From 987b2c82101b18504483519eea7ad71cd632b7c6 Mon Sep 17 00:00:00 2001 From: Sergey O Date: Mon, 24 Apr 2023 20:36:47 -0400 Subject: [PATCH] adding plddt coloring to output pdb(s) --- config/inference/base.yaml | 1 + run_inference.py | 27 +++++++++++++++++---------- util.py | 9 ++++++--- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/config/inference/base.yaml b/config/inference/base.yaml index da1bcd7..4ad9da9 100644 --- a/config/inference/base.yaml +++ b/config/inference/base.yaml @@ -20,6 +20,7 @@ inference: deterministic: False trb_save_ckpt_path: null dump_pdb: False + dump_pdb_path: "/tmp" contigmap: contigs: null diff --git a/run_inference.py b/run_inference.py index df96119..72b6cbf 100755 --- a/run_inference.py +++ b/run_inference.py @@ -82,7 +82,7 @@ def main(conf: HydraConfig) -> None: x_t = torch.clone(x_init) seq_t = torch.clone(seq_init) # Loop over number of reverse diffusion time steps. - for t in range(int(sampler.t_step_input), sampler.inf_conf.final_step - 1, -1): + for num,t in enumerate(range(int(sampler.t_step_input), sampler.inf_conf.final_step - 1, -1)): px0, x_t, seq_t, plddt = sampler.sample_step( t=t, x_t=x_t, seq_init=seq_t, final_step=sampler.inf_conf.final_step ) @@ -92,15 +92,16 @@ def main(conf: HydraConfig) -> None: plddt_stack.append(plddt[0]) # remove singleton leading dimension if conf.inference.dump_pdb: + bfacts = plddt.cpu().numpy()[0] protein = px0.cpu().numpy()[:,:4] - pdb_str = [] - ctr = 0 + pdb_str,line_num = "",0 for n,residue in enumerate(protein): for xyz,atom in zip(residue,[" N ", " CA ", " C ", " O "]): - pdb_str.append("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f" - % ("ATOM",ctr,atom,"GLY","A",n,*xyz,1,0)) - ctr += 1 - print(":".join(pdb_str)) + pdb_str += "%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n" % ("ATOM",line_num,atom,"GLY","A",n,*xyz,1,bfacts[n]) + line_num += 1 + dump_pdb_path = os.path.join(conf.inference.dump_pdb_path,f"{num}.pdb") + with open(dump_pdb_path,"w") as handle: + handle.write(pdb_str + "TER") # Flip order for better visualization in pymol @@ -121,6 +122,12 @@ def main(conf: HydraConfig) -> None: # For logging -- don't flip plddt_stack = torch.stack(plddt_stack) + bfact_stack = torch.flip( + plddt_stack, + [ + 0, + ], + ) # Save outputs os.makedirs(os.path.dirname(out_prefix), exist_ok=True) @@ -144,7 +151,7 @@ def main(conf: HydraConfig) -> None: final_seq, sampler.binderlen, chain_idx=sampler.chain_idx, - bfacts=bfacts, + bfacts=bfact_stack[0], ) # run metadata @@ -173,7 +180,7 @@ def main(conf: HydraConfig) -> None: writepdb_multi( out, denoised_xyz_stack, - bfacts, + bfact_stack, final_seq.squeeze(), use_hydrogens=False, backbone_only=False, @@ -184,7 +191,7 @@ def main(conf: HydraConfig) -> None: writepdb_multi( out, px0_xyz_stack, - bfacts, + bfact_stack, final_seq.squeeze(), use_hydrogens=False, backbone_only=False, diff --git a/util.py b/util.py index b1f8ca4..53688e2 100644 --- a/util.py +++ b/util.py @@ -679,11 +679,14 @@ def writepdb_multi( if seq_stack.ndim != 2: T = atoms_stack.shape[0] seq_stack = torch.tile(seq_stack, (T, 1)) + if bfacts.ndim != 2: + T = atoms_stack.shape[0] + bfacts = torch.tile(bfacts, (T, 1)) seq_stack = seq_stack.cpu() - for atoms, scpu in zip(atoms_stack, seq_stack): + for atoms, scpu, bfact in zip(atoms_stack, seq_stack, bfacts): ctr = 1 atomscpu = atoms.cpu() - Bfacts = torch.clamp(bfacts.cpu(), 0, 1) + B = torch.clamp(bfact.cpu(), 0, 1) for i, s in enumerate(scpu): atms = aa2long[s] for j, atm_j in enumerate(atms): @@ -709,7 +712,7 @@ def writepdb_multi( atomscpu[i, j, 1], atomscpu[i, j, 2], 1.0, - Bfacts[i], + B[i], ) ) ctr += 1