Skip to content

Commit

Permalink
refactor: Add caching, tests, and cache scalar pdf/prob/condition/con…
Browse files Browse the repository at this point in the history
…strain/mutual-info fns in clj

Also bump up inferenceql.inference version to avoid arrow constructor bug
  • Loading branch information
KingMob committed Mar 19, 2024
1 parent be3b194 commit 7e9d431
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 38 deletions.
5 changes: 3 additions & 2 deletions deps.edn
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
metosin/muuntaja {:mvn/version "0.6.8"}
net.cgrand/macrovich {:mvn/version "0.2.1"}
net.cgrand/xforms {:mvn/version "0.19.2"}
io.github.inferenceql/inferenceql.inference {:git/sha "40e77dedf680b7936ce988b66186a86f5c4db6a5"}
io.github.inferenceql/inferenceql.inference {:git/sha "c3cef474ba964a37fc2e5ff667055f5b77e12c45"}
org.babashka/sci {:mvn/version "0.3.32"}
org.clojure/clojure {:mvn/version "1.11.1"}
org.clojure/clojurescript {:mvn/version "1.11.132"}
Expand All @@ -18,7 +18,8 @@
ring-cors/ring-cors {:mvn/version "0.1.13"}
ring/ring-core {:mvn/version "1.9.5"}
ring/ring-jetty-adapter {:mvn/version "1.9.5"}
tech.tablesaw/tablesaw-core {:mvn/version "0.43.1"}}
tech.tablesaw/tablesaw-core {:mvn/version "0.43.1"}
org.clojure/core.memoize {:mvn/version "1.0.257"}}
:paths ["src" "resources"]
:aliases {:test {:extra-paths ["test"]
:extra-deps {com.gfredericks/test.chuck {:mvn/version "0.2.13"}
Expand Down
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
{
"name": "inferenceql.query",
"version": "1.0.0",
"description": "![tests](https://github.com/OpenIQL/inferenceql.query/workflows/tests/badge.svg) ![linter](https://github.com/OpenIQL/inferenceql.query/workflows/linter/badge.svg)",
Expand All @@ -25,5 +25,6 @@
"shadow-cljs": "^2.27.5"
},
"dependencies": {
"memoizee": "^0.4.15"
}
}
20 changes: 20 additions & 0 deletions src/inferenceql/query/cache.cljc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
(ns inferenceql.query.cache
"For caching expensive results."
(:require #?(:clj [clojure.core.memoize :as memo]
:cljs [memoizee :as memoizee])
[clojure.string :as string]))

(def default-threshold 100)


(defn lru
"Memoizes a fn with a least-recently-used eviction policy.
After the number of cached results exceeds the threshold, the
least-recently-used ones will be evicted."
([f]
(lru f default-threshold))
([f lru-threshold]
#?(:clj (memo/lru f :lru/threshold lru-threshold)
:cljs (memoizee f #js {"max" lru-threshold
"normalizer" js/JSON.stringify}))))
76 changes: 41 additions & 35 deletions src/inferenceql/query/scalar.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
[inferenceql.inference.approximate :as approx]
[inferenceql.inference.gpm :as gpm]
;; [inferenceql.inference.search.crosscat :as crosscat]
[inferenceql.query.cache :as cache]
#?(:clj [inferenceql.query.generative-table :as generative-table])
[inferenceql.query.parser.tree :as tree]
[inferenceql.query.relation :as relation]
Expand Down Expand Up @@ -91,23 +92,26 @@
:else x))
event))

(defn prob
[model event]
(let [event (inference-event event)]
(math/exp (gpm/logprob model event))))

(defn pdf
[model event]
(let [event (update-keys event keyword)]
(math/exp (gpm/logpdf model event {}))))

(defn condition
[model conditions]
(let [conditions (-> (medley/filter-vals some? conditions)
(update-keys keyword))]
(cond-> model
(seq conditions)
(gpm/condition conditions))))
(def prob (cache/lru
(fn prob*
[model event]
(let [event (inference-event event)]
(math/exp (gpm/logprob model event))))))

(def pdf (cache/lru
(fn pdf*
[model event]
(let [event (update-keys event keyword)]
(math/exp (gpm/logpdf model event {}))))))

(def condition (cache/lru
(fn condition*
[model conditions]
(let [conditions (-> (medley/filter-vals some? conditions)
(update-keys keyword))]
(cond-> model
(seq conditions)
(gpm/condition conditions))))))

(defn condition-all
[model bindings]
Expand Down Expand Up @@ -151,24 +155,26 @@
:else (remove nil? form)))))
event))

(defn constrain
[model event]
(let [event (-> event
(strip-nils)
(inference-event))]
(cond-> model
(some? event)
(gpm/constrain event
{:operation? operation?
:operands operands
:operator operator
:variable? variable?}))))

(defn mutual-info
[model event-a event-b]
(let [event-a (inference-event event-a)
event-b (inference-event event-b)]
(gpm/mutual-info model event-a event-b)))
(def constrain (cache/lru
(fn constrain*
[model event]
(let [event (-> event
(strip-nils)
(inference-event))]
(cond-> model
(some? event)
(gpm/constrain event
{:operation? operation?
:operands operands
:operator operator
:variable? variable?}))))))

(def mutual-info (cache/lru
(fn mutual-info*
[model event-a event-b]
(let [event-a (inference-event event-a)
event-b (inference-event event-b)]
(gpm/mutual-info model event-a event-b)))))

(defn approx-mutual-info
[model vars-lhs vars-rhs]
Expand Down
21 changes: 21 additions & 0 deletions test/inferenceql/query/cache_test.cljc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
(ns inferenceql.query.cache-test
(:require [clojure.test :refer [deftest is testing]]
[inferenceql.query.cache :as cache]))

(deftest basic-caching
(let [cache-size 2
a (atom 0)
incrementer (fn [_ignored-but-cached-key]
(swap! a inc))
cached-incrementer (cache/lru incrementer cache-size)]

(is (= 1 (cached-incrementer :foo)))
(is (= 2 (cached-incrementer :bar)))

(is (= 1 (cached-incrementer :foo)))
(is (= 2 (cached-incrementer :bar)))

(is (= 3 (cached-incrementer :moop)))

;; cache cleared for :foo
(is (= 4 (cached-incrementer :foo)))))

0 comments on commit 7e9d431

Please sign in to comment.