You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
There seems to be an issue using pullback when the calculation involves sparse arrays on the CPU. Consider the example below:
using CUDA
using Flux
using LinearAlgebra
using SparseArrays
using Zygote
M =sprand(10, 10, 0.2)
ps = Flux.params(M)
@showtypeof(ps[1]) # this is a SparseMatrixCSC{Float64, Int64}
x =randn(10)
loss() =sum((M*x).^2)
y, back = Zygote.pullback(loss, ps)
gs =back(one(y))
grad =collect(gs)[1]
@showtypeof(grad) # this is also a SparseMatrixCSC{Float64, Int64}, which matches ps[1] (makes sense)
Flux.update!(Adam(), ps, gs) # ... and the update! works
M_gpu = Flux.gpu(M)
ps_gpu = Flux.params(M_gpu)
@showtypeof(ps_gpu[1]) # this is a CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32}
x_gpu = Flux.gpu(x)
loss_gpu() =sum((M_gpu*x_gpu).^2)
y_gpu, back_gpu = Zygote.pullback(loss_gpu, ps_gpu)
gs_gpu =back_gpu(one(y_gpu))
grad_gpu =collect(gs_gpu)[1]
@showtypeof(grad_gpu) # this is a CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, not sparse!# The result is a full matrix and doesn't match numerically the result from the CPU (ignoring data types)
Flux.update!(Adam(), ps_gpu, gs_gpu) # ... and the update! throws an exception
[edited to add code block]
The text was updated successfully, but these errors were encountered:
What I think is happening is that, on the CPU, gradient projection via ProjectTo(x::SparseMatrixCSC{T}) is re-creating the sparse array. Since there is no corresponding projection for CuSparseMatrixCSC, the result stays dense.
There seems to be an issue using pullback when the calculation involves sparse arrays on the CPU. Consider the example below:
[edited to add code block]
The text was updated successfully, but these errors were encountered: