diff --git a/inference/model_runners.py b/inference/model_runners.py index a2e0f38..2b55323 100644 --- a/inference/model_runners.py +++ b/inference/model_runners.py @@ -324,7 +324,7 @@ def sample_init(self, return_forward_trajectory=False): xyz_motif_prealign = xyz_mapped.clone() motif_prealign_com = xyz_motif_prealign[0,0,:,1].mean(dim=0) self.motif_com = xyz_27[contig_map.ref_idx0,1].mean(dim=0) - xyz_mapped = get_init_xyz(xyz_mapped).squeeze() + xyz_mapped = get_init_xyz(xyz_mapped, center=self.symmetry is None).squeeze() # adjust the size of the input atom map atom_mask_mapped = torch.full((L_mapped, 27), False) atom_mask_mapped[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0] @@ -352,6 +352,11 @@ def sample_init(self, return_forward_trajectory=False): seq_t = torch.nn.functional.one_hot(seq_t, num_classes=22).float() # [L,22] seq_orig = torch.nn.functional.one_hot(seq_orig, num_classes=22).float() # [L,22] + ############################################################################ + if self.symmetry is not None: + xyz_mapped, seq_t = self.symmetry.apply_symmetry(xyz_mapped, seq_t) + ############################################################################ + fa_stack, xyz_true = self.diffuser.diffuse_pose( xyz_mapped, torch.clone(seq_t), diff --git a/kinematics.py b/kinematics.py index 1fbc015..76a6781 100644 --- a/kinematics.py +++ b/kinematics.py @@ -282,7 +282,7 @@ def c6d_to_bins2(c6d, same_chain, negative=False, params=PARAMS): return torch.stack([db,ob,tb,pb],axis=-1).long() -def get_init_xyz(xyz_t): +def get_init_xyz(xyz_t, center=True): # input: xyz_t (B, T, L, 14, 3) # ouput: xyz (B, T, L, 14, 3) B, T, L = xyz_t.shape[:3] @@ -292,8 +292,10 @@ def get_init_xyz(xyz_t): mask = torch.isnan(xyz_t[:,:,:,:3]).any(dim=-1).any(dim=-1) # (B, T, L) # - center_CA = ((~mask[:,:,:,None]) * torch.nan_to_num(xyz_t[:,:,:,1,:])).sum(dim=2) / ((~mask[:,:,:,None]).sum(dim=2)+1e-4) # (B, T, 3) - xyz_t = xyz_t - center_CA.view(B,T,1,1,3) + + if center: + center_CA = ((~mask[:,:,:,None]) * torch.nan_to_num(xyz_t[:,:,:,1,:])).sum(dim=2) / ((~mask[:,:,:,None]).sum(dim=2)+1e-4) # (B, T, 3) + xyz_t = xyz_t - center_CA.view(B,T,1,1,3) # idx_s = list() for i_b in range(B):