Skip to content

Commit

Permalink
Fix accumulate_param_gradients! for Map and Unfold.
Browse files Browse the repository at this point in the history
  • Loading branch information
ztangent committed Oct 5, 2024
1 parent 614a10a commit 07ae12e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/gen_fn_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,9 @@ If an argument is not annotated with `(grad)`, the corresponding value in
Also increment the gradient accumulators for the trainable parameters \$Θ\$ of
the function by:
```math
∇_Θ \\left( \\log P(t; x) + J \\right)
s * ∇_Θ \\left( \\log P(t; x) + J \\right)
```
where \$s\$ is `scale_factor`.
"""
function accumulate_param_gradients!(trace, retgrad, scale_factor)
error("Not implemented")
Expand Down
4 changes: 2 additions & 2 deletions src/modeling_library/map/backprop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function choice_gradients(trace::VectorTrace{MapType,T,U}, selection::Selection,
((arg_grad...,), value_choices, gradient_choices)
end

function accumulate_param_gradients!(trace::VectorTrace{MapType,T,U}, retval_grad) where {T,U}
function accumulate_param_gradients!(trace::VectorTrace{MapType,T,U}, retval_grad, scale_factor) where {T,U}

args = get_args(trace)
n_args = length(args)
Expand All @@ -54,7 +54,7 @@ function accumulate_param_gradients!(trace::VectorTrace{MapType,T,U}, retval_gra
for key=1:len
subtrace = trace.subtraces[key]
kernel_retval_grad = (retval_grad == nothing) ? nothing : retval_grad[key]
kernel_arg_grad::Tuple = accumulate_param_gradients!(subtrace, kernel_retval_grad)
kernel_arg_grad::Tuple = accumulate_param_gradients!(subtrace, kernel_retval_grad, scale_factor)
for (i, grad, has_grad) in zip(1:n_args, kernel_arg_grad, has_grads)
if has_grad
arg_grad[i][key] = grad
Expand Down
4 changes: 2 additions & 2 deletions src/modeling_library/unfold/backprop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function choice_gradients(trace::VectorTrace{UnfoldType,T,U}, selection::Selecti
((nothing, kernel_arg_grads[2], map(_sum, params_grad)...), value_choices, gradient_choices)
end

function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_grad) where {T,U}
function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_grad, scale_factor) where {T,U}
kernel_has_grads = has_argument_grads(trace.gen_fn.kernel)
if kernel_has_grads[1]
error("Cannot differentiate with respect to index in unfold")
Expand All @@ -76,7 +76,7 @@ function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_
if state_has_grad
kernel_retval_grad = fold_sum(kernel_retval_grad, kernel_arg_grads[2])
end
kernel_arg_grads = accumulate_param_gradients!(subtrace, kernel_retval_grad)
kernel_arg_grads = accumulate_param_gradients!(subtrace, kernel_retval_grad, scale_factor)
@assert kernel_arg_grads[1] == nothing
state_has_grad || @assert kernel_arg_grads[2] == nothing
for (i, (grad, has_grad)) in enumerate(zip(kernel_arg_grads[3:end], params_has_grad))
Expand Down

0 comments on commit 07ae12e

Please sign in to comment.