From 68000ea5efbc7d3a1dd32160c9a348bf921ab2f1 Mon Sep 17 00:00:00 2001 From: Matthew Davidson Date: Thu, 14 Mar 2024 21:48:51 +0700 Subject: [PATCH] refactor: Add caching, tests, and cache scalar pdf/prob/condition/constrain/mutual-info fns in clj Also bump up inferenceql.inference version to avoid arrow constructor bug --- deps-lock.json | 10 ++++ deps.edn | 3 +- package.json | 1 + src/inferenceql/query/cache.cljc | 20 +++++++ src/inferenceql/query/scalar.cljc | 81 +++++++++++++++----------- test/inferenceql/query/cache_test.cljc | 37 ++++++++++++ 6 files changed, 116 insertions(+), 36 deletions(-) create mode 100644 src/inferenceql/query/cache.cljc create mode 100644 test/inferenceql/query/cache_test.cljc diff --git a/deps-lock.json b/deps-lock.json index c254b13..8d51e3b 100644 --- a/deps-lock.json +++ b/deps-lock.json @@ -1547,6 +1547,16 @@ "mvn-repo": "https://repo.maven.apache.org/maven2/", "hash": "sha256-hML6t6Mso8HkDEGm7Mm9U26UezBYDne41dwjKjSSXqw=" }, + { + "mvn-path": "org/clojure/core.memoize/1.0.257/core.memoize-1.0.257.jar", + "mvn-repo": "https://repo.maven.apache.org/maven2/", + "hash": "sha256-mg6RgW4hp3SY7+3r1HrUvcL7+X+dvEm8nZWU4gEkbpY=" + }, + { + "mvn-path": "org/clojure/core.memoize/1.0.257/core.memoize-1.0.257.pom", + "mvn-repo": "https://repo.maven.apache.org/maven2/", + "hash": "sha256-3QQaWFudj1eN30s82rhS8/XdKajjNl4d1ehftl4/c9w=" + }, { "mvn-path": "org/clojure/core.rrb-vector/0.0.11/core.rrb-vector-0.0.11.jar", "mvn-repo": "https://repo.maven.apache.org/maven2/", diff --git a/deps.edn b/deps.edn index 348f68b..add75f0 100644 --- a/deps.edn +++ b/deps.edn @@ -19,7 +19,8 @@ ring-cors/ring-cors {:mvn/version "0.1.13"} ring/ring-core {:mvn/version "1.9.5"} ring/ring-jetty-adapter {:mvn/version "1.9.5"} - tech.tablesaw/tablesaw-core {:mvn/version "0.43.1"}} + tech.tablesaw/tablesaw-core {:mvn/version "0.43.1"} + org.clojure/core.memoize {:mvn/version "1.0.257"}} :paths ["src" "resources"] :aliases {:test {:extra-paths ["test"] :extra-deps {com.gfredericks/test.chuck {:mvn/version "0.2.13"} diff --git a/package.json b/package.json index 7ea8601..9468540 100644 --- a/package.json +++ b/package.json @@ -25,5 +25,6 @@ "shadow-cljs": "^2.27.5" }, "dependencies": { + "memoizee": "^0.4.15" } } diff --git a/src/inferenceql/query/cache.cljc b/src/inferenceql/query/cache.cljc new file mode 100644 index 0000000..d816062 --- /dev/null +++ b/src/inferenceql/query/cache.cljc @@ -0,0 +1,20 @@ +(ns inferenceql.query.cache + "For caching expensive results." + (:require #?(:clj [clojure.core.memoize :as memo] + :cljs ["memoizee" :as memoizee]) + [clojure.string :as string])) + +(def default-threshold 100) + + +(defn lru + "Memoizes a fn with a least-recently-used eviction policy. + + After the number of cached results exceeds the threshold, the + least-recently-used ones will be evicted." + ([f] + (lru f default-threshold)) + ([f lru-threshold] + #?(:clj (memo/lru f :lru/threshold lru-threshold) + :cljs (memoizee f #js {"max" lru-threshold + "normalizer" js/JSON.stringify})))) diff --git a/src/inferenceql/query/scalar.cljc b/src/inferenceql/query/scalar.cljc index 7c890c7..a3a8594 100644 --- a/src/inferenceql/query/scalar.cljc +++ b/src/inferenceql/query/scalar.cljc @@ -8,6 +8,7 @@ [inferenceql.inference.approximate :as approx] [inferenceql.inference.gpm :as gpm] ;; [inferenceql.inference.search.crosscat :as crosscat] + [inferenceql.query.cache :as cache] #?(:clj [inferenceql.query.generative-table :as generative-table]) [inferenceql.query.literal :as literal] [inferenceql.query.parser.tree :as tree] @@ -112,23 +113,29 @@ :env m})) result))) -(defn prob - [model event] - (let [event (inference-event event)] - (math/exp (gpm/logprob model event)))) - -(defn pdf - [model event] - (let [event (update-keys event str)] - (math/exp (gpm/logpdf model event {})))) - -(defn condition - [model conditions] - (let [conditions (-> (medley/filter-vals some? conditions) - (update-keys str))] - (cond-> model - (seq conditions) - (gpm/condition conditions)))) +(def prob + (cache/lru + (fn prob* + [model event] + (let [event (inference-event event)] + (math/exp (gpm/logprob model event)))))) + +(def pdf + (cache/lru + (fn pdf* + [model event] + (let [event (update-keys event str)] + (math/exp (gpm/logpdf model event {})))))) + +(def condition + (cache/lru + (fn condition* + [model conditions] + (let [conditions (-> (medley/filter-vals some? conditions) + (update-keys str))] + (cond-> model + (seq conditions) + (gpm/condition conditions)))))) (defn condition-all [model bindings] @@ -172,24 +179,28 @@ :else (remove nil? form))))) event)) -(defn constrain - [model event] - (let [event (-> event - (strip-nils) - (inference-event))] - (cond-> model - (some? event) - (gpm/constrain event - {:operation? operation? - :operands operands - :operator operator - :variable? variable?})))) - -(defn mutual-info - [model event-a event-b] - (let [event-a (inference-event event-a) - event-b (inference-event event-b)] - (gpm/mutual-info model event-a event-b))) +(def constrain + (cache/lru + (fn constrain* + [model event] + (let [event (-> event + (strip-nils) + (inference-event))] + (cond-> model + (some? event) + (gpm/constrain event + {:operation? operation? + :operands operands + :operator operator + :variable? variable?})))))) + +(def mutual-info + (cache/lru + (fn mutual-info* + [model event-a event-b] + (let [event-a (inference-event event-a) + event-b (inference-event event-b)] + (gpm/mutual-info model event-a event-b))))) (defn approx-mutual-info [model vars-lhs vars-rhs] diff --git a/test/inferenceql/query/cache_test.cljc b/test/inferenceql/query/cache_test.cljc new file mode 100644 index 0000000..6c8c202 --- /dev/null +++ b/test/inferenceql/query/cache_test.cljc @@ -0,0 +1,37 @@ +(ns inferenceql.query.cache-test + (:require [clojure.test :refer [deftest is testing]] + [inferenceql.query.cache :as cache])) + +(deftest basic-caching + (let [cache-size 2 + a (atom 0) + incrementer (fn [_ignored-but-cached-key] + (swap! a inc)) + cached-incrementer (cache/lru incrementer cache-size)] + + (is (= 1 (cached-incrementer :foo))) + (is (= 2 (cached-incrementer :bar))) + + (is (= 1 (cached-incrementer :foo))) + (is (= 2 (cached-incrementer :bar))) + + (is (= 3 (cached-incrementer :moop))) + + ;; cache cleared for :foo + (is (= 4 (cached-incrementer :foo))))) + +(deftest disambiguate-between-0-and-nil + (let [cache-size 1000 + englishize (fn [x] + (case x + 0 "zero" + nil "nil" + "other")) + cached-englishize (cache/lru englishize cache-size)] + ;; Add them both. + (is (= "zero" (cached-englishize 0))) + (is (= "nil" (cached-englishize nil))) + + ;; Check that they return the correct values. + (is (= "zero" (cached-englishize 0))) + (is (= "nil" (cached-englishize nil)))))