From 103dbfb7f258a0a3900a81b8e5ff9ec6ed6955ea Mon Sep 17 00:00:00 2001 From: Schaechtle Date: Tue, 2 Apr 2024 15:36:11 -0400 Subject: [PATCH 1/4] feat: (WIP) Add poisson primitive WIP because simulate is not implemented yet. --- .../inference/gpm/primitive_gpms.cljc | 8 ++- .../inference/gpm/primitive_gpms/poisson.cljc | 64 +++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 src/inferenceql/inference/gpm/primitive_gpms/poisson.cljc diff --git a/src/inferenceql/inference/gpm/primitive_gpms.cljc b/src/inferenceql/inference/gpm/primitive_gpms.cljc index 1be5735..24752fe 100644 --- a/src/inferenceql/inference/gpm/primitive_gpms.cljc +++ b/src/inferenceql/inference/gpm/primitive_gpms.cljc @@ -1,7 +1,8 @@ (ns inferenceql.inference.gpm.primitive-gpms (:require [inferenceql.inference.gpm.primitive-gpms.bernoulli :as bernoulli] [inferenceql.inference.gpm.primitive-gpms.categorical :as categorical] - [inferenceql.inference.gpm.primitive-gpms.gaussian :as gaussian])) + [inferenceql.inference.gpm.primitive-gpms.gaussian :as gaussian] + [inferenceql.inference.gpm.primitive-gpms.poisson :as poisson])) (defn primitive? "Checks whether the given GPM is a primitive GPM." @@ -9,7 +10,8 @@ (and (record? stattype) (or (bernoulli/bernoulli? stattype) (categorical/categorical? stattype) - (gaussian/gaussian? stattype)))) + (gaussian/gaussian? stattype) + (poisson/poisson? stattype)))) (defn hyper-grid [stattype data & {:keys [n-grid] :or {n-grid 30}}] @@ -19,6 +21,7 @@ :bernoulli (bernoulli/hyper-grid data n-grid) :categorical (categorical/hyper-grid data n-grid) :gaussian (gaussian/hyper-grid data n-grid) + :poisson (poisson/hyper-grid data n-grid) (throw (ex-info (str "pGPM doesn't exist: " stattype) {:stattype stattype :data data}))))) @@ -39,5 +42,6 @@ :bernoulli (bernoulli/spec->bernoulli var-name :suff-stats suff-stats :hyperparameters hyperparameters) :categorical (categorical/spec->categorical var-name :suff-stats suff-stats :hyperparameters hyperparameters :options options) :gaussian (gaussian/spec->gaussian var-name :suff-stats suff-stats :hyperparameters hyperparameters) + :poisson (poisson/spec->poisson var-name :suff-stats suff-stats :hyperparameters hyperparameters) (throw (ex-info (str "pGPM doesn't exist for var-name: " primitive " for " var-name) {:primitive primitive})))) diff --git a/src/inferenceql/inference/gpm/primitive_gpms/poisson.cljc b/src/inferenceql/inference/gpm/primitive_gpms/poisson.cljc new file mode 100644 index 0000000..2aa8eaf --- /dev/null +++ b/src/inferenceql/inference/gpm/primitive_gpms/poisson.cljc @@ -0,0 +1,64 @@ +(ns inferenceql.inference.gpm.primitive-gpms.poisson + (:require [clojure.math :as math] + [inferenceql.inference.gpm.proto :as gpm.proto] + [inferenceql.inference.primitives :as primitives] + [inferenceql.inference.distributions :as dist] + [inferenceql.inference.utils :as utils])) + +(defn posterior-hypers + [n sum-x a b] + [(+ a sum-x) (+ b n)]) + +(defn calc-log-Z + [a b] + (- (dist/log-gamma a) (* a (math/log b)))) + +(defrecord Poisson [var-name suff-stats hyperparameters] + gpm.proto/GPM + (logpdf [_ targets constraints] + (let [x (get targets var-name) + x' (get constraints var-name) + constrained? (not (nil? x'))] + (cond + (nil? x) 0 + constrained? (if (= x x') 0 ##-Inf) + :else (let [n (:n suff-stats) + sum-x (:sum-x suff-stats) + sum-log-fact (:sum-log-fact suff-stats) + a (:a hyperparameters) + b (:b hyperparameters) + [an bn] (posterior-hypers n sum-x a b) + [am bm] (posterior-hypers (+ n 1) (+ sum-x x) a b) + Zn (calc-log-Z an bn) + Zm (calc-log-Z am bm)] + (- Zm Zn (dist/log-gamma (+ x 1))))))) + + (simulate [this _ _] + (throw (Exception. "Poisson simulate not implemented"))) + + + gpm.proto/Variables + (variables [{:keys [var-name]}] + #{var-name})) + +(defn poisson? + "Checks if the given pGPM is Poisson." + [stattype] + (and (record? stattype) + (instance? Poisson stattype))) + +(defn hyper-grid + "Hyperparameter grid for the Poisson variable, used in column hyperparameter inference + for Column GPMs." + [data n-grid] + (let [grid (utils/log-linspace 1 (count data) n-grid)] + {:alpha grid})) + +(defn spec->poisson + "Casts a CrossCat category spec to a Poisson pGPM. + Requires a variable name, optionally takes by key + sufficient statistics, options, and hyperparameters." + [var-name & {:keys [hyperparameters suff-stats options]}] + (let [suff-stats' (if-not (nil? suff-stats) suff-stats {:n 0 :sum-x 0 :sum-log-fact 0}) + hyperparameters' (if-not (nil? hyperparameters) hyperparameters {:a 1 :b 1})] + (->Poisson var-name suff-stats' hyperparameters'))) From 0b0b6bc724a5682a2f4807050ae278d28cb6f7fd Mon Sep 17 00:00:00 2001 From: Schaechtle Date: Tue, 2 Apr 2024 17:07:33 -0400 Subject: [PATCH 2/4] test: (WIP) test poisson logpdf --- src/inferenceql/inference/gpm.cljc | 2 ++ .../gpm/primitive_gpms/poisson_test.cljc | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+) create mode 100644 test/inferenceql/inference/gpm/primitive_gpms/poisson_test.cljc diff --git a/src/inferenceql/inference/gpm.cljc b/src/inferenceql/inference/gpm.cljc index f61ca4c..6a3ee1b 100644 --- a/src/inferenceql/inference/gpm.cljc +++ b/src/inferenceql/inference/gpm.cljc @@ -12,6 +12,7 @@ [inferenceql.inference.gpm.primitive-gpms.bernoulli :as bernoulli] [inferenceql.inference.gpm.primitive-gpms.categorical :as categorical] [inferenceql.inference.gpm.primitive-gpms.gaussian :as gaussian] + [inferenceql.inference.gpm.primitive-gpms.poisson :as poisson] [inferenceql.inference.gpm.proto :as gpm-proto] [inferenceql.inference.gpm.view :as view])) @@ -152,6 +153,7 @@ 'inferenceql.inference.gpm.primitive_gpms.bernoulli.Bernoulli bernoulli/map->Bernoulli 'inferenceql.inference.gpm.primitive_gpms.categorical.Categorical categorical/map->Categorical 'inferenceql.inference.gpm.primitive_gpms.gaussian.Gaussian gaussian/map->Gaussian + 'inferenceql.inference.gpm.primitive_gpms.poisson.Poisson poisson/map->Poisson 'inferenceql.inference.gpm.view.View view/map->View}) (defn as-gpm diff --git a/test/inferenceql/inference/gpm/primitive_gpms/poisson_test.cljc b/test/inferenceql/inference/gpm/primitive_gpms/poisson_test.cljc new file mode 100644 index 0000000..cb0b108 --- /dev/null +++ b/test/inferenceql/inference/gpm/primitive_gpms/poisson_test.cljc @@ -0,0 +1,22 @@ +(ns inferenceql.inference.gpm.primitive-gpms.poisson-test + (:require [clojure.math :as math] + [clojure.test :as test :refer [deftest is]] + [inferenceql.inference.gpm :as gpm] + [inferenceql.inference.gpm.primitive-gpms.poisson :as poisson] + [inferenceql.inference.gpm.proto :as gpm.proto] + [inferenceql.inference.utils :as utils])) + +(def var-name "poisson") + +(def poisson-pgpm + (let [suff-stats {:n 0 :sum-x 0 :sum-log-fact 0} + hypers {:a 2 :b 2}] + (poisson/spec->poisson var-name :suff-stats suff-stats :hyperparameters hypers))) + +(deftest logpdf + (let [targets {"poisson" 0} + constraints {"poisson" 1}] + (is (= 1.0 (math/exp (gpm.proto/logpdf poisson-pgpm {} {})))) + (is (= 1.0 (math/exp (gpm.proto/logpdf poisson-pgpm targets targets)))) + (is (= ##-Inf (gpm.proto/logpdf poisson-pgpm targets constraints)))) + (is (utils/almost-equal? -1.2163953243244932 (gpm.proto/logpdf poisson-pgpm {"poisson" 1} {}) utils/relerr 1E-8))) From 24ab3da56337adb5dc578cecf7333f875c38c7b1 Mon Sep 17 00:00:00 2001 From: Schaechtle Date: Tue, 2 Apr 2024 17:24:07 -0400 Subject: [PATCH 3/4] WIP incorporate --- .../inference/gpm/primitive_gpms/poisson.cljc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/inferenceql/inference/gpm/primitive_gpms/poisson.cljc b/src/inferenceql/inference/gpm/primitive_gpms/poisson.cljc index 2aa8eaf..3c3fcc7 100644 --- a/src/inferenceql/inference/gpm/primitive_gpms/poisson.cljc +++ b/src/inferenceql/inference/gpm/primitive_gpms/poisson.cljc @@ -36,6 +36,19 @@ (simulate [this _ _] (throw (Exception. "Poisson simulate not implemented"))) + gpm.proto/Incorporate + (incorporate [this values] + (let [x (get values var-name)] + (assoc this :suff-stats (-> suff-stats + (update :n inc) + (update :sum-x #(+ % x)) + (update :sum-log-fact #(+ % (log-gamma (+ x 1)))))))) + (unincorporate [this values] + (let [x (get values var-name)] + (assoc this :suff-stats (-> suff-stats + (update :n dec) + (update :sum-x #(- % x)) + (update :sum-log-fact #(- % (log-gamma (+ x 1)))))))) gpm.proto/Variables (variables [{:keys [var-name]}] From 95f59a4ea5b0ef97cd928d1a586febcc5ac0e3e0 Mon Sep 17 00:00:00 2001 From: Schaechtle Date: Tue, 2 Apr 2024 17:26:33 -0400 Subject: [PATCH 4/4] WIP --- src/inferenceql/inference/gpm/primitive_gpms/poisson.cljc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/inferenceql/inference/gpm/primitive_gpms/poisson.cljc b/src/inferenceql/inference/gpm/primitive_gpms/poisson.cljc index 3c3fcc7..377f235 100644 --- a/src/inferenceql/inference/gpm/primitive_gpms/poisson.cljc +++ b/src/inferenceql/inference/gpm/primitive_gpms/poisson.cljc @@ -42,13 +42,13 @@ (assoc this :suff-stats (-> suff-stats (update :n inc) (update :sum-x #(+ % x)) - (update :sum-log-fact #(+ % (log-gamma (+ x 1)))))))) + (update :sum-log-fact #(+ % (dist/log-gamma (+ x 1)))))))) (unincorporate [this values] (let [x (get values var-name)] (assoc this :suff-stats (-> suff-stats (update :n dec) (update :sum-x #(- % x)) - (update :sum-log-fact #(- % (log-gamma (+ x 1)))))))) + (update :sum-log-fact #(- % (dist/log-gamma (+ x 1)))))))) gpm.proto/Variables (variables [{:keys [var-name]}]