Skip to content

Commit

Permalink
fix symmetry
Browse files Browse the repository at this point in the history
disable centering when symmetry is defined
  • Loading branch information
sokrypton committed Mar 31, 2023
1 parent 642e364 commit c042c0e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
7 changes: 6 additions & 1 deletion inference/model_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def sample_init(self, return_forward_trajectory=False):
xyz_motif_prealign = xyz_mapped.clone()
motif_prealign_com = xyz_motif_prealign[0,0,:,1].mean(dim=0)
self.motif_com = xyz_27[contig_map.ref_idx0,1].mean(dim=0)
xyz_mapped = get_init_xyz(xyz_mapped).squeeze()
xyz_mapped = get_init_xyz(xyz_mapped, center=self.symmetry is None).squeeze()
# adjust the size of the input atom map
atom_mask_mapped = torch.full((L_mapped, 27), False)
atom_mask_mapped[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0]
Expand Down Expand Up @@ -352,6 +352,11 @@ def sample_init(self, return_forward_trajectory=False):
seq_t = torch.nn.functional.one_hot(seq_t, num_classes=22).float() # [L,22]
seq_orig = torch.nn.functional.one_hot(seq_orig, num_classes=22).float() # [L,22]

############################################################################
if self.symmetry is not None:
xyz_mapped, seq_t = self.symmetry.apply_symmetry(xyz_mapped, seq_t)
############################################################################

fa_stack, xyz_true = self.diffuser.diffuse_pose(
xyz_mapped,
torch.clone(seq_t),
Expand Down
8 changes: 5 additions & 3 deletions kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def c6d_to_bins2(c6d, same_chain, negative=False, params=PARAMS):

return torch.stack([db,ob,tb,pb],axis=-1).long()

def get_init_xyz(xyz_t):
def get_init_xyz(xyz_t, center=True):
# input: xyz_t (B, T, L, 14, 3)
# ouput: xyz (B, T, L, 14, 3)
B, T, L = xyz_t.shape[:3]
Expand All @@ -292,8 +292,10 @@ def get_init_xyz(xyz_t):

mask = torch.isnan(xyz_t[:,:,:,:3]).any(dim=-1).any(dim=-1) # (B, T, L)
#
center_CA = ((~mask[:,:,:,None]) * torch.nan_to_num(xyz_t[:,:,:,1,:])).sum(dim=2) / ((~mask[:,:,:,None]).sum(dim=2)+1e-4) # (B, T, 3)
xyz_t = xyz_t - center_CA.view(B,T,1,1,3)

if center:
center_CA = ((~mask[:,:,:,None]) * torch.nan_to_num(xyz_t[:,:,:,1,:])).sum(dim=2) / ((~mask[:,:,:,None]).sum(dim=2)+1e-4) # (B, T, 3)
xyz_t = xyz_t - center_CA.view(B,T,1,1,3)
#
idx_s = list()
for i_b in range(B):
Expand Down

0 comments on commit c042c0e

Please sign in to comment.