Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional rosetta repacking and fastrelax #91

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions docs/prediction.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ As an example, to predict a structure using 10 recycling steps and 25 samples (t
| `--msa_pairing_strategy` | str | `greedy` | Pairing strategy to use. Used only if --use_msa_server is set. Options are 'greedy' and 'complete' |
| `--write_full_pae` | `FLAG` | `False` | Whether to save the full PAE matrix as a file. |
| `--write_full_pde` | `FLAG` | `False` | Whether to save the full PDE matrix as a file. |
| `--rosetta_relax` | `FLAG` | `False` | Whether to perform rosetta repacking and fastrelax. Installation of pyrosetta and a valid license are required.
| `--relax_cores` | `INTEGER` | `8` | Number of cores for rosetta relaxation.

## Output

Expand Down
185 changes: 185 additions & 0 deletions src/boltz/data/write/rosetta_relax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from pyrosetta import *
from pyrosetta.rosetta.core.scoring import get_score_function
from pyrosetta.rosetta.protocols.minimization_packing import PackRotamersMover
from pyrosetta.rosetta.protocols.relax import FastRelax
from pyrosetta.rosetta.core.pack.task import TaskFactory

from pathlib import Path
import json
import pandas as pd
from tqdm.auto import tqdm

import sys
from contextlib import redirect_stdout, redirect_stderr


def ensure_pyrosetta_initialized():
"""
Ensure that PyRosetta is initialized.
"""
if not hasattr(ensure_pyrosetta_initialized, "_initialized"):
init(silent=False, set_logging_handler="logging")
ensure_pyrosetta_initialized._initialized = True


def get_rosetta_score(input_path):
"""
Computes the Rosetta score.
Args:
input_path (str): Path to the CIF/PDB file.
Returns:
float: Rosetta energy score.
"""
ensure_pyrosetta_initialized()

pose = pose_from_file(str(input_path))
# Get the ref2015 score function
scorefxn = get_score_function()
return scorefxn(pose)


def repack_sidechains(input_path, out_path, **kwargs):
"""
Performs Rosetta side-chain repacking and returns the final score
Args:
input_path (str): Path to the CIF/PDB file.
out_path (str): Path to the output PDB file.
Returns:
float: repacked Rosetta score.
"""
ensure_pyrosetta_initialized()

pose = pose_from_file(str(input_path))
scorefxn = get_score_function()
tf = TaskFactory()
tf.push_back(pyrosetta.rosetta.core.pack.task.operation.InitializeFromCommandline())
tf.push_back(pyrosetta.rosetta.core.pack.task.operation.RestrictToRepacking())
packer = PackRotamersMover()
packer.task_factory(tf)
packer.score_function(scorefxn)
packer.apply(pose)
pose.dump_pdb(str(out_path))
return scorefxn(pose)


def fastrelax(input_path, out_path, constrain_relax_to_start_coords=True, **kwargs):
"""
Performs Rosetta FastRelax and returns the final score
Args:
input_path (str): Path to the CIF/PDB file.
out_path (str): Path to the output PDB file.
Returns:
float: fastrelaxed Rosetta score.
"""
ensure_pyrosetta_initialized()

pose = pose_from_file(str(input_path))
scorefxn = get_score_function()
fast_relax = FastRelax()
fast_relax.set_scorefxn(scorefxn)
fast_relax.constrain_relax_to_start_coords(constrain_relax_to_start_coords)
fast_relax.apply(pose)
pose.dump_pdb(str(out_path))
return scorefxn(pose)


def relax(
input_path,
output_dir=None,
override=False,
save_logs=True,
save_energies=True,
**kwargs,
):
input_path = Path(input_path)
if output_dir is None:
output_dir = input_path.parent
stdout = (
str(output_dir / f"rosetta-relax_{input_path.stem}.stdout")
if save_logs
else "/dev/null"
)
stderr = (
str(output_dir / f"rosetta-relax_{input_path.stem}.stderr")
if save_logs
else "/dev/null"
)

with redirect_stdout(open(stdout, "w")):
with redirect_stderr(open(stderr, "w")):
init_energy = get_rosetta_score(input_path)
ret = dict(init_energy=init_energy, input_path=str(input_path))

repacked_path = output_dir / f"repacked_{input_path.stem}.pdb"
repacked_energy = (
repack_sidechains(input_path, repacked_path, **kwargs)
if not repacked_path.exists() or override
else get_rosetta_score(repacked_path)
)
ret.update(
dict(repacked_energy=repacked_energy, repacked_path=str(repacked_path))
)

relax_path = output_dir / f"fastrelaxed_{input_path.stem}.pdb"
fastrelaxed_energy = (
fastrelax(repacked_path, relax_path, **kwargs)
if not relax_path.exists() or override
else get_rosetta_score(relax_path)
)
ret.update(
dict(
fastrelaxed_energy=fastrelaxed_energy,
fastrelaxed_path=str(relax_path),
)
)

if save_energies:
json_path = output_dir / f"energies_{input_path.stem}.json"
with open(json_path, "w") as f:
json.dump(
ret, f, indent=4
) # `indent=4` makes the file human-readable

sys.stdout.flush()
sys.stderr.flush()

return ret


def parallel_relax(
input_paths,
output_dir=None,
override=False,
save_logs=False,
save_energies=True,
cores=8,
**kwargs,
):
import multiprocessing
from functools import partial

if not all([Path(p).suffix.lower() in [".cif", ".mmcif"] for p in input_paths]):
print(
"WARNING: If your structure contains a ligand, use CCD ligands in Boltz input and choose CIF as the output format. "
"This will ensure that the atom/residue naming is compatible with Rosetta."
)
with multiprocessing.get_context("spawn").Pool(cores) as ex:
ret = pd.DataFrame(
tqdm(
ex.imap_unordered(
partial(
relax,
output_dir=output_dir,
override=override,
save_logs=save_logs,
save_energies=save_energies,
**kwargs,
),
input_paths,
),
total=len(input_paths),
desc="Rosetta Relaxation",
)
)
ret["name"] = [Path(p).parent.name for p in ret.input_path]
return ret
10 changes: 7 additions & 3 deletions src/boltz/data/write/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from boltz.data.write.mmcif import to_mmcif
from boltz.data.write.pdb import to_pdb


class BoltzWriter(BasePredictionWriter):
"""Custom writer for predictions."""

Expand All @@ -43,6 +42,7 @@ def __init__(
self.data_dir = Path(data_dir)
self.output_dir = Path(output_dir)
self.output_format = output_format
self.paths_to_relax = []
self.failed = 0

# Create the output directories
Expand Down Expand Up @@ -145,10 +145,12 @@ def write_on_batch_end(
path = struct_dir / f"{outname}.pdb"
with path.open("w") as f:
f.write(to_pdb(new_structure, plddts=plddts))
self.paths_to_relax.append(Path(path).resolve())
elif self.output_format == "mmcif":
path = struct_dir / f"{outname}.cif"
with path.open("w") as f:
f.write(to_mmcif(new_structure, plddts=plddts))
self.paths_to_relax.append(Path(path).resolve())
else:
path = struct_dir / f"{outname}.npz"
np.savez_compressed(path, **asdict(new_structure))
Expand Down Expand Up @@ -225,5 +227,7 @@ def on_predict_epoch_end(
pl_module: LightningModule, # noqa: ARG002
) -> None:
"""Print the number of failed examples."""
# Print number of failed examples
print(f"Number of failed examples: {self.failed}") # noqa: T201
print(f"\nNumber of failed examples: {self.failed}") # noqa: T201

def get_paths_to_relax(self):
return self.paths_to_relax
47 changes: 47 additions & 0 deletions src/boltz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.utilities import rank_zero_only
from tqdm import tqdm
import pandas as pd

from boltz.data import const
from boltz.data.module.inference import BoltzInferenceDataModule
Expand All @@ -22,6 +23,7 @@
from boltz.data.write.writer import BoltzWriter
from boltz.model.model import Boltz1


CCD_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/ccd.pkl"
MODEL_URL = (
"https://huggingface.co/boltz-community/boltz-1/resolve/main/boltz1_conf.ckpt"
Expand Down Expand Up @@ -535,6 +537,17 @@ def cli() -> None:
help="Pairing strategy to use. Used only if --use_msa_server is set. Options are 'greedy' and 'complete'",
default="greedy",
)
@click.option(
"--rosetta_relax",
is_flag=True,
help="Whether to perform Rosetta repacking and fastrelax. Installation of pyrosetta and a valid license are required",
)
@click.option(
"--relax_cores",
type=int,
default=8,
help="Number of cores for rosetta relaxation",
)
def predict(
data: str,
out_dir: str,
Expand All @@ -555,9 +568,12 @@ def predict(
use_msa_server: bool = False,
msa_server_url: str = "https://api.colabfold.com",
msa_pairing_strategy: str = "greedy",
rosetta_relax: bool = False,
relax_cores: int = 8,
) -> None:
"""Run predictions with Boltz-1."""
# If cpu, write a friendly warning

if accelerator == "cpu":
msg = "Running on CPU, this will be slow. Consider using a GPU."
click.echo(msg)
Expand Down Expand Up @@ -682,6 +698,37 @@ def predict(
return_predictions=False,
)

paths_to_relax = pred_writer.get_paths_to_relax()
if rosetta_relax and len(paths_to_relax) > 0:
from boltz.data.write.rosetta_relax import parallel_relax

release_resources(trainer, model_module, data_module, pred_writer)
ret = parallel_relax(
paths_to_relax,
override=True,
cores=min(relax_cores, len(paths_to_relax)),
save_energies=True,
)
csv = Path(paths_to_relax[0]).parent.parent / "rosetta_energies.csv"
ret = ret.sort_values(by=["name", "repacked_energy"]).reset_index(drop=True)
if csv.exists():
print(f"`{csv}` exists! appending ...\n")
ret = pd.concat([ret, pd.read_csv(csv)])
ret.to_csv(csv, index=False)


def release_resources(trainer=None, *objects):
import gc

for obj in objects:
try:
del obj
except Exception as e:
print(f"Error releasing object {obj}: {e}")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()


if __name__ == "__main__":
cli()