Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

posterior calculation #4

Open
oneoftwo opened this issue Feb 29, 2024 · 0 comments
Open

posterior calculation #4

oneoftwo opened this issue Feb 29, 2024 · 0 comments

Comments

@oneoftwo
Copy link

Hi, I have some question regarding the posterior sampling code.
In my understandings, in the posterior sampling process (during training and generation both)

q(x_{t-1}|x_{t}, x_0)

the x_0 is the predicted type (would be provided as logprobability when generation)
and the x_{t} is the current (generated x_{t}) (a type sampled by the previously predicted probability.)

However, it seems that the code in MolDiff takes x_t as the probability obtained from the previous generation step.
Could you please clarify if there are any different processing steps that I may have misunderstood?
Thank you.

log_node_type = self.node_transition.q_v_posterior(log_node_recon, log_node_type, time_step, batch_node, v0_prob=True)
node_type_prev = log_sample_categorical(log_node_type)

def q_v_posterior(self, log_v0, log_vt, t, batch, v0_prob):
# q(vt-1 | vt, v0) = q(vt | vt-1, x0) * q(vt-1 | x0) / q(vt | x0)
t_minus_1 = t - 1
t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) # Remove negative values, will not be used anyway for final decoder

    fact1 = extract(self.transpopse_q_onestep_mats, t, batch, ndim=1)
    # class_vt = log_vt.argmax(dim=-1)
    # fact1 = fact1[torch.arange(len(class_vt)), class_vt]
    fact1 = torch.einsum('bj,bjk->bk', torch.exp(log_vt), fact1)  # (batch, N)
    
    if not v0_prob:  # log_v0 is directly transformed to onehot
        fact2 = extract(self.q_mats, t_minus_1, batch, ndim=1)
        class_v0 = log_v0.argmax(dim=-1)
        fact2 = fact2[torch.arange(len(class_v0)), class_v0]
    else:  # log_v0 contains the probability information
        fact2 = extract(self.q_mats, t_minus_1, batch, ndim=1)  # (batch, N, N)
        fact2 = torch.einsum('bj,bjk->bk', torch.exp(log_v0), fact2)  # (batch, N)
    
    ndim = log_v0.ndim
    if ndim == 2:
        t_expand = t[batch].unsqueeze(-1)
    elif ndim == 3:
        t_expand = t[batch].unsqueeze(-1).unsqueeze(-1)
    else:
        raise NotImplementedError('ndim not supported')
    
    out = torch.log(fact1 + self.eps).clamp_min(-32.) + torch.log(fact2 + self.eps).clamp_min(-32.)
    out = out - torch.logsumexp(out, dim=-1, keepdim=True)
    out_t0 = log_v0
    out = torch.where(t_expand == 0, out_t0, out)
    return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant