Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nested forward/reverse mode #156

Merged
merged 6 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

## [unreleased]

- #156:

- Makes forward- and reverse-mode automatic differentiation compatible with
each other, allowing for proper mixed-mode AD

- Adds support for derivatives of literal functions in reverse-mode

- #165:

- Fixes Alexey's Amazing Bug for our tape implementation
Expand Down
68 changes: 59 additions & 9 deletions src/emmy/abstract/function.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
[emmy.numsymb :as sym]
[emmy.polynomial]
[emmy.structure :as s]
[emmy.tape :as tape]
[emmy.util :as u]
[emmy.value :as v])
#?(:clj
Expand Down Expand Up @@ -250,7 +251,21 @@
(->Function
fexp (f/arity f) (domain-types f) (range-type f))))

(defn- forward-mode-fold [f primal-s tag]
(defn- forward-mode-fold
"Takes
sritchie marked this conversation as resolved.
Show resolved Hide resolved

- a literal function `f`
- a structure `primal-s` of the primal components of the args to `f` (with
respect to `tag`)
- the `tag` of the innermost active derivative call

And returns a folding function (designed for use
with [[emmy.structure/fold-chain]]) that

generates a new [[emmy.differential/Dual]] by applying the chain rule and
summing the partial derivatives for each perturbed argument in the input
structure."
[f primal-s tag]
(fn
([] 0)
([tangent] (d/bundle-element (apply f primal-s) tangent tag))
Expand All @@ -262,15 +277,50 @@
(g/+ tangent (g/* (literal-apply partial primal-s)
dx))))))))

(defn- reverse-mode-fold
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

building on the cryptic fold-chain from a previous PR, now we can drop this in and support reverse-mode.

"Takes

- a literal function `f`
- a structure `primal-s` of the primal components of the args to `f` (with
respect to `tag`)
- the `tag` of the innermost active derivative call

And returns a folding function (designed for use
with [[emmy.structure/fold-chain]]) that assembles all partial derivatives of
`f` into a new [[emmy.tape/TapeCell]]."
[f primal-s tag]
(fn
([] [])
([partials]
(tape/make tag (apply f primal-s) partials))
([partials [entry path _]]
(if (and (tape/tape? entry) (= tag (tape/tape-tag entry)))
(let [partial (literal-partial f path)]
(conj partials [entry (literal-apply partial primal-s)]))
partials))))

Check warning on line 300 in src/emmy/abstract/function.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/abstract/function.cljc#L300

Added line #L300 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is odd that nothing triggers L300, so in every case, (and (tape/tape? entry) (= tag (tape/tape-tag entry))).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would happen in the case of nested gradient calls to a literal function, I just don't have those in the tests yet. In that case, the innermost tag wins and any tag with a different tape is treated as a scalar (i.e. partial derivative == 0, so we never add it to the map)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to share tests between forward mode and reverse mode... when I do that this should get hit.


(defn- literal-derivative
"Takes a literal function `f` and a sequence of arguments `xs`, and generates
an expanded `((D f) xs)` by applying the chain rule and summing the partial
derivatives for each [[emmy.differential/Dual]] argument in the input
structure."
"Takes

- a literal function `f`
- a structure `s` of arguments
- the `tag` of the innermost active derivative call
- an instance of a perturbation `dx` associated with `tag`

and generates the proper return value for `((D f) xs)`.

In forward-mode AD this is a new [[emmy.differential/Dual]] generated by
applying the chain rule and summing the partial derivatives for each perturbed
argument in the input structure.

In reverse-mode, this is a new [[emmy.tape/TapeCell]] containing a sequence of
pairs of each input paired with the partial derivative of `f` with respect to
that input."
[f s tag dx]
(let [fold-fn (cond (d/dual? dx) forward-mode-fold
:else (u/illegal "No tape or differential inputs."))
primal-s (s/mapr (fn [x] (d/primal x tag)) s)]
(let [fold-fn (cond (tape/tape? dx) reverse-mode-fold
(d/dual? dx) forward-mode-fold
:else (u/illegal "No tape or differential inputs."))

Check warning on line 322 in src/emmy/abstract/function.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/abstract/function.cljc#L322

Added line #L322 was not covered by tests
primal-s (s/mapr (fn [x] (tape/primal-of x tag)) s)]
(s/fold-chain (fold-fn f primal-s tag) s)))

(defn- check-argument-type
Expand Down Expand Up @@ -305,7 +355,7 @@
(if-let [[tag dx] (s/fold-chain
(fn
([] [])
([acc] (apply d/tag+perturbation acc))
([acc] (apply tape/tag+perturbation acc))
([acc [d]] (conj acc d)))
s)]
(literal-derivative f s tag dx)
Expand Down
26 changes: 19 additions & 7 deletions src/emmy/calculus/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
[emmy.operator :as o]
[emmy.series :as series]
[emmy.structure :as s]
[emmy.tape :as tape]
[emmy.util :as u]
[emmy.value :as v])
#?(:clj
Expand Down Expand Up @@ -531,12 +532,23 @@
(letfn [(process-term [term]
(g/simplify
(s/mapr (fn rec [x]
(if (d/dual? x)
(d/bundle-element
(rec (d/primal x))
(rec (d/tangent x))
(d/tag x))
(-> (g/simplify x)
(x/substitute replace-m))))
(cond (d/dual? x)
(d/bundle-element
(rec (d/primal x))
(rec (d/tangent x))
(d/tag x))

(tape/tape? x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels janky, and is the only place where we deliberately introspect these types.

(tape/->TapeCell
(tape/tape-tag x)
(tape/tape-id x)
(rec (tape/tape-primal x))
(mapv (fn [[node partial]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like the inner thing is mapv too. Does writing it that way feel less janky? [node partial] is almost an instance of Dual, right, except that the tag belongs to the containing tape. Armed with that, you could unify this part with the part above, but (I say sincerely) it would be too clever. Therefore I don't think this is janky: it feels like about what we want to see here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this whole function is too clever, and my cleverness bit me here: #168

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleverness in trying to make this compatible with D, that is...

[(rec node)
(rec partial)])
(tape/tape-partials x)))

:else (-> (g/simplify x)
(x/substitute replace-m))))
term)))]
(series/fmap process-term series)))))
Loading
Loading