Skip to content

Commit

Permalink
refactor: Add caching, tests, and cache scalar pdf/prob/condition/con…
Browse files Browse the repository at this point in the history
…strain/mutual-info fns in clj

Also bump up inferenceql.inference version to avoid arrow constructor bug
  • Loading branch information
KingMob committed Mar 22, 2024
1 parent 4805515 commit e3842ca
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 39 deletions.
19 changes: 18 additions & 1 deletion deps-lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,19 @@
"git-dir": "https/github.com/inferenceql/inferenceql.gpm.sppl",
"hash": "sha256-4EJqRFT3/95kyriEYMIZxn+TavhMFPE7Yxv6lC1s0lI="
},
{
"lib": "io.github.inferenceql/inferenceql.inference",
"url": "https://github.com/inferenceql/inferenceql.inference.git",
"rev": "c3cef474ba964a37fc2e5ff667055f5b77e12c45",
"git-dir": "https/github.com/inferenceql/inferenceql.inference",
"hash": "sha256-gjVmfbcAEeb0uHJTOUrFasDhZY+c4wI8M5W+itJynz4="
},
{
"lib": "io.github.inferenceql/inferenceql.inference",
"url": "https://github.com/inferenceql/inferenceql.inference.git",
"rev": "40e77dedf680b7936ce988b66186a86f5c4db6a5",
"git-dir": "https/github.com/inferenceql/inferenceql.inference",
"hash": "sha256-GlcGgWVVxyuP1oEGWdV+lBM11vgy1jUJOhkQzCFRTqk="
"hash": "sha256-UtH6AbOhOXzD0hhIYJRrS8k2NQwPj3ZzZp3HNUvevME="
},
{
"lib": "io.github.probcomp/metaprob",
Expand Down Expand Up @@ -1669,6 +1676,16 @@
"mvn-repo": "https://repo.maven.apache.org/maven2/",
"hash": "sha256-hML6t6Mso8HkDEGm7Mm9U26UezBYDne41dwjKjSSXqw="
},
{
"mvn-path": "org/clojure/core.memoize/1.0.257/core.memoize-1.0.257.jar",
"mvn-repo": "https://repo.maven.apache.org/maven2/",
"hash": "sha256-mg6RgW4hp3SY7+3r1HrUvcL7+X+dvEm8nZWU4gEkbpY="
},
{
"mvn-path": "org/clojure/core.memoize/1.0.257/core.memoize-1.0.257.pom",
"mvn-repo": "https://repo.maven.apache.org/maven2/",
"hash": "sha256-3QQaWFudj1eN30s82rhS8/XdKajjNl4d1ehftl4/c9w="
},
{
"mvn-path": "org/clojure/core.rrb-vector/0.0.11/core.rrb-vector-0.0.11.jar",
"mvn-repo": "https://repo.maven.apache.org/maven2/",
Expand Down
5 changes: 3 additions & 2 deletions deps.edn
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
metosin/muuntaja {:mvn/version "0.6.8"}
net.cgrand/macrovich {:mvn/version "0.2.1"}
net.cgrand/xforms {:mvn/version "0.19.2"}
io.github.inferenceql/inferenceql.inference {:git/sha "40e77dedf680b7936ce988b66186a86f5c4db6a5"}
io.github.inferenceql/inferenceql.inference {:git/sha "c3cef474ba964a37fc2e5ff667055f5b77e12c45"}
io.github.inferenceql/inferenceql.gpm.sppl {:git/sha "f745dbb0c17c1a9da21488b7bd3098338ab7d7a2"}
io.github.clojure/tools.build {:git/sha "8e78bccc35116f6b6fc0bf0c125dba8b8db8da6b"}
org.babashka/sci {:mvn/version "0.3.32"}
Expand All @@ -20,7 +20,8 @@
ring-cors/ring-cors {:mvn/version "0.1.13"}
ring/ring-core {:mvn/version "1.9.5"}
ring/ring-jetty-adapter {:mvn/version "1.9.5"}
tech.tablesaw/tablesaw-core {:mvn/version "0.43.1"}}
tech.tablesaw/tablesaw-core {:mvn/version "0.43.1"}
org.clojure/core.memoize {:mvn/version "1.0.257"}}
:paths ["src" "resources"]
:aliases {:test {:extra-paths ["test"]
:extra-deps {com.gfredericks/test.chuck {:mvn/version "0.2.13"}
Expand Down
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
{
"name": "inferenceql.query",
"version": "1.0.0",
"description": "![tests](https://github.com/OpenIQL/inferenceql.query/workflows/tests/badge.svg) ![linter](https://github.com/OpenIQL/inferenceql.query/workflows/linter/badge.svg)",
Expand All @@ -25,5 +25,6 @@
"shadow-cljs": "^2.27.5"
},
"dependencies": {
"memoizee": "^0.4.15"
}
}
20 changes: 20 additions & 0 deletions src/inferenceql/query/cache.cljc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
(ns inferenceql.query.cache
"For caching expensive results."
(:require #?(:clj [clojure.core.memoize :as memo]
:cljs [memoizee :as memoizee])
[clojure.string :as string]))

(def default-threshold 100)


(defn lru
"Memoizes a fn with a least-recently-used eviction policy.
After the number of cached results exceeds the threshold, the
least-recently-used ones will be evicted."
([f]
(lru f default-threshold))
([f lru-threshold]
#?(:clj (memo/lru f :lru/threshold lru-threshold)
:cljs (memoizee f #js {"max" lru-threshold
"normalizer" js/JSON.stringify}))))
81 changes: 46 additions & 35 deletions src/inferenceql/query/scalar.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
[inferenceql.inference.approximate :as approx]
[inferenceql.inference.gpm :as gpm]
;; [inferenceql.inference.search.crosscat :as crosscat]
[inferenceql.query.cache :as cache]
#?(:clj [inferenceql.query.generative-table :as generative-table])
[inferenceql.query.literal :as literal]
[inferenceql.query.parser.tree :as tree]
Expand Down Expand Up @@ -112,23 +113,29 @@
:env m}))
result)))

(defn prob
[model event]
(let [event (inference-event event)]
(math/exp (gpm/logprob model event))))

(defn pdf
[model event]
(let [event (update-keys event str)]
(math/exp (gpm/logpdf model event {}))))

(defn condition
[model conditions]
(let [conditions (-> (medley/filter-vals some? conditions)
(update-keys str))]
(cond-> model
(seq conditions)
(gpm/condition conditions))))
(def prob
(cache/lru
(fn prob*
[model event]
(let [event (inference-event event)]
(math/exp (gpm/logprob model event))))))

(def pdf
(cache/lru
(fn pdf*
[model event]
(let [event (update-keys event str)]
(math/exp (gpm/logpdf model event {}))))))

(def condition
(cache/lru
(fn condition*
[model conditions]
(let [conditions (-> (medley/filter-vals some? conditions)
(update-keys str))]
(cond-> model
(seq conditions)
(gpm/condition conditions))))))

(defn condition-all
[model bindings]
Expand Down Expand Up @@ -172,24 +179,28 @@
:else (remove nil? form)))))
event))

(defn constrain
[model event]
(let [event (-> event
(strip-nils)
(inference-event))]
(cond-> model
(some? event)
(gpm/constrain event
{:operation? operation?
:operands operands
:operator operator
:variable? variable?}))))

(defn mutual-info
[model event-a event-b]
(let [event-a (inference-event event-a)
event-b (inference-event event-b)]
(gpm/mutual-info model event-a event-b)))
(def constrain
(cache/lru
(fn constrain*
[model event]
(let [event (-> event
(strip-nils)
(inference-event))]
(cond-> model
(some? event)
(gpm/constrain event
{:operation? operation?
:operands operands
:operator operator
:variable? variable?}))))))

(def mutual-info
(cache/lru
(fn mutual-info*
[model event-a event-b]
(let [event-a (inference-event event-a)
event-b (inference-event event-b)]
(gpm/mutual-info model event-a event-b)))))

(defn approx-mutual-info
[model vars-lhs vars-rhs]
Expand Down
37 changes: 37 additions & 0 deletions test/inferenceql/query/cache_test.cljc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
(ns inferenceql.query.cache-test
(:require [clojure.test :refer [deftest is testing]]
[inferenceql.query.cache :as cache]))

(deftest basic-caching
(let [cache-size 2
a (atom 0)
incrementer (fn [_ignored-but-cached-key]
(swap! a inc))
cached-incrementer (cache/lru incrementer cache-size)]

(is (= 1 (cached-incrementer :foo)))
(is (= 2 (cached-incrementer :bar)))

(is (= 1 (cached-incrementer :foo)))
(is (= 2 (cached-incrementer :bar)))

(is (= 3 (cached-incrementer :moop)))

;; cache cleared for :foo
(is (= 4 (cached-incrementer :foo)))))

(deftest disambiguate-between-0-and-nil
(let [cache-size 1000
englishize (fn [x]
(case x
0 "zero"
nil "nil"
"other"))
cached-englishize (cache/lru englishize cache-size)]
;; Add them both.
(is (= "zero" (cached-englishize 0)))
(is (= "nil" (cached-englishize nil)))

;; Check that they return the correct values.
(is (= "zero" (cached-englishize 0)))
(is (= "nil" (cached-englishize nil)))))

0 comments on commit e3842ca

Please sign in to comment.