Skip to content

Commit

Permalink
update vae straight through
Browse files Browse the repository at this point in the history
  • Loading branch information
DorinDaniil committed Nov 5, 2024
1 parent 0bea1e0 commit 22e3d64
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 13 deletions.
22 changes: 10 additions & 12 deletions demo/vae_straight_through_bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,32 +62,30 @@ def decode(self, z):
return torch.sigmoid(self.fc4(h3))

def forward(self, x, hard=False):
a = self.encode(x.view(-1, 784))
a = self.encode(x.view(-1, 784)).float()
q_z = StraightThroughBernoulli(a)
z = q_z.rsample() # sample with reparameterization
z = q_z.rsample().float() # sample with reparameterization

if hard:
# No step function in torch, so using sign instead
z_hard = 0.5 * (torch.sign(z) + 1)
z = z + (z_hard - z).detach()

return self.decode(z), a
return self.decode(z), z


model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, a, prior=0.5, eps=1e-10):
def loss_function(recon_x, x, q_z, prior=0.5, eps=1e-10):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
# You can also compute p(x|z) as below, for binary output it reduces
# to binary cross entropy error, for gaussian output it reduces to

# p = вер выпадения 1 в straight through
p = torch.nn.functional.sigmoid(a)
t1 = p * ((p + eps) / prior).log()
t2 = (1 - p) * ((1 - p + eps) / (1 - prior)).log()
t1 = q_z * ((q_z + eps) / prior).log()
t2 = (1 - q_z) * ((1 - q_z + eps) / (1 - prior)).log()
KLD = torch.sum(t1 + t2, dim=-1).sum()

return BCE + KLD
Expand All @@ -100,8 +98,8 @@ def train(epoch):
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, a = model(data)
loss = loss_function(recon_batch, data, a)
recon_batch, z = model(data)
loss = loss_function(recon_batch, data, z)
loss.backward()
train_loss += loss.item()
optimizer.step()
Expand All @@ -124,8 +122,8 @@ def test(epoch):
with torch.no_grad():
for i, (data, _) in enumerate(test_loader):
data = data.to(device)
recon_batch, a = model(data)
test_loss += loss_function(recon_batch, data, a).item()
recon_batch, z = model(data)
test_loss += loss_function(recon_batch, data, z).item()
if i == 0:
n = min(data.size(0), 8)
comparison = torch.cat([data[:n],
Expand Down
59 changes: 59 additions & 0 deletions demo/visualization.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/relaxit/distributions/StraightThroughBernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, a: torch.Tensor, validate_args: bool = None):
"""

self.a = a.float() # Ensure a is a float tensor
self.uniform = torch.distributions.Uniform(torch.tensor([0.0]), torch.tensor([1.0]))
self.uniform = torch.distributions.Uniform(torch.tensor([0.0], device=self.a.device), torch.tensor([1.0], device=self.a.device))
super().__init__(validate_args=validate_args)

@property
Expand Down

0 comments on commit 22e3d64

Please sign in to comment.