-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
87 lines (65 loc) · 2.71 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import equinox as eqx
import jax
import jax.numpy as jnp
image_dim = 28
class mnist_model(eqx.Module):
conv1: eqx.nn.Conv2d
conv2: eqx.nn.Conv2d
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
self.conv1 = eqx.nn.Conv2d(1, 32, 3, 1, padding=1, key=key1)
self.conv2 = eqx.nn.Conv2d(32, 64, 3, 1, padding=1, key=key2)
self.linear1 = eqx.nn.Linear(64*(image_dim-1)**2, 128, key=key3)
self.linear2 = eqx.nn.Linear(128, 10, key=key4)
@jax.jit
def __call__(self, x):
x = jax.nn.relu(self.conv1(x))
x = jax.nn.relu(self.conv2(x))
x = eqx.nn.MaxPool2d(kernel_size=2)(x)
x = jnp.ravel(x)
x = jax.nn.relu(self.linear1(x))
x = self.linear2(x)
x = jax.nn.log_softmax(x)
return x
class mnist_unet(eqx.Module):
conv1: eqx.nn.Conv2d
conv2: eqx.nn.Conv2d
conv3: eqx.nn.Conv2d
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
label_embedding: eqx.nn.Embedding
time_embedding: eqx.nn.Embedding
conv4: eqx.nn.ConvTranspose2d
conv5: eqx.nn.ConvTranspose2d
conv6: eqx.nn.ConvTranspose2d
def __init__(self, key):
key1, key2, key3, key4, key5, key6, key7, key8, key9, key10 = jax.random.split(key, 10)
self.conv1 = eqx.nn.Conv2d(1, 16, 3, 1, padding=1, key=key1)
self.conv2 = eqx.nn.Conv2d(16, 32, 3, 1, padding=1, key=key2)
self.conv3 = eqx.nn.Conv2d(32, 64, 3, 1, padding=1, key=key3)
self.linear1 = eqx.nn.Linear(64*(image_dim-1)**2, 128, key=key4)
self.linear2 = eqx.nn.Linear(128, 64*image_dim**2, key=key5)
self.label_embedding = eqx.nn.Embedding(10, 128, key=key6)
self.time_embedding = eqx.nn.Embedding(50, 128, key=key7)
self.conv4 = eqx.nn.ConvTranspose2d(128, 64, 3, 1, padding=1, key=key8)
self.conv5 = eqx.nn.ConvTranspose2d(64, 32, 3, 1, padding=1, key=key9)
self.conv6 = eqx.nn.ConvTranspose2d(32, 1, 3, 1, padding=1, key=key10)
@jax.jit
def __call__(self, x, label, time):
label_encoding = self.label_embedding(label)
time_encoding = self.time_embedding(time)
x = jax.nn.relu(self.conv1(x))
x = jax.nn.relu(self.conv2(x))
x3 = jax.nn.relu(self.conv3(x))
x = eqx.nn.MaxPool2d(kernel_size=2)(x3)
x = jnp.ravel(x)
x = jax.nn.relu(self.linear1(x)+label_encoding+time_encoding)
x = jax.nn.relu(self.linear2(x))
x = jnp.reshape(x, (64, image_dim, image_dim))
x4_input = jnp.concatenate([x, x3])
x = jax.nn.relu(self.conv4(x4_input))
x = jax.nn.relu(self.conv5(x))
x = jax.nn.sigmoid(self.conv6(x))
return x