Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I don't understand what pullback is producing when using sparse arrays on the GPU #1313

Open
salbert83 opened this issue Sep 24, 2022 · 1 comment

Comments

@salbert83
Copy link

salbert83 commented Sep 24, 2022

There seems to be an issue using pullback when the calculation involves sparse arrays on the CPU. Consider the example below:

using CUDA
using Flux
using LinearAlgebra
using SparseArrays
using Zygote

M = sprand(10, 10, 0.2)
ps = Flux.params(M)
@show typeof(ps[1]) # this is a SparseMatrixCSC{Float64, Int64}
x = randn(10)
loss() = sum((M*x).^2)
y, back = Zygote.pullback(loss, ps)
gs = back(one(y))
grad = collect(gs)[1]
@show typeof(grad) # this is also a SparseMatrixCSC{Float64, Int64}, which matches ps[1] (makes sense)
Flux.update!(Adam(), ps, gs) # ... and the update! works


M_gpu = Flux.gpu(M)
ps_gpu = Flux.params(M_gpu)
@show typeof(ps_gpu[1]) # this is a CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32}
x_gpu = Flux.gpu(x)
loss_gpu() = sum((M_gpu*x_gpu).^2)
y_gpu, back_gpu = Zygote.pullback(loss_gpu, ps_gpu)
gs_gpu = back_gpu(one(y_gpu))
grad_gpu = collect(gs_gpu)[1]
@show typeof(grad_gpu) # this is a CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, not sparse!
# The result is a full matrix and doesn't match numerically the result from the CPU (ignoring data types)
Flux.update!(Adam(), ps_gpu, gs_gpu) # ... and the update! throws an exception

[edited to add code block]

@mcabbott
Copy link
Member

What I think is happening is that, on the CPU, gradient projection via ProjectTo(x::SparseMatrixCSC{T}) is re-creating the sparse array. Since there is no corresponding projection for CuSparseMatrixCSC, the result stays dense.

Code is here, and as noted there (and in JuliaDiff/ChainRulesCore.jl#571) it's really a sketch of what's desired.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants