Skip to content

Commit

Permalink
feat: make MLUtils into a weakdep & suppport MLDataDevices
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 21, 2024
1 parent 904cac0 commit 75e308c
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 3 deletions.
12 changes: 10 additions & 2 deletions lib/OptimizationOptimisers/Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
name = "OptimizationOptimisers"
uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.3.0"
version = "0.3.1"

[deps]
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

[extensions]
OptimizationOptimisersMLDataDevicesExt = "MLDataDevices"
OptimizationOptimisersMLUtilsExt = "MLUtils"

[weakdeps]
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"

[compat]
MLDataDevices = "1.1"
MLUtils = "0.4.4"
Optimisers = "0.2, 0.3"
Optimization = "4"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module OptimizationOptimisersMLDataDevicesExt

using MLDataDevices
using OptimizationOptimisers

OptimizationOptimisers.isa_dataiterator(::DeviceIterator) = true

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module OptimizationOptimisersMLUtilsExt

using MLUtils
using OptimizationOptimisers

OptimizationOptimisers.isa_dataiterator(::MLUtils.DataLoader) = true

end
5 changes: 4 additions & 1 deletion lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ function SciMLBase.__init(
kwargs...)
end

isa_dataiterator(data) = false

function SciMLBase.__solve(cache::OptimizationCache{
F,
RC,
Expand Down Expand Up @@ -57,13 +59,14 @@ function SciMLBase.__solve(cache::OptimizationCache{
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
end

if cache.p isa MLUtils.DataLoader
if isa_dataiterator(cache.p)
data = cache.p
dataiterate = true
else
data = [cache.p]
dataiterate = false
end

opt = cache.opt
θ = copy(cache.u0)
G = copy(θ)
Expand Down

0 comments on commit 75e308c

Please sign in to comment.