-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: use the correct dispatches for device overloads (#1118)
* fix: use the correct dispatches for device overloads * fix: correct dispatches for adapt_structure * refactor: simplify get_device(_type) code * fix: partial revert of the ancestor patch * fix: restore unrolled_mapreduce for julia 1.10
- Loading branch information
Showing
13 changed files
with
72 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,15 @@ | ||
name = "LuxCore" | ||
uuid = "bb33d45b-7691-41d6-9220-0943567d0623" | ||
authors = ["Avik Pal <[email protected]> and contributors"] | ||
version = "1.2.0" | ||
version = "1.2.1" | ||
|
||
[deps] | ||
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" | ||
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
|
||
[weakdeps] | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" | ||
|
@@ -25,11 +26,12 @@ LuxCoreArrayInterfaceTrackerExt = ["ArrayInterface", "Tracker"] | |
LuxCoreChainRulesCoreExt = "ChainRulesCore" | ||
LuxCoreEnzymeCoreExt = "EnzymeCore" | ||
LuxCoreFunctorsExt = "Functors" | ||
LuxCoreMLDataDevicesExt = "MLDataDevices" | ||
LuxCoreMLDataDevicesExt = ["Adapt", "MLDataDevices"] | ||
LuxCoreReactantExt = "Reactant" | ||
LuxCoreSetfieldExt = "Setfield" | ||
|
||
[compat] | ||
Adapt = "4.1" | ||
ArrayInterface = "7.9" | ||
ChainRulesCore = "1.24" | ||
Compat = "4.16" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,16 @@ | ||
module LuxCoreMLDataDevicesExt | ||
|
||
using LuxCore: LuxCore | ||
using MLDataDevices: MLDataDevices | ||
using Adapt: Adapt | ||
using LuxCore: LuxCore, AbstractLuxLayer | ||
using MLDataDevices: MLDataDevices, AbstractDevice | ||
|
||
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) | ||
ldev = Symbol(dev, :Device) | ||
@eval function (::MLDataDevices.$(ldev))(NN::LuxCore.AbstractLuxLayer) | ||
@warn "Lux layers are stateless and hence don't participate in device transfers. \ | ||
Apply this function on the parameters and states generated using \ | ||
`LuxCore.setup`." | ||
return NN | ||
end | ||
MLDataDevices.isleaf(::AbstractLuxLayer) = true | ||
|
||
function Adapt.adapt_storage(::AbstractDevice, x::AbstractLuxLayer) | ||
@warn "Lux layers are stateless and hence don't participate in device transfers. \ | ||
Apply this function on the parameters and states generated using \ | ||
`LuxCore.setup`." | ||
return x | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "MLDataDevices" | ||
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" | ||
authors = ["Avik Pal <[email protected]> and contributors"] | ||
version = "1.6.2" | ||
version = "1.6.3" | ||
|
||
[deps] | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,13 @@ | ||
module MLDataDevicesChainRulesExt | ||
|
||
using Adapt: Adapt | ||
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, | ||
ReactantDevice | ||
using MLDataDevices: CPUDevice, AbstractDevice, Internal | ||
using ChainRules: OneElement | ||
|
||
Adapt.adapt_structure(::CPUDevice, x::OneElement) = x | ||
Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x)) | ||
|
||
for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, | ||
CUDADevice{Nothing}, AMDGPUDevice{Nothing}) | ||
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) | ||
@eval Adapt.adapt_structure(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) | ||
end | ||
Internal.get_device(::OneElement) = CPUDevice() | ||
Internal.get_device_type(::OneElement) = CPUDevice | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,13 @@ | ||
module MLDataDevicesFillArraysExt | ||
|
||
using Adapt: Adapt | ||
using FillArrays: FillArrays, AbstractFill | ||
using MLDataDevices: MLDataDevices, CPUDevice, AbstractDevice | ||
using FillArrays: AbstractFill | ||
using MLDataDevices: CPUDevice, AbstractDevice, Internal | ||
|
||
Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x | ||
Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x)) | ||
|
||
Internal.get_device(::AbstractFill) = CPUDevice() | ||
Internal.get_device_type(::AbstractFill) = CPUDevice | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,13 @@ | ||
module MLDataDevicesZygoteExt | ||
|
||
using Adapt: Adapt | ||
using MLDataDevices: CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, | ||
ReactantDevice | ||
using MLDataDevices: CPUDevice, AbstractDevice, Internal | ||
using Zygote: OneElement | ||
|
||
Adapt.adapt_structure(::CPUDevice, x::OneElement) = x | ||
Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x)) | ||
|
||
for dev in (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice, | ||
CUDADevice{Nothing}, AMDGPUDevice{Nothing}) | ||
# use `@eval` to avoid ambiguity with adapt_storage(::CUDADevice, ::AbstractArray) | ||
@eval Adapt.adapt_structure(to::$(dev), x::OneElement) = Adapt.adapt(to, collect(x)) | ||
end | ||
Internal.get_device(::OneElement) = CPUDevice() | ||
Internal.get_device_type(::OneElement) = CPUDevice | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters