From f343c88d34ab4fc7cfbaec4d0d68089b07fd1010 Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Thu, 18 Apr 2024 12:29:59 -0600 Subject: [PATCH 1/6] merge forward and reverse mode --- src/emmy/abstract/function.cljc | 26 ++- src/emmy/calculus/derivative.cljc | 12 ++ src/emmy/differential.cljc | 284 +----------------------------- src/emmy/tape.cljc | 194 +++++++++++++------- test/emmy/differential_test.cljc | 3 +- test/emmy/tape_test.cljc | 42 +++-- 6 files changed, 195 insertions(+), 366 deletions(-) diff --git a/src/emmy/abstract/function.cljc b/src/emmy/abstract/function.cljc index 8289ba2f..72ff544a 100644 --- a/src/emmy/abstract/function.cljc +++ b/src/emmy/abstract/function.cljc @@ -19,6 +19,7 @@ [emmy.numsymb :as sym] [emmy.polynomial] [emmy.structure :as s] + [emmy.tape :as tape] [emmy.util :as u] [emmy.value :as v]) #?(:clj @@ -262,14 +263,25 @@ (g/+ tangent (g/* (literal-apply partial primal-s) dx)))))))) +(defn- reverse-mode-fold [f primal-s tag] + (fn + ([] []) + ([partials] + (tape/make tag (apply f primal-s) partials)) + ([partials [entry path _]] + (if (and (tape/tape? entry) (= tag (tape/tape-tag entry))) + (let [partial (literal-partial f path)] + (conj partials [entry (literal-apply partial primal-s)])) + partials)))) + (defn- literal-derivative - "Takes a literal function `f` and a sequence of arguments `xs`, and generates - an expanded `((D f) xs)` by applying the chain rule and summing the partial - derivatives for each [[emmy.differential/Dual]] argument in the input - structure." + "Takes a literal function `f` and a sequence of arguments `xs`, and generates an + expanded `((D f) xs)` by applying the chain rule and summing the partial + derivatives for each perturbed argument in the input structure." [f s tag dx] - (let [fold-fn (cond (d/dual? dx) forward-mode-fold - :else (u/illegal "No tape or differential inputs.")) + (let [fold-fn (cond (tape/tape? dx) reverse-mode-fold + (d/dual? dx) forward-mode-fold + :else (u/illegal "No tape or differential inputs.")) primal-s (s/mapr (fn [x] (d/primal x tag)) s)] (s/fold-chain (fold-fn f primal-s tag) s))) @@ -305,7 +317,7 @@ (if-let [[tag dx] (s/fold-chain (fn ([] []) - ([acc] (apply d/tag+perturbation acc)) + ([acc] (apply tape/tag+perturbation acc)) ([acc [d]] (conj acc d))) s)] (literal-derivative f s tag dx) diff --git a/src/emmy/calculus/derivative.cljc b/src/emmy/calculus/derivative.cljc index 10bbdb61..9f4418ab 100644 --- a/src/emmy/calculus/derivative.cljc +++ b/src/emmy/calculus/derivative.cljc @@ -15,6 +15,7 @@ [emmy.operator :as o] [emmy.series :as series] [emmy.structure :as s] + [emmy.tape :as tape] [emmy.util :as u] [emmy.value :as v]) #?(:clj @@ -458,6 +459,17 @@ (o/make-operator #(g/partial-derivative % []) g/derivative-symbol)) +(def D-rev + "Reverse-mode derivative operator..." + (o/make-operator #(tape/gradient % []) + g/derivative-symbol)) + +(defn partial-rev + "Reverse-mode partial derivative." + [& selectors] + (o/make-operator #(tape/gradient % selectors) + `(~'partial ~@selectors))) + (defn D-as-matrix [F] (fn [s] (matrix/s->m diff --git a/src/emmy/differential.cljc b/src/emmy/differential.cljc index 5c07e136..e16542a7 100644 --- a/src/emmy/differential.cljc +++ b/src/emmy/differential.cljc @@ -12,7 +12,6 @@ (:refer-clojure :exclude [compare]) (:require [emmy.function] ;; for the side effect of making kind: MultiFn -> ::v/function [emmy.generic :as g] - [emmy.util :as u] [emmy.value :as v])) ;; ## Differentials, Dual Numbers and Automatic Differentiation @@ -557,36 +556,6 @@ (boolean (some #{tag} *active-tags*))) -(defn inner-tag - "Given any number of `tags`, returns the tag most recently bound - via [[with-active-tag]] (i.e., the tag connected with the _innermost_ call - to [[with-active-tag]]). - - If none of the tags are bound, returns `(apply max tags)`." - [& tags] - (or (some (apply hash-set tags) - *active-tags*) - (apply max tags))) - -(defn tag+perturbation - "Given any number of [[Dual]] instances `dxs`, returns a pair of the form - - [ ] - - containing the tag and instance of [[Dual]] associated with the inner-most - call to [[with-active-tag]] in the current call stack. - - If none of `dxs` has an active tag, returns `nil`." - ([& dxs] - (let [m (into {} (mapcat - (fn [dx] - (when-let [t (tag dx)] - {t dx}))) - dxs)] - (when (seq m) - (let [tag (apply inner-tag (keys m))] - [tag (m tag)]))))) - ;; ## Comparison, Control Flow ;; ;; Functions like `=`, `<` and friends don't have derivatives; instead, they're @@ -719,260 +688,9 @@ ;; Magically this will all Just Work if you pass an already-lifted function, or ;; a function built out of already-lifted components, as `df:dx` or `df:dy`. -(defn lift-1 - "Given: - - - some unary function `f` - - a function `df:dx` that computes the derivative of `f` with respect to its - single argument - - Returns a new unary function that operates on both the original type of `f` - and [[Dual]] instances. - - If called without `df:dx`, `df:dx` defaults to `(f :dfdx)`; this will return - the derivative registered to a generic function defined - with [[emmy.util.def/defgeneric]]. - - NOTE: `df:dx` has to ALREADY be able to handle [[Dual]] instances. The best - way to accomplish this is by building `df:dx` out of already-lifted functions, - and declaring them by forward reference if you need to." - ([f] - (if-let [df:dx (f :dfdx)] - (lift-1 f df:dx) - (u/illegal "No df:dx supplied for `f` or registered generically."))) - ([f df:dx] - (fn call [x] - (if-not (dual? x) - (f x) - (let [[px tx] (primal-tangent-pair x) - primal (call px) - tangent (g/* (df:dx px) tx)] - (bundle-element primal tangent (tag x))))))) - -(defn lift-2 - "Given: - - - some binary function `f` - - a function `df:dx` that computes the derivative of `f` with respect to its - single argument - - a function `df:dy`, similar to `df:dx` for the second arg - - Returns a new binary function that operates on both the original type of `f` - and [[Dual]] instances. - - NOTE: `df:dx` and `df:dy` have to ALREADY be able to handle [[Dual]] - instances. The best way to accomplish this is by building `df:dx` and `df:dy` - out of already-lifted functions, and declaring them by forward reference if - you need to." - ([f] - (let [df:dx (f :dfdx) - df:dy (f :dfdy)] - (if (and df:dx df:dy) - (lift-2 f df:dx df:dy) - (u/illegal "No df:dx, df:dy supplied for `f` or registered generically.")))) - ([f df:dx df:dy] - (fn call [x y] - (if-let [[tag _] (tag+perturbation x y)] - (let [[xe dx] (primal-tangent-pair x tag) - [ye dy] (primal-tangent-pair y tag) - primal (call xe ye) - tangent (g/+ (if (g/numeric-zero? dx) - dx - (g/* (df:dx xe ye) dx)) - (if (g/numeric-zero? dy) - dy - (g/* (df:dy xe ye) dy)))] - (bundle-element primal tangent tag)) - (f x y))))) - -(defn lift-n - "Given: - - - some function `f` that can handle 0, 1 or 2 arguments - - `df:dx`, a fn that returns the derivative wrt the single arg in the unary case - - `df:dx1` and `df:dx2`, fns that return the derivative with respect to the - first and second args in the binary case - - Returns a new any-arity function that operates on both the original type of - `f` and [[Dual]] instances. - - NOTE: The n-ary case of `f` is populated by nested calls to the binary case. - That means that this is NOT an appropriate lifting method for an n-ary - function that isn't built out of associative binary calls. If you need this - ability, please file an issue at the [emmy issue - tracker](https://github.com/mentat-collective/emmy/issues)." - [f df:dx df:dx1 df:dx2] - (let [f1 (lift-1 f df:dx) - f2 (lift-2 f df:dx1 df:dx2)] - (fn call - ([] (f)) - ([x] (f1 x)) - ([x y] (f2 x y)) - ([x y & more] - (reduce call (call x y) more))))) +;; TODO port docs above... ;; ## Generic Method Installation -;; -;; Armed with [[lift-1]] and [[lift-2]], we can install [[Dual]] into -;; the Emmy generic arithmetic system. -;; -;; Any function built out of these components will work with -;; the [[emmy.calculus.derivative/D]] operator. - -(defn- defunary - "Given: - - - a generic unary multimethod `generic-op` - - optionally, a corresponding single-arity lifted function - `differential-op` (defaults to `(lift-1 generic-op)`) - - installs an appropriate unary implementation of `generic-op` for `::dual` - instances." - ([generic-op] - (defunary generic-op (lift-1 generic-op))) - ([generic-op differential-op] - (defmethod generic-op [::dual] [a] (differential-op a)))) - -(defn- defbinary - "Given: - - - a generic binary multimethod `generic-op` - - optionally, a corresponding 2-arity lifted function - `differential-op` (defaults to `(lift-2 generic-op)`) - - installs an appropriate binary implementation of `generic-op` between `::dual` - and `::v/scalar` instances." - ([generic-op] - (defbinary generic-op (lift-2 generic-op))) - ([generic-op differential-op] - (doseq [signature [[::dual ::dual] - [::v/scalar ::dual] - [::dual ::v/scalar]]] - (defmethod generic-op signature [a b] (differential-op a b))))) - -(defn ^:no-doc by-primal - "Given some unary or binary function `f`, returns an augmented `f` that acts on - the primal entries of any [[Dual]] arguments encountered, irrespective of tag. - - Given a [[Dual]] with a [[Dual]] in its [[primal]] part, the returned `f` will - recursively descend until it hits a non-[[Dual]]." - [f] - (fn - ([x] (f (deep-primal x))) - ([x y] (f (deep-primal x) - (deep-primal y))))) - -;; And now we're off to the races. The rest of the namespace -;; provides [[defunary]] and [[defbinary]] calls for all of the generic -;; operations for which we know how to declare partial derivatives. - -;; First, install `equiv` as to perform proper equality between `Dual` -;; instances and scalars. `equiv` compares on only the finite part, not the -;; differential parts. - -(defbinary g/add) -(defunary g/negate) -(defbinary g/sub) - -(let [mul (lift-2 g/mul)] - (defbinary g/mul mul) - (defbinary g/dot-product mul)) -(defbinary g/expt) - -(defunary g/square) -(defunary g/cube) - -(defunary g/invert) -(defbinary g/div) - -(defunary g/abs - (fn [x] - (let [f (deep-primal x) - func (cond (< f 0) (lift-1 g/negate (fn [_] -1)) - (> f 0) (lift-1 identity (fn [_] 1)) - (= f 0) (u/illegal "Derivative of g/abs undefined at zero") - :else (u/illegal (str "error! derivative of g/abs at" x)))] - (func x)))) - -(defn- discont-at-integers [f dfdx] - (let [f (lift-1 f (fn [_] dfdx)) - f-name (g/freeze f)] - (fn [x] - (if (v/integral? (deep-primal x)) - (u/illegal - (str "Derivative of emmy.generic/" - f-name " undefined at integral points.")) - (f x))))) - -(defunary g/floor - (discont-at-integers g/floor 0)) - -(defunary g/ceiling - (discont-at-integers g/ceiling 0)) - -(defunary g/integer-part - (discont-at-integers g/integer-part 0)) - -(defunary g/fractional-part - (discont-at-integers g/fractional-part 1)) - -(let [div (lift-2 g/div)] - (defbinary g/solve-linear (fn [l r] (div r l))) - (defbinary g/solve-linear-right div)) - -(defunary g/sqrt) -(defunary g/log) -(defunary g/exp) - -(defunary g/cos) -(defunary g/sin) -(defunary g/tan) -(defunary g/cot) -(defunary g/sec) -(defunary g/csc) - -(defunary g/atan) -(defbinary g/atan) -(defunary g/asin) -(defunary g/acos) -(defunary g/acot) -(defunary g/asec) -(defunary g/acsc) - -(defunary g/cosh) -(defunary g/sinh) -(defunary g/tanh) -(defunary g/sech) -(defunary g/coth) -(defunary g/csch) - -(defunary g/acosh) -(defunary g/asinh) -(defunary g/atanh) -(defunary g/acoth) -(defunary g/asech) -(defunary g/acsch) - -(defunary g/sinc) -(defunary g/sinhc) -(defunary g/tanc) -(defunary g/tanhc) - -;; Non-differentiable generic operations - -(defbinary v/= (by-primal v/=)) -(defunary g/negative? (by-primal g/negative?)) -(defunary g/infinite? (by-primal g/infinite?)) - - -(defunary g/zero? - (fn [dx] - (let [[p t] (primal-tangent-pair dx)] - (and (g/zero? p) - (g/zero? t))))) - -(defunary g/one? one?) -(defunary g/identity? identity?) (defmethod g/zero-like [::dual] [_] 0) (defmethod g/one-like [::dual] [_] 1) diff --git a/src/emmy/tape.cljc b/src/emmy/tape.cljc index b902a029..210b3cb3 100644 --- a/src/emmy/tape.cljc +++ b/src/emmy/tape.cljc @@ -254,17 +254,6 @@ [^TapeCell tape] (.-tag tape)) -(defn tag-of - "More permissive version of [[tape-tag]] that returns `nil` when passed a - non-[[TapeCell]] instance. - - TODO this will need to be extended to - handle [[emmy.differential/Differential]] instances when these namespaces - merge." - [x] - (cond (tape? x) (tape-tag x) - :else nil)) - (defn tape-id "Returns the `-id` field of the supplied [[TapeCell]] object. Errors if any other type is supplied. @@ -317,15 +306,66 @@ :primal (.-primal t) :in->partial (.-in->partial t)}) +(defn tag-of + "More permissive version of [[tape-tag]] that returns `nil` when passed a + non-[[TapeCell]] instance. + + TODO note what we handle now." + [x] + (cond (tape? x) (tape-tag x) + (d/dual? x) (d/tag x) + :else nil)) + +;; TODO move tag stuff here? + +(defn inner-tag + "Given any number of `tags`, returns the tag most recently bound + via [[with-active-tag]] (i.e., the tag connected with the _innermost_ call + to [[with-active-tag]]). + + If none of the tags are bound, returns `(apply max tags)`." + [& tags] + (or (some (apply hash-set tags) + d/*active-tags*) + (apply max tags))) + +(defn tag+perturbation + "A COPY of the same function in `differential`. I'm adding this here to avoid + import nonsense, and I'll delete one of the copies on the next PR, when I add + support for mixing forward and reverse modes together." + ([& dxs] + (let [m (into {} (mapcat + (fn [dx] + (when-let [t (tag-of dx)] + {t dx}))) + dxs)] + (when (seq m) + (let [tag (apply inner-tag (keys m))] + [tag (m tag)]))))) + +(defn primal-of + "More permissive version of [[tape-primal]] that returns `v` when passed a + non-[[TapeCell]]-or-[[emmy.differential/Dual]] instance. + + TODO fix docstring" + ([v] + (primal-of v (tag-of v))) + ([v tag] + (cond (tape? v) (tape-primal v tag) + (d/dual? v) (d/primal v tag) + :else v))) + (defn deep-primal "Version of [[tape-primal]] that will descend recursively into any [[TapeCell]] instance returned by [[tape-primal]] until encountering a non-[[TapeCell]]. - Given a non-[[TapeCell]], acts as identity." - [v] - (if (tape? v) - (recur (tape-primal v)) - v)) + Given a non-[[TapeCell]], acts as identity. + + TODO say what we really do now" + ([v] + (cond (tape? v) (recur (tape-primal v)) + (d/dual? v) (recur (d/primal v)) + :else v))) ;; ### Comparison, Control Flow ;; @@ -699,16 +739,17 @@ - a function `df:dx` that computes the derivative of `f` with respect to its single argument - Returns a new unary function that operates on both the original type of `f` - and [[TapeCell]] instances. + Returns a new unary function that operates on both the original type of + `f`, [[TapeCell]] and [[emmy.differential/Dual]] instances. If called without `df:dx`, `df:dx` defaults to `(f :dfdx)`; this will return the derivative registered to a generic function defined with [[emmy.util.def/defgeneric]]. - NOTE: `df:dx` has to ALREADY be able to handle [[TapeCell]] instances. The - best way to accomplish this is by building `df:dx` out of already-lifted - functions, and declaring them by forward reference if you need to." + NOTE: `df:dx` has to ALREADY be able to handle [[TapeCell]] + and [[emmy.differential/Dual]] instances. The best way to accomplish this is + by building `df:dx` out of already-lifted functions, and declaring them by + forward reference if you need to." ([f] (if-let [df:dx (f :dfdx)] (lift-1 f df:dx) @@ -716,26 +757,19 @@ "No df:dx supplied for `f` or registered generically."))) ([f df:dx] (fn call [x] - (if (tape? x) - (let [primal (tape-primal x)] - (make (tape-tag x) - (call primal) - [[x (df:dx primal)]])) - (f x))))) - -(defn- tag+perturbation - "A COPY of the same function in `differential`. I'm adding this here to avoid - import nonsense, and I'll delete one of the copies on the next PR, when I add - support for mixing forward and reverse modes together." - ([& dxs] - (let [m (into {} (mapcat - (fn [dx] - (when-let [t (tag-of dx)] - {t dx}))) - dxs)] - (when (seq m) - (let [tag (apply d/inner-tag (keys m))] - [tag (m tag)]))))) + (cond (tape? x) + (let [primal (tape-primal x)] + (make (tape-tag x) + (call primal) + [[x (df:dx primal)]])) + + (d/dual? x) + (let [[px tx] (d/primal-tangent-pair x) + primal (call px) + tangent (g/* (df:dx px) tx)] + (d/bundle-element primal tangent (d/tag x))) + + :else (f x))))) (defn lift-2 "Given: @@ -745,13 +779,13 @@ single argument - a function `df:dy`, similar to `df:dx` for the second arg - Returns a new binary function that operates on both the original type of `f` - and [[TapeCell]] instances. + Returns a new binary function that operates on both the original type of + `f`, [[TapeCell]] and [[emmy.differential/Differential]] instances. NOTE: `df:dx` and `df:dy` have to ALREADY be able to handle [[TapeCell]] - instances. The best way to accomplish this is by building `df:dx` and `df:dy` - out of already-lifted functions, and declaring them by forward reference if - you need to." + and [[emmy.differential/Dual]] instances. The best way to accomplish this is + by building `df:dx` and `df:dy` out of already-lifted functions, and declaring + them by forward reference if you need to." ([f] (let [df:dx (f :dfdx) df:dy (f :dfdy)] @@ -761,7 +795,19 @@ "No df:dx, df:dy supplied for `f` or registered generically.")))) ([f df:dx df:dy] (fn call [x y] - (letfn [(operate [tag] + (letfn [(operate-forward [tag] + (let [[xe dx] (d/primal-tangent-pair x tag) + [ye dy] (d/primal-tangent-pair y tag) + primal (call xe ye) + tangent (g/+ (if (g/numeric-zero? dx) + dx + (g/* (df:dx xe ye) dx)) + (if (g/numeric-zero? dy) + dy + (g/* (df:dy xe ye) dy)))] + (d/bundle-element primal tangent tag))) + + (operate-reverse [tag] (let [primal-x (tape-primal x tag) primal-y (tape-primal y tag) partial-x (if (and (tape? x) (= tag (tape-tag x))) @@ -770,13 +816,15 @@ partial-y (if (and (tape? y) (= tag (tape-tag y))) [[y (df:dy primal-x primal-y)]] [])] + (make tag (call primal-x primal-y) (into partial-x partial-y))))] (if-let [[tag dx] (tag+perturbation x y)] - (cond (tape? dx) (operate tag) + (cond (tape? dx) (operate-reverse tag) + (d/dual? dx) (operate-forward tag) :else - (u/illegal "Non-tape perturbation!")) + (u/illegal "Non-tape or differential perturbation!")) (f x y)))))) (defn lift-n @@ -788,7 +836,7 @@ first and second args in the binary case Returns a new any-arity function that operates on both the original type of - `f` and [[TapeCell]] instances. + `f`, [[TapeCell]] and [[emmy.differential/Dual]] instances. NOTE: The n-ary case of `f` is populated by nested calls to the binary case. That means that this is NOT an appropriate lifting method for an n-ary @@ -817,11 +865,12 @@ - optionally, a corresponding single-arity lifted function `differential-op` (defaults to `(lift-1 generic-op)`) - installs an appropriate unary implementation of `generic-op` for `::tape` - instances." + installs an appropriate unary implementation of `generic-op` for `::tape` and + `:emmy.differential/dual` instances." ([generic-op] (defunary generic-op (lift-1 generic-op))) ([generic-op differential-op] + (defmethod generic-op [::d/dual] [a] (differential-op a)) (defmethod generic-op [::tape] [a] (differential-op a)))) (defn- defbinary @@ -831,14 +880,19 @@ - optionally, a corresponding 2-arity lifted function `differential-op` (defaults to `(lift-2 generic-op)`) - installs an appropriate binary implementation of `generic-op` between `:tape` - and `::v/scalar` instances." + installs an appropriate binary implementation of `generic-op` between + `::tape`, `::emmy.differential/dual` and `::v/scalar` instances." ([generic-op] (defbinary generic-op (lift-2 generic-op))) ([generic-op differential-op] (doseq [signature [[::tape ::tape] + [::d/dual ::d/dual] + [::tape ::d/dual] + [::d/dual ::tape] [::v/scalar ::tape] - [::tape ::v/scalar]]] + [::v/scalar ::d/dual] + [::tape ::v/scalar] + [::d/dual ::v/scalar]]] (defmethod generic-op signature [a b] (differential-op a b))))) (defn ^:no-doc by-primal @@ -847,7 +901,9 @@ tag. Given a [[TapeCell]] with a [[TapeCell]] in its [[primal-part]], the returned - `f` will recursively descend until it hits a non-[[TapeCell]]." + `f` will recursively descend until it hits a non-[[TapeCell]]. + + TODO fix docs" [f] (fn ([x] (f (deep-primal x))) @@ -944,9 +1000,29 @@ ;; Non-differentiable generic operations (defbinary v/= (by-primal v/=)) -(defunary g/zero? (by-primal g/zero?)) -(defunary g/one? (by-primal g/one?)) -(defunary g/identity? (by-primal g/identity?)) +(defunary g/zero? + (let [zero-p? (by-primal g/zero?)] + (fn [dx] + (if (tape? dx) + (zero-p? dx) + (let [[p t] (d/primal-tangent-pair dx)] + (and (g/zero? p) + (g/zero? t))))))) + +(defunary g/one? + (let [one-p? (by-primal g/one?)] + (fn [dx] + (if (tape? dx) + (one-p? dx) + (d/one? dx))))) + +(defunary g/identity? + (let [identity-p? (by-primal g/identity?)] + (fn [dx] + (if (tape? dx) + (identity-p? dx) + (d/identity? dx))))) + (defunary g/negative? (by-primal g/negative?)) (defunary g/infinite? (by-primal g/infinite?)) diff --git a/test/emmy/differential_test.cljc b/test/emmy/differential_test.cljc index 58225ccb..2373b88f 100644 --- a/test/emmy/differential_test.cljc +++ b/test/emmy/differential_test.cljc @@ -10,6 +10,7 @@ [emmy.generic :as g] [emmy.numerical.derivative :refer [D-numeric]] [emmy.simplify :refer [hermetic-simplify-fixture]] + [emmy.tape :as tape] [emmy.value :as v] [same.core :refer [ish? with-comparator]])) @@ -357,7 +358,7 @@ (is (g/one? ((derivative g/fractional-part) x))))))) (testing "lift-n" - (let [* (d/lift-n g/* (fn [_] 1) (fn [_ y] y) (fn [x _] x)) + (let [* (tape/lift-n g/* (fn [_] 1) (fn [_ y] y) (fn [x _] x)) Df7 (derivative (fn x**7 [x] (* x x x x x x x))) Df1 (derivative *) diff --git a/test/emmy/tape_test.cljc b/test/emmy/tape_test.cljc index 1bade4e4..05070c83 100644 --- a/test/emmy/tape_test.cljc +++ b/test/emmy/tape_test.cljc @@ -576,20 +576,30 @@ (is (= expected (g/simplify ((t/gradient (t/gradient f)) 'a 'b 'c 'd 'e 'f))) - "multivariable derivatives match (reverse-over-reverse)") + "multivariable derivatives match (reverse-over-reverse)")))) - ;; TODO enable this when we add support for tape and gradient comms in - ;; lift-2. - #_ - (is (= expected - (g/simplify - ((D (t/gradient f)) 'a 'b 'c 'd 'e 'f))) - "forward-over-reverse") - - ;; TODO enable this when we add support for tape and gradient comms in - ;; lift-2. - #_ - (is (= expected - (g/simplify - ((t/gradient (D f)) 'a 'b 'c 'd 'e 'f))) - "reverse-over-forward")))) +(deftest mixed-mode-tests + (testing "nested reverse mode" + (let [f (fn [x] + (fn [y] + (g/* (g/square x) (g/square y))))] + (is (= ((D ((t/gradient f) 'x)) 'y) + ((t/gradient ((D f) 'x)) 'y) + ((t/gradient ((t/gradient f) 'x)) 'y)) + "reverse-mode nests with forward-mode"))) + + (let [f (fn [a b c d e f] + [(g/* (g/cos a) (g/cos b)) + (g/* (g/cos c) (g/cos d)) + (g/* (g/cos e) (g/cos f))]) + expected (g/simplify + ((D (D f)) 'a 'b 'c 'd 'e 'f))] + (is (= expected + (g/simplify + ((D (t/gradient f)) 'a 'b 'c 'd 'e 'f))) + "forward-over-reverse") + + (is (= expected + (g/simplify + ((t/gradient (D f)) 'a 'b 'c 'd 'e 'f))) + "reverse-over-forward"))) From 69ee50a126655ca724fc9a1f5d4ea15c7a266712 Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Thu, 18 Apr 2024 12:34:53 -0600 Subject: [PATCH 2/6] tape works --- src/emmy/abstract/function.cljc | 2 +- src/emmy/calculus/derivative.cljc | 18 +++++---- src/emmy/tape.cljc | 61 ++++++++++++++++++++++++++++--- 3 files changed, 68 insertions(+), 13 deletions(-) diff --git a/src/emmy/abstract/function.cljc b/src/emmy/abstract/function.cljc index 72ff544a..5ddc2a1c 100644 --- a/src/emmy/abstract/function.cljc +++ b/src/emmy/abstract/function.cljc @@ -282,7 +282,7 @@ (let [fold-fn (cond (tape/tape? dx) reverse-mode-fold (d/dual? dx) forward-mode-fold :else (u/illegal "No tape or differential inputs.")) - primal-s (s/mapr (fn [x] (d/primal x tag)) s)] + primal-s (s/mapr (fn [x] (tape/primal-of x tag)) s)] (s/fold-chain (fold-fn f primal-s tag) s))) (defn- check-argument-type diff --git a/src/emmy/calculus/derivative.cljc b/src/emmy/calculus/derivative.cljc index 9f4418ab..da5a7584 100644 --- a/src/emmy/calculus/derivative.cljc +++ b/src/emmy/calculus/derivative.cljc @@ -543,12 +543,16 @@ (letfn [(process-term [term] (g/simplify (s/mapr (fn rec [x] - (if (d/dual? x) - (d/bundle-element - (rec (d/primal x)) - (rec (d/tangent x)) - (d/tag x)) - (-> (g/simplify x) - (x/substitute replace-m)))) + (cond (d/dual? x) + (d/bundle-element + (rec (d/primal x)) + (rec (d/tangent x)) + (d/tag x)) + + (tape/tape? x) + (u/illegal "TODO implement this using fmap style.") + + :else (-> (g/simplify x) + (x/substitute replace-m)))) term)))] (series/fmap process-term series))))) diff --git a/src/emmy/tape.cljc b/src/emmy/tape.cljc index 210b3cb3..10f42c1b 100644 --- a/src/emmy/tape.cljc +++ b/src/emmy/tape.cljc @@ -13,6 +13,7 @@ [emmy.function :as f] [emmy.generic :as g] [emmy.matrix :as matrix] + [emmy.operator :as o] [emmy.structure :as s] [emmy.util :as u] [emmy.value :as v])) @@ -329,10 +330,21 @@ d/*active-tags*) (apply max tags))) +;; TODO we could change `perturbed?` into something like +;; `possible-perturbations`, to get collection types to return sequence of +;; inputs for this. Then we could handle map-shaped inputs etc into literal +;; functions, if we had the proper descriptor language for it. + (defn tag+perturbation - "A COPY of the same function in `differential`. I'm adding this here to avoid - import nonsense, and I'll delete one of the copies on the next PR, when I add - support for mixing forward and reverse modes together." + "Given any number of `dxs`, returns a pair of the form + + [ ] + + containing the tag and instance of [[emmy.differential/Dual]] or [[TapeCell]] + associated with the inner-most call to [[with-active-tag]] in the current call + stack. + + If none of `dxs` has an active tag, returns `nil`." ([& dxs] (let [m (into {} (mapcat (fn [dx] @@ -479,6 +491,20 @@ ;; (defrecord Completed [v->partial] d/IPerturbed + ;; TODO note that this can happen because these can pop out from inside of + ;; ->partial-fn. And that is currently where the tag-rewriting has to occur. + ;; + ;; But that is going to be inefficient for lots of intermediate values... + ;; ideally we could call this AFTER we select out the IDs. That implies that + ;; we want to shove that inside of extract. + ;; + ;; TODO TODO TODO definitely do this, we definitely want that to happen, don't + ;; have those stacked levels, otherwise super inefficient to walk multiple + ;; times. + ;; + ;; TODO AND THEN if that's true then we can delete this implementation, since + ;; we'll already be pulled OUT of the completed map. + ;; NOTE that it's a problem that `replace-tag` is called on [[Completed]] ;; instances now. In a future refactor I want `get` calls out of ;; a [[Completed]] map to occur before tag replacement needs to happen. @@ -524,6 +550,10 @@ - the partial derivative of the output with respect to that value." [root] (let [nodes (topological-sort root) + + ;; TODO this is the spot where we want to wire in many sensitivities. So + ;; how would it work, if we set all of the sensitivities for the outputs + ;; at once? What would the ordering be as we walked backwards? sensitivities {(tape-id root) 1}] (->Completed (reduce process sensitivities nodes)))) @@ -540,6 +570,10 @@ (declare ->partials) +;; TODO fix the docstring, and think of how we can combine this into the +;; narrative of what we find in derivative. Maybe this should be the main +;; version? + (defn- ->partials-fn "Returns a new function that composes a 'tag extraction' step with `f`. The returned fn will @@ -582,12 +616,24 @@ (vector? output) (mapv #(->partials % tag) output) + ;; Here is an example of the subtlety. We MAY want to go one at a + ;; time... or we may want to insert some sensitivity entry into the + ;; entire structure and roll the entire structure back. We don't do that + ;; YET so I bet we can get away with ignoring it for this first PR. But + ;; we are close to needing that. (s/structure? output) (s/mapr #(->partials % tag) output) (f/function? output) (->partials-fn output tag) + (o/operator? output) + (o/->Operator (->partials-fn (o/procedure output) tag) + (o/arity output) + (o/name output) + (o/context output) + (meta output)) + (v/scalar? output) (->Completed {}) @@ -620,10 +666,15 @@ (s/mapr #(extract % id) output) (f/function? output) - ;; TODO this needs to handle perturbation confusion with tag - ;; replacement. Make something similar to extract-tangent-fn. (comp #(extract % id) output) + (o/operator? output) + (o/->Operator (extract (o/procedure output) id) + (o/arity output) + (o/name output) + (o/context output) + (meta output)) + :else 0)) ;; TODO note that [[interpret]] and [[tapify]] both need to become generic on From 3ece359b4806d709aba765cb0ef4d88ae431cff9 Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Mon, 22 Apr 2024 12:20:13 -0600 Subject: [PATCH 3/6] add tests --- CHANGELOG.md | 7 +++ src/emmy/abstract/function.cljc | 48 +++++++++++++-- src/emmy/calculus/derivative.cljc | 20 +++--- src/emmy/differential.cljc | 38 +++++------- src/emmy/tape.cljc | 88 +++++++++------------------ test/emmy/abstract/function_test.cljc | 18 ++++++ 6 files changed, 118 insertions(+), 101 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a7633cb..fc39aeb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,13 @@ ## [unreleased] +- #156: + + - Makes forward- and reverse-mode automatic differentiation compatible with + each other, allowing for proper mixed-mode AD + + - Adds support for derivatives of literal functions in reverse-mode + - #165: - Fixes Alexey's Amazing Bug for our tape implementation diff --git a/src/emmy/abstract/function.cljc b/src/emmy/abstract/function.cljc index 5ddc2a1c..c3e3abe8 100644 --- a/src/emmy/abstract/function.cljc +++ b/src/emmy/abstract/function.cljc @@ -251,7 +251,21 @@ (->Function fexp (f/arity f) (domain-types f) (range-type f)))) -(defn- forward-mode-fold [f primal-s tag] +(defn- forward-mode-fold + "Takes + + - a literal function `f` + - a structure `primal-s` of the primal components of the args to `f` (with + respect to `tag`) + - the `tag` of the innermost active derivative call + + And returns a folding function (designed for use + with [[emmy.structure/fold-chain]]) that + + generates a new [[emmy.differential/Dual]] by applying the chain rule and + summing the partial derivatives for each perturbed argument in the input + structure." + [f primal-s tag] (fn ([] 0) ([tangent] (d/bundle-element (apply f primal-s) tangent tag)) @@ -263,7 +277,18 @@ (g/+ tangent (g/* (literal-apply partial primal-s) dx)))))))) -(defn- reverse-mode-fold [f primal-s tag] +(defn- reverse-mode-fold + "Takes + + - a literal function `f` + - a structure `primal-s` of the primal components of the args to `f` (with + respect to `tag`) + - the `tag` of the innermost active derivative call + + And returns a folding function (designed for use + with [[emmy.structure/fold-chain]]) that assembles all partial derivatives of + `f` into a new [[emmy.tape/TapeCell]]." + [f primal-s tag] (fn ([] []) ([partials] @@ -275,9 +300,22 @@ partials)))) (defn- literal-derivative - "Takes a literal function `f` and a sequence of arguments `xs`, and generates an - expanded `((D f) xs)` by applying the chain rule and summing the partial - derivatives for each perturbed argument in the input structure." + "Takes + + - a literal function `f` + - a structure `s` of arguments + - the `tag` of the innermost active derivative call + - an instance of a perturbation `dx` associated with `tag` + + and generates the proper return value for `((D f) xs)`. + + In forward-mode AD this is a new [[emmy.differential/Dual]] generated by + applying the chain rule and summing the partial derivatives for each perturbed + argument in the input structure. + + In reverse-mode, this is a new [[emmy.tape/TapeCell]] containing a sequence of + pairs of each input paired with the partial derivative of `f` with respect to + that input." [f s tag dx] (let [fold-fn (cond (tape/tape? dx) reverse-mode-fold (d/dual? dx) forward-mode-fold diff --git a/src/emmy/calculus/derivative.cljc b/src/emmy/calculus/derivative.cljc index da5a7584..5f57e9a4 100644 --- a/src/emmy/calculus/derivative.cljc +++ b/src/emmy/calculus/derivative.cljc @@ -459,17 +459,6 @@ (o/make-operator #(g/partial-derivative % []) g/derivative-symbol)) -(def D-rev - "Reverse-mode derivative operator..." - (o/make-operator #(tape/gradient % []) - g/derivative-symbol)) - -(defn partial-rev - "Reverse-mode partial derivative." - [& selectors] - (o/make-operator #(tape/gradient % selectors) - `(~'partial ~@selectors))) - (defn D-as-matrix [F] (fn [s] (matrix/s->m @@ -550,7 +539,14 @@ (d/tag x)) (tape/tape? x) - (u/illegal "TODO implement this using fmap style.") + (tape/->TapeCell + (tape/tape-tag x) + (tape/tape-id x) + (rec (tape/tape-primal x)) + (mapv (fn [[node partial]] + [(rec node) + (rec partial)]) + (tape/tape-partials x))) :else (-> (g/simplify x) (x/substitute replace-m)))) diff --git a/src/emmy/differential.cljc b/src/emmy/differential.cljc index e16542a7..9ff101ea 100644 --- a/src/emmy/differential.cljc +++ b/src/emmy/differential.cljc @@ -102,7 +102,7 @@ ;; ;; $$f(a+b\varepsilon) = f(a)+ (Df(a)b)\varepsilon$$ ;; -;; > NOTE: See [[lift-1]] for an implementation of this idea. +;; > NOTE: See [[emmy.tape/lift-1]] for an implementation of this idea. ;; ;; This justifies our claim above: applying a function to some dual number ;; $a+\varepsilon$ returns a new dual number, where @@ -257,9 +257,9 @@ ;; ### What Return Values are Allowed? ;; ;; Before we discuss the implementation of dual -;; numbers (called [[Differential]]), [[lift-1]], [[lift-2]] and the rest of the -;; machinery that makes this all possible; what sorts of objects is `f` allowed -;; to return? +;; numbers (called [[Differential]]), [[emmy.tape/lift-1]], [[emmy.tape/lift-2]] +;; and the rest of the machinery that makes this all possible; what sorts of +;; objects is `f` allowed to return? ;; ;; The dual number approach is beautiful because we can bring to bear all sorts ;; of operations in Clojure that never even _see_ dual numbers. For example, @@ -668,29 +668,19 @@ ;; ## Chain Rule and Lifted Functions ;; -;; Finally, we come to the heart of it! [[lift-1]] and [[lift-2]] "lift", or -;; augment, unary or binary functions with the ability to -;; handle [[Dual]] instances in addition to whatever other types they -;; previously supported. +;; For the rest of the story, please see the implementations +;; of [[emmy.tape/lift-1]] and [[emmy.tape/lift-2]]. These functions "lift", or +;; augment, unary or binary functions with the ability to handle [[Dual]] +;; instances in addition to whatever other types they previously supported. ;; -;; These functions are implementations of the single and multivariable Taylor -;; series expansion methods discussed at the beginning of the namespace. -;; -;; There is yet another subtlety here, noted in the docstrings below. [[lift-1]] -;; and [[lift-2]] really are able to lift functions like [[clojure.core/+]] that -;; can't accept [[Dual]]s. But the first-order derivatives that you have -;; to supply _do_ have to be able to take [[Dual]] instances. -;; -;; This is because the [[tangent]] of [[Dual]] might still be a [[Dual]], and -;; for `Df` to handle this we need to be able to take the second-order -;; derivative. -;; -;; Magically this will all Just Work if you pass an already-lifted function, or -;; a function built out of already-lifted components, as `df:dx` or `df:dy`. - -;; TODO port docs above... +;; The [[dual?]] branches inside these functions are implementations of the +;; single and multivariable Taylor series expansion methods discussed at the +;; beginning of the namespace. ;; ## Generic Method Installation +;; +;; These generic methods don't need to be lifted, so live here alongside +;; the [[Dual]] type definition. (defmethod g/zero-like [::dual] [_] 0) (defmethod g/one-like [::dual] [_] 1) diff --git a/src/emmy/tape.cljc b/src/emmy/tape.cljc index 10f42c1b..afdf0c0f 100644 --- a/src/emmy/tape.cljc +++ b/src/emmy/tape.cljc @@ -309,16 +309,12 @@ (defn tag-of "More permissive version of [[tape-tag]] that returns `nil` when passed a - non-[[TapeCell]] instance. - - TODO note what we handle now." + non-perturbation." [x] (cond (tape? x) (tape-tag x) (d/dual? x) (d/tag x) :else nil)) -;; TODO move tag stuff here? - (defn inner-tag "Given any number of `tags`, returns the tag most recently bound via [[with-active-tag]] (i.e., the tag connected with the _innermost_ call @@ -330,11 +326,6 @@ d/*active-tags*) (apply max tags))) -;; TODO we could change `perturbed?` into something like -;; `possible-perturbations`, to get collection types to return sequence of -;; inputs for this. Then we could handle map-shaped inputs etc into literal -;; functions, if we had the proper descriptor language for it. - (defn tag+perturbation "Given any number of `dxs`, returns a pair of the form @@ -357,9 +348,7 @@ (defn primal-of "More permissive version of [[tape-primal]] that returns `v` when passed a - non-[[TapeCell]]-or-[[emmy.differential/Dual]] instance. - - TODO fix docstring" + non-perturbation." ([v] (primal-of v (tag-of v))) ([v tag] @@ -368,12 +357,11 @@ :else v))) (defn deep-primal - "Version of [[tape-primal]] that will descend recursively into any [[TapeCell]] - instance returned by [[tape-primal]] until encountering a non-[[TapeCell]]. + "Version of [[tape-primal]] that will descend recursively into any perturbation + instance returned by [[tape-primal]] or [[emmy.differential/primal]] until + encountering a non-perturbation. - Given a non-[[TapeCell]], acts as identity. - - TODO say what we really do now" + Given a non-perturbation, acts as identity." ([v] (cond (tape? v) (recur (tape-primal v)) (d/dual? v) (recur (d/primal v)) @@ -491,20 +479,6 @@ ;; (defrecord Completed [v->partial] d/IPerturbed - ;; TODO note that this can happen because these can pop out from inside of - ;; ->partial-fn. And that is currently where the tag-rewriting has to occur. - ;; - ;; But that is going to be inefficient for lots of intermediate values... - ;; ideally we could call this AFTER we select out the IDs. That implies that - ;; we want to shove that inside of extract. - ;; - ;; TODO TODO TODO definitely do this, we definitely want that to happen, don't - ;; have those stacked levels, otherwise super inefficient to walk multiple - ;; times. - ;; - ;; TODO AND THEN if that's true then we can delete this implementation, since - ;; we'll already be pulled OUT of the completed map. - ;; NOTE that it's a problem that `replace-tag` is called on [[Completed]] ;; instances now. In a future refactor I want `get` calls out of ;; a [[Completed]] map to occur before tag replacement needs to happen. @@ -550,10 +524,6 @@ - the partial derivative of the output with respect to that value." [root] (let [nodes (topological-sort root) - - ;; TODO this is the spot where we want to wire in many sensitivities. So - ;; how would it work, if we set all of the sensitivities for the outputs - ;; at once? What would the ordering be as we walked backwards? sensitivities {(tape-id root) 1}] (->Completed (reduce process sensitivities nodes)))) @@ -570,10 +540,6 @@ (declare ->partials) -;; TODO fix the docstring, and think of how we can combine this into the -;; narrative of what we find in derivative. Maybe this should be the main -;; version? - (defn- ->partials-fn "Returns a new function that composes a 'tag extraction' step with `f`. The returned fn will @@ -616,11 +582,6 @@ (vector? output) (mapv #(->partials % tag) output) - ;; Here is an example of the subtlety. We MAY want to go one at a - ;; time... or we may want to insert some sensitivity entry into the - ;; entire structure and roll the entire structure back. We don't do that - ;; YET so I bet we can get away with ignoring it for this first PR. But - ;; we are close to needing that. (s/structure? output) (s/mapr #(->partials % tag) output) @@ -755,11 +716,14 @@ (matrix/seq-> (cons x more))))))) ;; ## Lifted Functions + +;; [[lift-1]] and [[lift-2]] "lift", or augment, unary or binary functions with +;; the ability to handle [[emmy.differential/Dual]] and [[TapeCell]] instances +;; in addition to whatever other types they previously supported. ;; -;; NOTE these next two functions are similar to the functions -;; in [[emmy.differential]]; both of these should be merged and install methods -;; that can handle the interaction between [[TapeCell]] -;; and [[emmy.differential/Differential]] instances. +;; Forward-mode support for [[emmy.differential/Dual]] is an implementation of +;; the single and multivariable Taylor series expansion methods discussed at the +;; beginning of [[emmy.differential]]. ;; ;; To support reverse-mode automatic differentiation, When a unary or binary ;; function `f` encounters a [[TapeCell]] `x` (and `y` in the binary case) it @@ -778,10 +742,19 @@ ;; ```` ;; ;; in the binary case. + +;; There is a subtlety here, noted in the docstrings below. [[lift-1]] +;; and [[lift-2]] really are able to lift functions like [[clojure.core/+]] that +;; can't accept [[emmy.differential/Dual]] and [[TapeCell]]s. But the +;; first-order derivatives that you have to supply _do_ have to be able to take +;; instances of these types. ;; -;; The partial derivative implementations are passed in directly or retrieved -;; from the generic implementation using the same method as in -;; the [[emmy.differential]] versions, hinting again that we should unify these. +;; This is because, for example, the [[emmy.differential/tangent]] of [[Dual]] +;; might still be a [[Dual]], and will hit the first-order derivative via the +;; chain rule. +;; +;; Magically this will all Just Work if you pass an already-lifted function, or +;; a function built out of already-lifted components, as `df:dx` or `df:dy`. (defn lift-1 "Given: @@ -875,7 +848,7 @@ (cond (tape? dx) (operate-reverse tag) (d/dual? dx) (operate-forward tag) :else - (u/illegal "Non-tape or differential perturbation!")) + (u/illegal "Non-tape or dual perturbation!")) (f x y)))))) (defn lift-n @@ -948,13 +921,8 @@ (defn ^:no-doc by-primal "Given some unary or binary function `f`, returns an augmented `f` that acts on - the primal entries of any [[TapeCell]] arguments encountered, irrespective of - tag. - - Given a [[TapeCell]] with a [[TapeCell]] in its [[primal-part]], the returned - `f` will recursively descend until it hits a non-[[TapeCell]]. - - TODO fix docs" + the primal entries of any perturbed arguments encountered, irrespective of + tag." [f] (fn ([x] (f (deep-primal x))) diff --git a/test/emmy/abstract/function_test.cljc b/test/emmy/abstract/function_test.cljc index d46e31f7..f82631a5 100644 --- a/test/emmy/abstract/function_test.cljc +++ b/test/emmy/abstract/function_test.cljc @@ -17,6 +17,7 @@ [emmy.structure :as s :refer [literal-up literal-down up down]] + [emmy.tape :as t] [emmy.value :as v :refer [=]])) (use-fixtures :each hermetic-simplify-fixture) @@ -331,3 +332,20 @@ (is (= '((((partial 1) F) x y) (((partial 1) G) x y) 0 0) (simp4 (((partial 1) T) 'x 'y)))) (is (= '((((partial 0) W) (up r θ)) (((partial 0) Z) (up r θ)) 0 0) (simp4 (((partial 0) U) (up 'r 'θ))))) (is (= '((((partial 1) W) (up r θ)) (((partial 1) Z) (up r θ)) 0 0) (simp4 (((partial 1) U) (up 'r 'θ)))))))) + +(deftest literal-derivative-tests + (testing "reverse-mode matches forward-mode" + (let [f (af/literal-function 'f)] + (is (= ((emmy.tape/gradient f) 'x) + ((D f) 'x)) + "gradient matches D with internal partials")) + + (let [f (af/literal-function 'f '(-> (UP Real (UP Real Real) (UP Real Real)) Real)) + s (up 't (up 'x 'y) (up 'px 'py))] + (is (= ((t/gradient (t/gradient f)) s) + ((D (D f)) s)) + "gradient matches D with internal partials") + + (is (= ((D (t/gradient f)) s) + ((t/gradient (D f)) s)) + "mixed-mode")))) From deedf550ad472dde29ff021b7b07d7af880e0839 Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Mon, 22 Apr 2024 12:49:06 -0600 Subject: [PATCH 4/6] remove deep-primal --- src/emmy/differential.cljc | 10 ---------- test/emmy/differential_test.cljc | 12 +----------- 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/src/emmy/differential.cljc b/src/emmy/differential.cljc index 9ff101ea..b35a3842 100644 --- a/src/emmy/differential.cljc +++ b/src/emmy/differential.cljc @@ -482,16 +482,6 @@ (-> (primal-tangent-pair dx tag) (nth 0)))) -(defn deep-primal - "Version of [[primal]] that will descend recursively into any [[Dual]] instance - returned by [[primal]] until encountering a non-[[Dual]]. - - Given a non-[[Dual]], acts as identity." - [dx] - (if (dual? dx) - (recur (.-primal ^Dual dx)) - dx)) - (defn tangent "If `dx` is an instance of [[Dual]] returns the `tangent` component. Else, returns 0. diff --git a/test/emmy/differential_test.cljc b/test/emmy/differential_test.cljc index 2373b88f..7cc48823 100644 --- a/test/emmy/differential_test.cljc +++ b/test/emmy/differential_test.cljc @@ -329,17 +329,7 @@ (let [tag (d/tag diff) [primal tangent] (d/primal-tangent-pair diff tag)] (is (d/eq primal (d/primal diff tag))) - (is (d/eq tangent (d/tangent diff tag))))) - - (checking "deep-primal fetches bottom primal" 100 - [diff real-dual-gen] - (loop [diff diff] - (let [primal (d/primal diff)] - (if (d/dual? primal) - (recur primal) - (is (= primal (d/primal diff)) - "recursing to the bottom with primal gives the same result - as jumping straight there with deep-primal")))))) + (is (d/eq tangent (d/tangent diff tag)))))) (deftest lifted-fn-tests (letfn [(breaks? [f x] From 2617c36e38780ec8e87d1e8cc8ebefe5b941ce1d Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Mon, 22 Apr 2024 14:44:36 -0600 Subject: [PATCH 5/6] new test --- test/emmy/calculus/derivative_test.cljc | 27 ++++++++++++++------- test/emmy/tape_test.cljc | 32 +++++++++++++++++++++---- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/test/emmy/calculus/derivative_test.cljc b/test/emmy/calculus/derivative_test.cljc index 4921f9e9..35759dae 100644 --- a/test/emmy/calculus/derivative_test.cljc +++ b/test/emmy/calculus/derivative_test.cljc @@ -19,6 +19,7 @@ [emmy.series :as series] [emmy.simplify :refer [hermetic-simplify-fixture]] [emmy.structure :as s] + [emmy.tape :as tape] [emmy.util :as u] [emmy.value :as v] [same.core :refer [ish? with-comparator]])) @@ -243,10 +244,10 @@ (* (η t) ((D g) (q t)))) (simplify (((δη (+ F G)) q) 't))))) -(testing "scalar product rule for variation: δ(cF) = cδF" + (testing "scalar product rule for variation: δ(cF) = cδF" (is (= '(* c (η t) ((D f) (q t))) (simplify (((δη (* 'c F)) q) 't))))) -(testing "product rule for variation: δ(FG) = δF G + F δG" + (testing "product rule for variation: δ(FG) = δF G + F δG" (is (= (simplify (+ (* (((δη F) q) 't) ((G q) 't)) (* ((F q) 't) (((δη G) q) 't)))) (simplify (((δη (* F G)) q) 't))))) @@ -695,8 +696,8 @@ (testing "f -> Series" (let [F (fn [k] (series/series - (fn [t] (g/* k t)) - (fn [t] (g/* k k t))))] + (fn [t] (g/* k t)) + (fn [t] (g/* k k t))))] (is (= '((* q z) (* (expt q 2) z) 0 0) (simp4 ((F 'q) 'z)))) (is (= '(z (* 2 q z) 0 0) (simp4 (((D F) 'q) 'z))))))) @@ -1272,7 +1273,7 @@ ;; This means that the tangents of the `x` instances captured by `f1` and ;; `f2` can no longer interact. There is no context waiting to bind them ;; together! -))) + ))) (deftest dvl-bug-examples ;; These tests and comments all come from Alexey Radul's @@ -1342,7 +1343,7 @@ ;; The "linear" comment matters because if you only combine the dropped-down ;; pieces linearly, then their tangents wouldn't have interacted anyway, so ;; you can't tell that there are different cases here. -) + ) (testing "amazing bug 4" ;; The same as amazing-bug-3.dvl, but supplies the arguments to f in the @@ -1576,10 +1577,20 @@ `D`; this shows that it can do proper symbolic replacement inside of differential instances.") + (is (v/= [0 1 0 0] + ((tape/gradient + (fn [y] + (into [] (take 4 (d/symbolic-taylor-series + (fn [x] (g/* x y)) + 0))))) + 'a)) + "works with gradient too! TODO once gradients support series outputs, + massage this into better shape...") + (testing "compare, one stays symbolic:" (letfn [(f [[a b]] - (* (sin (* 3 a)) - (cos (* 4 b))))] + (* (sin (* 3 a)) + (cos (* 4 b))))] (is (ish? [-0.020532965943782493 (s/down 0.4321318251769156 -0.558472974950351)] diff --git a/test/emmy/tape_test.cljc b/test/emmy/tape_test.cljc index 05070c83..8032dfd2 100644 --- a/test/emmy/tape_test.cljc +++ b/test/emmy/tape_test.cljc @@ -11,6 +11,7 @@ [emmy.generators :as sg] [emmy.generic :as g] [emmy.numerical.derivative :refer [D-numeric]] + [emmy.operator :as o] [emmy.simplify :refer [hermetic-simplify-fixture]] [emmy.structure :as s] [emmy.tape :as t] @@ -262,11 +263,24 @@ (let [cell (t/make tag 1)] (is (= (t/tag-of cell) (t/tape-tag cell)) - "for tape cells, these should match"))) + "for tape cells, these should match")))) - (checking "for any other type tag == nil" 100 [x gen/any] - (is (nil? (t/tag-of x)) - "for tape cells, these should match"))) + (testing "primal-of" + (checking "for any other type primal-of == identity" 100 [x gen/any-equatable] + (is (= x (t/primal-of x)))) + + (checking "vs tape-primal" 100 [tape (sg/tapecell gen/symbol)] + (is (= (t/primal-of tape) + (t/tape-primal tape)) + "primal-of eq with and without tag") + + (is (= (t/primal-of tape) + (t/primal-of tape (t/tape-tag tape))) + "primal-of eq with and without tag") + + (is (= (t/tape-primal tape) + (t/tape-primal tape (t/tape-tag tape))) + "tape-primal eq with and without tag"))) (checking "deep-primal returns nested primal" 100 [p gen/any-equatable] (let [cell (t/make 0 (t/make 1 p))] @@ -547,6 +561,16 @@ "partial selector provided to a fn of a non-structural argument throws on fn application") + (testing "Operator" + (letfn [(f [x] + (o/make-operator (g/* x g/sin) 'D-op))] + (is (o/operator? ((t/gradient f) 'x)) + "if f returns an operator, (gradient f) does too.") + (is (= '(sin y) + (g/freeze + (((t/gradient f) 'x) 'y))) + "gradient pushes into the operator's fn."))) + (testing "partial derivative" (let [f (fn [[x y z]] (g/+ (g/expt x 4) (g/* x y z (g/cos x))))] From 1e86dc527800117840fae56538864b283507cdb1 Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Tue, 23 Apr 2024 16:14:15 -0600 Subject: [PATCH 6/6] add another test --- src/emmy/tape.cljc | 10 +++++----- test/emmy/calculus/derivative_test.cljc | 11 +++++++++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/emmy/tape.cljc b/src/emmy/tape.cljc index afdf0c0f..ac3a7cf1 100644 --- a/src/emmy/tape.cljc +++ b/src/emmy/tape.cljc @@ -337,11 +337,11 @@ If none of `dxs` has an active tag, returns `nil`." ([& dxs] - (let [m (into {} (mapcat - (fn [dx] - (when-let [t (tag-of dx)] - {t dx}))) - dxs)] + (let [xform (map + (fn [dx] + (when-let [t (tag-of dx)] + [t dx]))) + m (into {} xform dxs)] (when (seq m) (let [tag (apply inner-tag (keys m))] [tag (m tag)]))))) diff --git a/test/emmy/calculus/derivative_test.cljc b/test/emmy/calculus/derivative_test.cljc index 35759dae..8bf2b0a9 100644 --- a/test/emmy/calculus/derivative_test.cljc +++ b/test/emmy/calculus/derivative_test.cljc @@ -104,7 +104,14 @@ (/ -2 (expt x 2)) (/ -3 (expt x 2))) (simplify - ((D f) 'x))))))) + ((D f) 'x)))))) + + (let [f (fn [a b c d e] [d e c b a]) + M ((D f) 'a 'b 'c 'd 'e)] + (is (= (s/up 4 5 3 2 1) + (g/* M [1 2 3 4 5])) + "the Jacobian of a permutation is the permutation matrix of that + permutation"))) (deftest derivative-return-tests (testing "Series, PowerSeries" @@ -1313,7 +1320,7 @@ (fn [g] (fn [z] (g (+ x z)))))))] (is (ish? ((fn [y] (+ (* 3 (cos (* 3 y))) - (* y (cos (* 3 y))))) + (* y (cos (* 3 y))))) (+ 3 Math/PI)) (((D f) 3) (fn [g-hat f-hat]