-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetaprob.hs
378 lines (330 loc) · 12.9 KB
/
metaprob.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
import Control.Monad.Random
--
-- METAPROB
--
-- The first section this code, INTERFACE, shows how Metaprob rests
-- atop an abstract theory of distributions (and traces). It
-- clarifies which aspects of distributions are required to define
-- generative functions, the key transformations among them, and the
-- meta-circular evaluator.
--
-- The second section, EXAMPLES, provides implementations of the
-- interface. Our implementations of distributions include both
-- measure-theoretic and sampler-theoretic ones. We set up some
-- sample computations too.
--
-- In all cases, the goal is conceptual clarity, ignoring
-- computational efficiency.
-- This file is intended to be loaded in ghci via
-- > :l metaprob.hs
-- followed by the suggestions under the "try:"s comments in the
-- examples below.
-- TODOS:
-- * Probabilistic programs of type "A -> B";
-- currently just type A ~= "() -> A" is supported
-- * Some notion of continuous distributions
--
-- INTERFACE
--
-- Here is some of our notation compared with the Metaprob paper:
-- * The space K of trace keys is represented by `key`.
-- * The type A of elements we wish to compute about is `a`.
-- We assume that A is essentially a finite discrete set,
-- though the code might sometimes work more generally.
-- * We fix some parametric type f(A) in the world of the paper.
-- * Key examples: f(A) = A, and f(A) = A x Trace.
-- The rest of the correspondence is documented as we go.
-- Describes
-- * distributions R(f(A)) = `distr`
-- on
-- * elements f(A) = `elt a`
class Distr distr where
pushForward :: (elt a -> elt' a) -> distr elt a -> distr elt' a
dirac :: elt a -> distr elt a
-- The literature also calls this a "compound" distribution:
mixture :: Eq (elt a) =>
distr elt a -> (elt a -> distr elt a) -> distr elt a
-- A distribution is morally just a monad applied to the element type:
newtype MDistr m elt a = MDistr { mDistr :: m (elt a) }
instance Monad m => Distr (MDistr m) where
pushForward f = MDistr . fmap f . mDistr
dirac = MDistr . return
mixture s1 s2 = MDistr $ (mDistr s1) >>= (mDistr . s2)
-- Defines generative functions P(f(A)) in terms of f(A) and R(f(A))
data GenFn key distr elt a =
Sample key (distr elt a) (elt a -> Double) |
Ret (elt a) |
Semicolon (GenFn key distr elt a) (elt a -> GenFn key distr elt a)
-- Defines the "Gen" interpretation [[ ]]_g from P(f(A)) to R(f(A))
runGen :: (Distr distr, Eq (elt a)) =>
GenFn key distr elt a -> distr elt a
runGen (Sample k sample score) = sample
runGen (Ret e) = dirac e
runGen (Semicolon p1 p2) = mixture (runGen p1) (runGen . p2)
data TValue a = TNone | Traced a | Intervene a | Observe a
deriving (Eq, Show)
-- In the context of `class Trace`,
-- * `elt a` corresponds to f(A) = A,
-- * `traced a` corresponds to f(A) = A x Trace, and
-- * `wtraced a` corresponds to f(A) = A x Trace x R^+.
-- And, yes, `trace` corresponds to Trace.
-- Generative functions of these types are then related by tracing and
-- infer, below.
class Trace trace key elt traced wtraced |
trace -> key elt traced wtraced,
traced -> trace, wtraced -> traced where
getTrace :: trace a -> [(key, TValue (elt a))]
emptyTrace :: trace a
kvTrace :: key -> TValue (elt a) -> trace a
appendTrace :: trace a -> trace a -> trace a
getTraced :: traced a -> (elt a, trace a)
makeTraced :: (elt a, trace a) -> traced a
getWTraced :: wtraced a -> (elt a, trace a, Double)
makeWTraced :: (elt a, trace a, Double) -> wtraced a
traceValue :: (Trace trace key elt traced wtraced, Eq key) =>
trace a -> key -> TValue (elt a)
traceValue t k = let res = filter ((== k) . fst) (getTrace t) in
if null res then TNone else snd $ head res
extendByZero :: Trace trace key elt traced wtraced =>
(elt a -> Double) -> traced a -> Double
extendByZero f xt = let (x, t) = getTraced xt in
if null $ getTrace t then f x else 0.0
extendByZeroW :: Trace trace key elt traced wtraced =>
(elt a -> Double) -> wtraced a -> Double
extendByZeroW f xtw = let (x, t, _) = getWTraced xtw in
if null $ getTrace t then f x else 0.0
-- Defines tracing from P(A) to P(A x Tracing)
tracing :: (Trace trace key elt traced wtraced, Distr distr) =>
GenFn key distr elt a ->
GenFn key distr traced a
tracing (Sample k dist score) =
Semicolon
(Sample k
(pushForward (\x -> makeTraced (x, emptyTrace)) dist)
(extendByZero score))
(\xt -> let (x, _) = getTraced xt in
Ret $ makeTraced (x, kvTrace k (Traced x)))
tracing (Ret x) = Ret $ makeTraced (x, emptyTrace)
tracing (Semicolon p1 p2) =
Semicolon
(tracing p1)
(\xs -> let (x, s) = getTraced xs in
Semicolon
(tracing (p2 x))
(\yt -> let (y, t) = getTraced yt in
Ret $ makeTraced (y, appendTrace s t)))
-- Defines infer_t from P(A) to P(A x Tracing x R^+)
infer :: (Trace trace key elt traced wtraced, Distr distr,
Eq key) =>
trace a -> GenFn key distr elt a ->
GenFn key distr wtraced a
infer tr (Sample k dist score) =
Semicolon
(case traceValue tr k of
Observe x -> Ret $ makeWTraced (x, emptyTrace, score x)
Intervene x -> Ret $ makeWTraced (x, emptyTrace, 1.0)
_ ->
Semicolon
(Sample k
(pushForward
(\x -> makeWTraced (x, emptyTrace, 1.0))
dist)
(extendByZeroW score))
(\xtw -> let (x, _, _) = getWTraced xtw in
Ret $ makeWTraced (x, emptyTrace, 1.0)))
(\ytw -> let (y, _, w) = getWTraced ytw in
Ret $ makeWTraced (y, kvTrace k (Traced y), w))
infer tr (Ret x) = Ret $ makeWTraced (x, emptyTrace, 1.0)
infer tr (Semicolon p1 p2) =
Semicolon
(infer tr p1)
(\xsv -> let (x, s, v) = getWTraced xsv in
Semicolon
(infer tr (p2 x))
(\ytw -> let (y, t, w) = getWTraced ytw in
Ret $ makeWTraced
(y, appendTrace s t, (v * w))))
--
-- EXAMPLES
--
--
-- Default implementation of shared items
--
-- Trace-related things
newtype MyElt a = MyElt { myElt :: a } deriving Eq
instance Show a => Show (MyElt a) where
show (MyElt a) = show a
newtype MyTrace key a =
MyTrace { myTrace :: [(key, TValue (MyElt a))] }
deriving Eq
instance (Show key, Show a) => Show (MyTrace key a) where
show (MyTrace t) = "Trace " ++ show t
newtype MyTraced key a =
MyTraced { myTraced :: (MyElt a, MyTrace key a) }
deriving Eq
instance (Show key, Show a) => Show (MyTraced key a) where
show (MyTraced (MyElt x, MyTrace t)) = show (x, t)
newtype MyWTraced key a =
MyWTraced { myWTraced :: (MyElt a, MyTrace key a, Double) }
deriving Eq
instance (Show key, Show a) => Show (MyWTraced key a) where
show (MyWTraced (MyElt x, MyTrace t, w)) = show (x, t, w)
instance Trace (MyTrace key)
key
MyElt
(MyTraced key)
(MyWTraced key) where
getTrace = myTrace
emptyTrace = MyTrace []
kvTrace k v = MyTrace [(k, v)]
appendTrace t1 t2 = MyTrace (myTrace t1 ++ myTrace t2)
getTraced = myTraced
makeTraced = MyTraced
getWTraced = myWTraced
makeWTraced = MyWTraced
-- Example/omputation-related things
data MySet = Tails | Heads deriving (Show, Eq)
myNot Tails = Heads
myNot Heads = Tails
input :: GenFn key distr MyElt MySet
input = Ret $ MyElt Tails
input' :: Distr distr => GenFn Int distr MyElt MySet
input' = Sample (0 :: Int)
(dirac $ MyElt Tails)
(\(MyElt x) -> if x == Tails then 1.0 else 0.0)
drunkenNotList :: (Fractional t) => MySet -> [(MyElt MySet, t)]
drunkenNotList x = [(MyElt $ myNot x, 0.9), (MyElt x, 0.1)]
drunkenNotScore :: MySet -> MyElt MySet -> Double
drunkenNotScore x (MyElt y) =
if y == x then 0.1
else if y == myNot x then 0.9
else 0.0 -- Not reachable but syntactically required
drunkenNot :: (MySet -> distr MyElt MySet) -> key ->
MyElt MySet -> GenFn key distr MyElt MySet
drunkenNot d k (MyElt x) = Sample k (d x) (drunkenNotScore x)
tObs = MyTrace [(0 :: Int, Observe (MyElt Heads))]
tInt = MyTrace [(0 :: Int, Intervene (MyElt Heads))]
--
-- Example 1:
-- Measure-theoretically, tracking point masses in the support of the
-- distribution
--
squashDiracs :: Eq a => [(a, Double)] -> [(a, Double)]
squashDiracs [] = []
squashDiracs ((x, v) : xvs) =
let yws = squashDiracs xvs
hit = filter (\yw -> fst yw == x) yws
miss = filter (\yw -> fst yw /= x) yws in
if null hit then (x, v) : yws
else (x, v + (snd $ head hit)) : miss
-- The latent monad in this example is similar to the common
-- list/"nondeterminism" monad, but here the list values are distinct,
-- and each value is paired with a weight in R^+. Whereas the list
-- monad's bind fans out over all possibilities and concatenates, this
-- monad's bind fans out, multiplying weights as it goes, and when it
-- concatenates it collects duplicate values while adding their
-- weights (the effect of `squashDiracs`).
-- The reason we do not define a monad instance is that
-- `squashDiracs`, and therefore the bind operation, requires the
-- condition `Eq (elt a)`, which monads do not let us impose.
newtype Diracs elt a = Diracs { diracs :: [(elt a, Double)] }
instance Show (elt a) => Show (Diracs elt a) where
show (Diracs d) = "Diracs" ++ concat (map (\l -> "\n " ++ show l) d)
instance Distr Diracs where
pushForward f = Diracs . map (\(x, u) -> (f x, u)) . diracs
dirac x = Diracs [(x, 1.0)]
mixture d1 d2 = Diracs . squashDiracs . concat $
map (\xu -> let (x, u) = xu in
map (\yv -> let (y, v) = yv in (y, u * v))
(diracs $ d2 x))
(diracs d1)
tracing1 = tracing
:: GenFn Int Diracs MyElt MySet ->
GenFn Int Diracs (MyTraced Int) MySet
infer1 = infer
:: MyTrace Int MySet -> GenFn Int Diracs MyElt MySet ->
GenFn Int Diracs (MyWTraced Int) MySet
input1 = input :: GenFn Int Diracs MyElt MySet
input1' = input' :: GenFn Int Diracs MyElt MySet
drunkenNot1 = drunkenNot (\x -> Diracs $ drunkenNotList x)
computed1 =
Semicolon (Semicolon input1 (drunkenNot1 1)) (drunkenNot1 2)
computed1' =
Semicolon (Semicolon input1' (drunkenNot1 1)) (drunkenNot1 2)
-- try:
-- > runGen computed1
-- > runGen computed1'
-- > runGen $ tracing1 computed1
-- > runGen $ tracing1 computed1'
-- > runGen $ infer1 tObs computed1'
-- ...
--
-- Example 2:
-- Executing determinstic sampling prodedures. This is a conceptual
-- warm-up for Example 3 below.
--
-- Sampling procedures are functions with one value, `() -> elt a`,
-- otherwise written `(->) () (elt a)`:
type DSampler = MDistr ((->) ())
instance Show (elt a) => Show (DSampler elt a) where
show s = "() -> " ++ show (mDistr s ())
tracing2 = tracing
:: GenFn Int DSampler MyElt MySet ->
GenFn Int DSampler (MyTraced Int) MySet
infer2 = infer
:: MyTrace Int MySet -> GenFn Int DSampler MyElt MySet ->
GenFn Int DSampler (MyWTraced Int) MySet
input2 = input :: GenFn Int DSampler MyElt MySet
input2' = input' :: GenFn Int DSampler MyElt MySet
-- This is especially not stochastic:
drunkenNot2 = drunkenNot (\x -> MDistr (\_ -> MyElt $ myNot x))
computed2 =
Semicolon (Semicolon input2 (drunkenNot2 1)) (drunkenNot2 2)
computed2' =
Semicolon (Semicolon input2' (drunkenNot2 1)) (drunkenNot2 2)
-- try:
-- > runGen computed2
-- > runGen computed2'
-- > runGen $ tracing2 computed2
-- > runGen $ tracing2 computed2'
-- > runGen $ infer2 tObs computed2'
-- ...
--
-- Example 3:
-- Executing randomized/stochastic sampling procedures.
--
-- The Rand monad carries along the state of a pseudorandom number
-- generator seed for us:
type RSampler g = MDistr (Rand g)
-- This plays the role of the Show instance:
rsample :: RSampler StdGen elt a -> IO (elt a)
rsample = evalRandIO . mDistr
tracing3 = tracing
:: GenFn Int (RSampler StdGen) MyElt MySet ->
GenFn Int (RSampler StdGen) (MyTraced Int) MySet
infer3 = infer
:: MyTrace Int MySet -> GenFn Int (RSampler StdGen) MyElt MySet ->
GenFn Int (RSampler StdGen) (MyWTraced Int) MySet
input3 = input :: GenFn Int (RSampler StdGen) MyElt MySet
input3' = input' :: GenFn Int (RSampler StdGen) MyElt MySet
drunkenNot3 =
drunkenNot (\x -> MDistr . fromList $ drunkenNotList x)
computed3 =
Semicolon (Semicolon input3 (drunkenNot3 1)) (drunkenNot3 2)
computed3' =
Semicolon (Semicolon input3' (drunkenNot3 1)) (drunkenNot3 2)
test3 n = do
flips <- sequence . (replicate n) . rsample . runGen $ computed3
let tails = filter (== MyElt Tails) flips
-- Output will be ~0.82:
putStrLn . show $ fromIntegral (length tails) / fromIntegral n
-- try:
-- > rsample $ runGen computed3
-- > rsample $ runGen computed3'
-- > rsample . runGen $ tracing3 computed3
-- > rsample . runGen $ tracing3 computed3'
-- > rsample . runGen $ infer3 tObs computed3'
-- ...
-- > test3 100