Skip to content

Commit

Permalink
add additional hetatm info to pyg parser (#397)
Browse files Browse the repository at this point in the history
* add additional hetatm info to parser

* Update changelog

---------

Co-authored-by: Arian Jamasb <[email protected]>
  • Loading branch information
a-r-j and Arian Jamasb authored Aug 4, 2024
1 parent 4d8dc64 commit aee6638
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* Add missing modified residue `AYA` to constants [#390](https://github.com/a-r-j/graphein/pull/390)
* Fix bug where the `deprotonate` argument is not wired up to `graphein.protein.graphs.construct_graphs` [#375](https://github.com/a-r-j/graphein/pull/375)
* Fix cluster file loading bug in `pdb_data.py` [#396](https://github.com/a-r-j/graphein/pull/396)
* Improves storage of hetatm data in `graphein.protein.tensor.io.protein_to_pyg` [#397](https://github.com/a-r-j/graphein/pull/397).

#### Misc
* set logging to false by default and added mmcif support [#402](https://github.com/a-r-j/graphein/pull/402)
Expand Down
18 changes: 15 additions & 3 deletions graphein/protein/tensor/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Project Website: https://github.com/a-r-j/graphein
# Code Repository: https://github.com/a-r-j/graphein

import collections
import os
from typing import List, Optional, Union

Expand Down Expand Up @@ -218,13 +219,24 @@ def protein_to_pyg(
if store_het:
hetatms = df.loc[df.record_name == "HETATM"]
all_hets = list(set(hetatms.residue_name))
het_coords = {}
het_data = collections.defaultdict(dict)
for het in all_hets:
het_coords[het] = torch.tensor(
het_data[het]["coords"] = torch.tensor(
hetatms.loc[hetatms.residue_name == het][
["x_coord", "y_coord", "z_coord"]
].values
)
het_data[het]["atoms"] = hetatms.loc[hetatms.residue_name == het][
"atom_name"
].values
het_data[het]["residue_number"] = torch.tensor(
hetatms.loc[hetatms.residue_name == het][
"residue_number"
].values
)
het_data[het]["element_symbol"] = hetatms.loc[
hetatms.residue_name == het
]["element_symbol"].values

df = df.loc[df.record_name == "ATOM"]
if remove_nonstandard:
Expand Down Expand Up @@ -263,7 +275,7 @@ def protein_to_pyg(
)

if store_het:
out.hetatms = [het_coords]
out.hetatms = [het_data]

if store_bfactor:
# group by residue_id and average b_factor per residue
Expand Down

0 comments on commit aee6638

Please sign in to comment.