Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Apr 22, 2024
1 parent 69ee50a commit 227276b
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 90 deletions.
48 changes: 43 additions & 5 deletions src/emmy/abstract/function.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,21 @@
(->Function
fexp (f/arity f) (domain-types f) (range-type f))))

(defn- forward-mode-fold [f primal-s tag]
(defn- forward-mode-fold
"Takes
- a literal function `f`
- a structure `primal-s` of the primal components of the args to `f` (with
respect to `tag`)
- the `tag` of the innermost active derivative call
And returns a folding function (designed for use
with [[emmy.structure/fold-chain]]) that
generates a new [[emmy.differential/Dual]] by applying the chain rule and
summing the partial derivatives for each perturbed argument in the input
structure."
[f primal-s tag]
(fn
([] 0)
([tangent] (d/bundle-element (apply f primal-s) tangent tag))
Expand All @@ -263,7 +277,18 @@
(g/+ tangent (g/* (literal-apply partial primal-s)
dx))))))))

(defn- reverse-mode-fold [f primal-s tag]
(defn- reverse-mode-fold
"Takes
- a literal function `f`
- a structure `primal-s` of the primal components of the args to `f` (with
respect to `tag`)
- the `tag` of the innermost active derivative call
And returns a folding function (designed for use
with [[emmy.structure/fold-chain]]) that assembles all partial derivatives of
`f` into a new [[emmy.tape/TapeCell]]."
[f primal-s tag]
(fn
([] [])
([partials]
Expand All @@ -275,9 +300,22 @@
partials))))

Check warning on line 300 in src/emmy/abstract/function.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/abstract/function.cljc#L300

Added line #L300 was not covered by tests

(defn- literal-derivative
"Takes a literal function `f` and a sequence of arguments `xs`, and generates an
expanded `((D f) xs)` by applying the chain rule and summing the partial
derivatives for each perturbed argument in the input structure."
"Takes
- a literal function `f`
- a structure `s` of arguments
- the `tag` of the innermost active derivative call
- an instance of a perturbation `dx` associated with `tag`
and generates the proper return value for `((D f) xs)`.
In forward-mode AD this is a new [[emmy.differential/Dual]] generated by
applying the chain rule and summing the partial derivatives for each perturbed
argument in the input structure.
In reverse-mode, this is a new [[emmy.tape/TapeCell]] containing a sequence of
pairs of each input paired with the partial derivative of `f` with respect to
that input."
[f s tag dx]
(let [fold-fn (cond (tape/tape? dx) reverse-mode-fold
(d/dual? dx) forward-mode-fold
Expand Down
9 changes: 8 additions & 1 deletion src/emmy/calculus/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,14 @@
(d/tag x))

(tape/tape? x)
(u/illegal "TODO implement this using fmap style.")
(tape/->TapeCell
(tape/tape-tag x)
(tape/tape-id x)
(rec (tape/tape-primal x))
(mapv (fn [[node partial]]
[(rec node)
(rec partial)])
(tape/tape-partials x)))

Check warning on line 560 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L553-L560

Added lines #L553 - L560 were not covered by tests

:else (-> (g/simplify x)
(x/substitute replace-m))))
Expand Down
38 changes: 14 additions & 24 deletions src/emmy/differential.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
;;
;; $$f(a+b\varepsilon) = f(a)+ (Df(a)b)\varepsilon$$
;;
;; > NOTE: See [[lift-1]] for an implementation of this idea.
;; > NOTE: See [[emmy.tape/lift-1]] for an implementation of this idea.
;;
;; This justifies our claim above: applying a function to some dual number
;; $a+\varepsilon$ returns a new dual number, where
Expand Down Expand Up @@ -257,9 +257,9 @@
;; ### What Return Values are Allowed?
;;
;; Before we discuss the implementation of dual
;; numbers (called [[Differential]]), [[lift-1]], [[lift-2]] and the rest of the
;; machinery that makes this all possible; what sorts of objects is `f` allowed
;; to return?
;; numbers (called [[Differential]]), [[emmy.tape/lift-1]], [[emmy.tape/lift-2]]
;; and the rest of the machinery that makes this all possible; what sorts of
;; objects is `f` allowed to return?
;;
;; The dual number approach is beautiful because we can bring to bear all sorts
;; of operations in Clojure that never even _see_ dual numbers. For example,
Expand Down Expand Up @@ -668,29 +668,19 @@

;; ## Chain Rule and Lifted Functions
;;
;; Finally, we come to the heart of it! [[lift-1]] and [[lift-2]] "lift", or
;; augment, unary or binary functions with the ability to
;; handle [[Dual]] instances in addition to whatever other types they
;; previously supported.
;; For the rest of the story, please see the implementations
;; of [[emmy.tape/lift-1]] and [[emmy.tape/lift-2]]. These functions "lift", or
;; augment, unary or binary functions with the ability to handle [[Dual]]
;; instances in addition to whatever other types they previously supported.
;;
;; These functions are implementations of the single and multivariable Taylor
;; series expansion methods discussed at the beginning of the namespace.
;;
;; There is yet another subtlety here, noted in the docstrings below. [[lift-1]]
;; and [[lift-2]] really are able to lift functions like [[clojure.core/+]] that
;; can't accept [[Dual]]s. But the first-order derivatives that you have
;; to supply _do_ have to be able to take [[Dual]] instances.
;;
;; This is because the [[tangent]] of [[Dual]] might still be a [[Dual]], and
;; for `Df` to handle this we need to be able to take the second-order
;; derivative.
;;
;; Magically this will all Just Work if you pass an already-lifted function, or
;; a function built out of already-lifted components, as `df:dx` or `df:dy`.

;; TODO port docs above...
;; The [[dual?]] branches inside these functions are implementations of the
;; single and multivariable Taylor series expansion methods discussed at the
;; beginning of the namespace.

;; ## Generic Method Installation
;;
;; These generic methods don't need to be lifted, so live here alongside
;; the [[Dual]] type definition.

(defmethod g/zero-like [::dual] [_] 0)
(defmethod g/one-like [::dual] [_] 1)
Expand Down
88 changes: 28 additions & 60 deletions src/emmy/tape.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -309,16 +309,12 @@

(defn tag-of
"More permissive version of [[tape-tag]] that returns `nil` when passed a
non-[[TapeCell]] instance.
TODO note what we handle now."
non-perturbation."
[x]
(cond (tape? x) (tape-tag x)
(d/dual? x) (d/tag x)
:else nil))

;; TODO move tag stuff here?

(defn inner-tag
"Given any number of `tags`, returns the tag most recently bound
via [[with-active-tag]] (i.e., the tag connected with the _innermost_ call
Expand All @@ -330,11 +326,6 @@
d/*active-tags*)
(apply max tags)))

;; TODO we could change `perturbed?` into something like
;; `possible-perturbations`, to get collection types to return sequence of
;; inputs for this. Then we could handle map-shaped inputs etc into literal
;; functions, if we had the proper descriptor language for it.

(defn tag+perturbation
"Given any number of `dxs`, returns a pair of the form
Expand All @@ -357,9 +348,7 @@

(defn primal-of
"More permissive version of [[tape-primal]] that returns `v` when passed a
non-[[TapeCell]]-or-[[emmy.differential/Dual]] instance.
TODO fix docstring"
non-perturbation."
([v]
(primal-of v (tag-of v)))

Check warning on line 353 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L353

Added line #L353 was not covered by tests
([v tag]
Expand All @@ -368,12 +357,11 @@
:else v)))

(defn deep-primal
"Version of [[tape-primal]] that will descend recursively into any [[TapeCell]]
instance returned by [[tape-primal]] until encountering a non-[[TapeCell]].
"Version of [[tape-primal]] that will descend recursively into any perturbation
instance returned by [[tape-primal]] or [[emmy.differential/primal]] until
encountering a non-perturbation.
Given a non-[[TapeCell]], acts as identity.
TODO say what we really do now"
Given a non-perturbation, acts as identity."
([v]
(cond (tape? v) (recur (tape-primal v))
(d/dual? v) (recur (d/primal v))
Expand Down Expand Up @@ -491,20 +479,6 @@
;;
(defrecord Completed [v->partial]
d/IPerturbed
;; TODO note that this can happen because these can pop out from inside of
;; ->partial-fn. And that is currently where the tag-rewriting has to occur.
;;
;; But that is going to be inefficient for lots of intermediate values...
;; ideally we could call this AFTER we select out the IDs. That implies that
;; we want to shove that inside of extract.
;;
;; TODO TODO TODO definitely do this, we definitely want that to happen, don't
;; have those stacked levels, otherwise super inefficient to walk multiple
;; times.
;;
;; TODO AND THEN if that's true then we can delete this implementation, since
;; we'll already be pulled OUT of the completed map.

;; NOTE that it's a problem that `replace-tag` is called on [[Completed]]
;; instances now. In a future refactor I want `get` calls out of
;; a [[Completed]] map to occur before tag replacement needs to happen.
Expand Down Expand Up @@ -550,10 +524,6 @@
- the partial derivative of the output with respect to that value."
[root]
(let [nodes (topological-sort root)

;; TODO this is the spot where we want to wire in many sensitivities. So
;; how would it work, if we set all of the sensitivities for the outputs
;; at once? What would the ordering be as we walked backwards?
sensitivities {(tape-id root) 1}]
(->Completed
(reduce process sensitivities nodes))))
Expand All @@ -570,10 +540,6 @@

(declare ->partials)

;; TODO fix the docstring, and think of how we can combine this into the
;; narrative of what we find in derivative. Maybe this should be the main
;; version?

(defn- ->partials-fn
"Returns a new function that composes a 'tag extraction' step with `f`. The
returned fn will
Expand Down Expand Up @@ -616,11 +582,6 @@
(vector? output)
(mapv #(->partials % tag) output)

;; Here is an example of the subtlety. We MAY want to go one at a
;; time... or we may want to insert some sensitivity entry into the
;; entire structure and roll the entire structure back. We don't do that
;; YET so I bet we can get away with ignoring it for this first PR. But
;; we are close to needing that.
(s/structure? output)
(s/mapr #(->partials % tag) output)

Expand Down Expand Up @@ -755,11 +716,14 @@
(matrix/seq-> (cons x more)))))))

;; ## Lifted Functions

;; [[lift-1]] and [[lift-2]] "lift", or augment, unary or binary functions with
;; the ability to handle [[emmy.differential/Dual]] and [[TapeCell]] instances
;; in addition to whatever other types they previously supported.
;;
;; NOTE these next two functions are similar to the functions
;; in [[emmy.differential]]; both of these should be merged and install methods
;; that can handle the interaction between [[TapeCell]]
;; and [[emmy.differential/Differential]] instances.
;; Forward-mode support for [[emmy.differential/Dual]] is an implementation of
;; the single and multivariable Taylor series expansion methods discussed at the
;; beginning of [[emmy.differential]].
;;
;; To support reverse-mode automatic differentiation, When a unary or binary
;; function `f` encounters a [[TapeCell]] `x` (and `y` in the binary case) it
Expand All @@ -778,10 +742,19 @@
;; ````
;;
;; in the binary case.

;; There is a subtlety here, noted in the docstrings below. [[lift-1]]
;; and [[lift-2]] really are able to lift functions like [[clojure.core/+]] that
;; can't accept [[emmy.differential/Dual]] and [[TapeCell]]s. But the
;; first-order derivatives that you have to supply _do_ have to be able to take
;; instances of these types.
;;
;; The partial derivative implementations are passed in directly or retrieved
;; from the generic implementation using the same method as in
;; the [[emmy.differential]] versions, hinting again that we should unify these.
;; This is because, for example, the [[emmy.differential/tangent]] of [[Dual]]
;; might still be a [[Dual]], and will hit the first-order derivative via the
;; chain rule.
;;
;; Magically this will all Just Work if you pass an already-lifted function, or
;; a function built out of already-lifted components, as `df:dx` or `df:dy`.

(defn lift-1
"Given:
Expand Down Expand Up @@ -875,7 +848,7 @@
(cond (tape? dx) (operate-reverse tag)
(d/dual? dx) (operate-forward tag)
:else
(u/illegal "Non-tape or differential perturbation!"))
(u/illegal "Non-tape or dual perturbation!"))

Check warning on line 851 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L851

Added line #L851 was not covered by tests
(f x y))))))

(defn lift-n
Expand Down Expand Up @@ -948,13 +921,8 @@

(defn ^:no-doc by-primal
"Given some unary or binary function `f`, returns an augmented `f` that acts on
the primal entries of any [[TapeCell]] arguments encountered, irrespective of
tag.
Given a [[TapeCell]] with a [[TapeCell]] in its [[primal-part]], the returned
`f` will recursively descend until it hits a non-[[TapeCell]].
TODO fix docs"
the primal entries of any perturbed arguments encountered, irrespective of
tag."
[f]
(fn
([x] (f (deep-primal x)))
Expand Down
18 changes: 18 additions & 0 deletions test/emmy/abstract/function_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
[emmy.structure :as s :refer [literal-up
literal-down
up down]]
[emmy.tape :as t]
[emmy.value :as v :refer [=]]))

(use-fixtures :each hermetic-simplify-fixture)
Expand Down Expand Up @@ -331,3 +332,20 @@
(is (= '((((partial 1) F) x y) (((partial 1) G) x y) 0 0) (simp4 (((partial 1) T) 'x 'y))))
(is (= '((((partial 0) W) (up r θ)) (((partial 0) Z) (up r θ)) 0 0) (simp4 (((partial 0) U) (up 'r 'θ)))))
(is (= '((((partial 1) W) (up r θ)) (((partial 1) Z) (up r θ)) 0 0) (simp4 (((partial 1) U) (up 'r 'θ))))))))

(deftest literal-derivative-tests
(testing "reverse-mode matches forward-mode"
(let [f (af/literal-function 'f)]
(is (= ((emmy.tape/gradient f) 'x)
((D f) 'x))
"gradient matches D with internal partials"))

(let [f (af/literal-function 'f '(-> (UP Real (UP Real Real) (UP Real Real)) Real))
s (up 't (up 'x 'y) (up 'px 'py))]
(is (= ((t/gradient (t/gradient f)) s)
((D (D f)) s))
"gradient matches D with internal partials")

(is (= ((D (t/gradient f)) s)
((t/gradient (D f)) s))
"mixed-mode"))))

0 comments on commit 227276b

Please sign in to comment.