-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbypassdlgradients.m
31 lines (30 loc) · 1.09 KB
/
bypassdlgradients.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
% BYPASSDLGRADIENTS Bypass gradient for non-differentiable operations
% Y = bypassdlgradients(FUN, X) evaluates FUN(X)
% while overriding the derivitive calculation used during backward
% propogation to an identity function instead.
%
% Examples:
% a = dlarray([1.0 2.5]); % point at which to evaluate gradient
%
% % non-differentiable gradient calculation
% function [y,grad] = objectiveAndGradient(x)
% y = round(x(1) + x(2));
% grad = dlgradient(y,x);
% end
% [val,grad] = dlfeval(@objectiveAndGradient,a);
% % val is dlarray(4)
% % grad is dlarray([0 0])
%
% % non-differentiable gradient calculation with a straight-through
% % estimator for the 'round' function
% function [y,grad] = steObjectiveAndGradient(x)
% y = BYPASSDLGRADIENTS(@round, x(1) + x(2) );
% grad = dlgradient(y,x);
% end
% [val,grad] = dlfeval(@steObjectiveAndGradient,a);
% % val is dlarray(4)
% % grad is dlarray([1 1])
%
% See also: DLARRAY, DLACCELERATE, EXTRACTDATA
%
% Copyright 2023 The Mathworks, Inc.