Skip to content

Commit

Permalink
Convert transposeLinear to transpose_linear
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasschmitz authored and apaszke committed May 10, 2022
1 parent a41e6b4 commit c5e2401
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 20 deletions.
2 changes: 1 addition & 1 deletion examples/manifold-gradients.dx
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def manifoldGrad
[TangentSpace manifoldA tangentA, VSpace tangentA]
(f: manifoldA -> Float) (x: manifoldA) : tangentA =
linearized: (Float & (tangentA --o Float)) = manifoldLinearize f x
transposeLinear (snd linearized) 1.0
transpose_linear (snd linearized) 1.0


' ### Equivalence to standard differentiation
Expand Down
4 changes: 2 additions & 2 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -974,11 +974,11 @@ def cumsum {n} (xs: n=>Float) : n=>Float =
-- TODO: add vector space constraints
def linearize {a b} (f:a->b) (x:a) : (b & a --o b) = %linearize f x
def jvp {a b} (f:a->b) (x:a) : a --o b = snd (linearize f x)
def transposeLinear {a b} (f:a --o b) : b --o a = %linearTranspose f
def transpose_linear {a b} (f:a --o b) : b --o a = %linearTranspose f

def vjp {a b} (f:a->b) (x:a) : (b & b --o a) =
(y, df) = linearize f x
(y, transposeLinear df)
(y, transpose_linear df)

def grad {a} (f:a->Float) (x:a) : a = snd (vjp f x) 1.0

Expand Down
6 changes: 3 additions & 3 deletions python/dex/interop/jax/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def tuple_string(prefix, index_set):
# parameter at index 1, the evaluated string should look like:
# ```
# \ x0 x1 x2 u1 ct.
# transposeLinear (\(t0, t2). linearized x0 x1 x2 t0 u1 t2) ct
# transpose_linear (\(t0, t2). linearized x0 x1 x2 t0 u1 t2) ct
# ```
# - The `x` variables are the (constant) inputs to the primal function. These
# should always be supplied by JAX.
Expand All @@ -373,9 +373,9 @@ def tuple_string(prefix, index_set):
arg_string("x", range(num_primals)) + " " + linearized_tangent_inputs)

# \ x0 x1 x2 u1 ct.
# transposeLinear (\(t0, t2). linearized x0 x1 x2 t0 u1 t2) ct
# transpose_linear (\(t0, t2). linearized x0 x1 x2 t0 u1 t2) ct
transposed = module.eval(
f"\\ {transposed_atom_params}. transposeLinear " +
f"\\ {transposed_atom_params}. transpose_linear " +
f"(\ {linear_lambda_params}. {linearized_name} {linearized_inputs}) ct"
)

Expand Down
22 changes: 11 additions & 11 deletions tests/ad-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,27 @@ def sum' {n} (xs:n=>Float) : Float = yieldAccum (AddMonoid Float) \ref. for i. r

:p
f : Float --o Float = \x. x
transposeLinear f 2.0
transpose_linear f 2.0
> 2.

:p
f : Float --o Float = \x. x + x
transposeLinear f 1.0
transpose_linear f 1.0
> 2.

:p
f : Float --o Float = \x. x + (x + x) * 2.0
transposeLinear f 1.0
transpose_linear f 1.0
> 5.

:p
f : Float --o Float = \x. x * 2.0
transposeLinear f 1.0
transpose_linear f 1.0
> 2.

:p
f : Float --o Float = \x. 2.0 * x
transposeLinear f 1.0
transpose_linear f 1.0
> 2.

:p grad (\x. x * x) 1.0
Expand Down Expand Up @@ -85,12 +85,12 @@ f : Float -> Float = \x. yieldAccum (AddMonoid Float) \ref. ref += x

:p
f : Float --o (Float & Float) = \x. (x, 2.0 * x)
transposeLinear f (1.0, 3.0)
transpose_linear f (1.0, 3.0)
> 7.

:p
f : (Float & Float) --o Float = \(x,y). x + 2.0 * y
transposeLinear f 1.0
transpose_linear f 1.0
> (1., 2.)

:p deriv cos 0.0
Expand Down Expand Up @@ -148,22 +148,22 @@ tripleit : Float --o Float = \x. x + x + x
:p tripleit 1.0
> 3.

:p transposeLinear tripleit 1.0
:p transpose_linear tripleit 1.0
> 3.

:p transposeLinear (transposeLinear tripleit) 1.0
:p transpose_linear (transpose_linear tripleit) 1.0
> 3.

:p
f : n:Type ?-> Ix n ?=> Float --o n=>Float = \x. for i. x

transposeLinear f [1.0, 2.0]
transpose_linear f [1.0, 2.0]
> 3.

:p
f : n:Type ?-> n=>Float --o n=>Float = \x. for i. x.i * 2.0

transposeLinear f [1.0, 2.0]
transpose_linear f [1.0, 2.0]
> [2., 4.]

myOtherSquare : Float -> Float =
Expand Down
6 changes: 3 additions & 3 deletions tests/uexpr-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -227,19 +227,19 @@ def myOtherFst {a b} ((x, _):(a&b)) : a = x
:p
f : Float --o Float =
\x. 2.0 * (x + x)
transposeLinear f 1.0
transpose_linear f 1.0
> 4.

-- FIXME: This fails due to shadowing!
--def transpose' (x:n=>m=>Float) --o : m=>n=>Float = for i j. x.j.i
--
--:p transposeLinear transpose' [[1.0, 2.0, 3.0]]
--:p transpose_linear transpose' [[1.0, 2.0, 3.0]]
--> [[1.0], [2.0], [3.0]]

:p
f : Float --o (Fin 3=>Float) =
\x. for i. x * 2.0
transposeLinear f [1.0, 2.0, 3.0]
transpose_linear f [1.0, 2.0, 3.0]
> 12.

id'' : b:Type ?-> b -> b = id
Expand Down

0 comments on commit c5e2401

Please sign in to comment.