Skip to content

Commit

Permalink
adding plddt coloring to output pdb(s)
Browse files Browse the repository at this point in the history
  • Loading branch information
sokrypton committed Apr 25, 2023
1 parent b1c9e9f commit 987b2c8
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
1 change: 1 addition & 0 deletions config/inference/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ inference:
deterministic: False
trb_save_ckpt_path: null
dump_pdb: False
dump_pdb_path: "/tmp"

contigmap:
contigs: null
Expand Down
27 changes: 17 additions & 10 deletions run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def main(conf: HydraConfig) -> None:
x_t = torch.clone(x_init)
seq_t = torch.clone(seq_init)
# Loop over number of reverse diffusion time steps.
for t in range(int(sampler.t_step_input), sampler.inf_conf.final_step - 1, -1):
for num,t in enumerate(range(int(sampler.t_step_input), sampler.inf_conf.final_step - 1, -1)):
px0, x_t, seq_t, plddt = sampler.sample_step(
t=t, x_t=x_t, seq_init=seq_t, final_step=sampler.inf_conf.final_step
)
Expand All @@ -92,15 +92,16 @@ def main(conf: HydraConfig) -> None:
plddt_stack.append(plddt[0]) # remove singleton leading dimension

if conf.inference.dump_pdb:
bfacts = plddt.cpu().numpy()[0]
protein = px0.cpu().numpy()[:,:4]
pdb_str = []
ctr = 0
pdb_str,line_num = "",0
for n,residue in enumerate(protein):
for xyz,atom in zip(residue,[" N ", " CA ", " C ", " O "]):
pdb_str.append("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f"
% ("ATOM",ctr,atom,"GLY","A",n,*xyz,1,0))
ctr += 1
print(":".join(pdb_str))
pdb_str += "%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n" % ("ATOM",line_num,atom,"GLY","A",n,*xyz,1,bfacts[n])
line_num += 1
dump_pdb_path = os.path.join(conf.inference.dump_pdb_path,f"{num}.pdb")
with open(dump_pdb_path,"w") as handle:
handle.write(pdb_str + "TER")


# Flip order for better visualization in pymol
Expand All @@ -121,6 +122,12 @@ def main(conf: HydraConfig) -> None:

# For logging -- don't flip
plddt_stack = torch.stack(plddt_stack)
bfact_stack = torch.flip(
plddt_stack,
[
0,
],
)

# Save outputs
os.makedirs(os.path.dirname(out_prefix), exist_ok=True)
Expand All @@ -144,7 +151,7 @@ def main(conf: HydraConfig) -> None:
final_seq,
sampler.binderlen,
chain_idx=sampler.chain_idx,
bfacts=bfacts,
bfacts=bfact_stack[0],
)

# run metadata
Expand Down Expand Up @@ -173,7 +180,7 @@ def main(conf: HydraConfig) -> None:
writepdb_multi(
out,
denoised_xyz_stack,
bfacts,
bfact_stack,
final_seq.squeeze(),
use_hydrogens=False,
backbone_only=False,
Expand All @@ -184,7 +191,7 @@ def main(conf: HydraConfig) -> None:
writepdb_multi(
out,
px0_xyz_stack,
bfacts,
bfact_stack,
final_seq.squeeze(),
use_hydrogens=False,
backbone_only=False,
Expand Down
9 changes: 6 additions & 3 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,11 +679,14 @@ def writepdb_multi(
if seq_stack.ndim != 2:
T = atoms_stack.shape[0]
seq_stack = torch.tile(seq_stack, (T, 1))
if bfacts.ndim != 2:
T = atoms_stack.shape[0]
bfacts = torch.tile(bfacts, (T, 1))
seq_stack = seq_stack.cpu()
for atoms, scpu in zip(atoms_stack, seq_stack):
for atoms, scpu, bfact in zip(atoms_stack, seq_stack, bfacts):
ctr = 1
atomscpu = atoms.cpu()
Bfacts = torch.clamp(bfacts.cpu(), 0, 1)
B = torch.clamp(bfact.cpu(), 0, 1)
for i, s in enumerate(scpu):
atms = aa2long[s]
for j, atm_j in enumerate(atms):
Expand All @@ -709,7 +712,7 @@ def writepdb_multi(
atomscpu[i, j, 1],
atomscpu[i, j, 2],
1.0,
Bfacts[i],
B[i],
)
)
ctr += 1
Expand Down

0 comments on commit 987b2c8

Please sign in to comment.