Skip to content

Commit

Permalink
get primal back
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Apr 23, 2024
1 parent 08d1063 commit 002d575
Showing 1 changed file with 39 additions and 1 deletion.
40 changes: 39 additions & 1 deletion src/emmy/calculus/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,50 @@

(defn vjp [f v]
(let [g (fn [x]
(if (or (and (v/scalar? x) (v/scalar? v))
(if (or (v/scalar? x)
(v/scalar? v)
(s/compatible-for-contraction? x v))
(g/* x v)
(u/illegal "Incompatible structures!")))]
(gradient (comp g f))))

;; TODO this would be better, if we had a clear way of pulling the primals out
;; from the IPerturbed protocol... that would let `jvp` do its thing as well.

(defn primal-and-derivative [f]
(fn [x]
(let [tag (d/fresh-tag)
lifted (d/bundle-element x 1 tag)
output (d/with-active-tag tag f [lifted])]
[(d/primal output)
(d/extract-tangent output tag)])))

(defn primal-and-jvp [f v]
(fn [x]
(let [g (fn [r]
(f (g/+ x (g/* r v))))]
((primal-and-derivative g) 0))))

(defn vjp* [f]
(fn [x]
(let [tag (d/fresh-tag)
inputs (tape/tapify x tag)
output (d/with-active-tag tag f [inputs])]
[(tape/tape-primal output)
(fn [v]
(let [output (if (or (v/scalar? x)
(v/scalar? v)
(s/compatible-for-contraction? x v))
(g/* output v)
(u/illegal "Incompatible structures!"))
completed (tape/->partials output tag)]
(tape/interpret inputs completed tag)))])))

;; attempting to return the primal too

(defn primal-and-vjp [f]
(multi vjp* f))

(defn hvp [f v]
(jvp (gradient f) v))

Expand Down

0 comments on commit 002d575

Please sign in to comment.