You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I have some question regarding the posterior sampling code.
In my understandings, in the posterior sampling process (during training and generation both)
q(x_{t-1}|x_{t}, x_0)
the x_0 is the predicted type (would be provided as logprobability when generation)
and the x_{t} is the current (generated x_{t}) (a type sampled by the previously predicted probability.)
However, it seems that the code in MolDiff takes x_t as the probability obtained from the previous generation step.
Could you please clarify if there are any different processing steps that I may have misunderstood?
Thank you.
Hi, I have some question regarding the posterior sampling code.
In my understandings, in the posterior sampling process (during training and generation both)
q(x_{t-1}|x_{t}, x_0)
the x_0 is the predicted type (would be provided as logprobability when generation)
and the x_{t} is the current (generated x_{t}) (a type sampled by the previously predicted probability.)
However, it seems that the code in MolDiff takes x_t as the probability obtained from the previous generation step.
Could you please clarify if there are any different processing steps that I may have misunderstood?
Thank you.
log_node_type = self.node_transition.q_v_posterior(log_node_recon, log_node_type, time_step, batch_node, v0_prob=True)
node_type_prev = log_sample_categorical(log_node_type)
def q_v_posterior(self, log_v0, log_vt, t, batch, v0_prob):
# q(vt-1 | vt, v0) = q(vt | vt-1, x0) * q(vt-1 | x0) / q(vt | x0)
t_minus_1 = t - 1
t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) # Remove negative values, will not be used anyway for final decoder
The text was updated successfully, but these errors were encountered: