From aee6638b0ade154318bce315fda0b45f8b4edff5 Mon Sep 17 00:00:00 2001 From: Arian Jamasb Date: Sun, 4 Aug 2024 22:20:02 +0200 Subject: [PATCH] add additional hetatm info to pyg parser (#397) * add additional hetatm info to parser * Update changelog --------- Co-authored-by: Arian Jamasb --- CHANGELOG.md | 1 + graphein/protein/tensor/io.py | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f0dcf333..7dc9288c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * Add missing modified residue `AYA` to constants [#390](https://github.com/a-r-j/graphein/pull/390) * Fix bug where the `deprotonate` argument is not wired up to `graphein.protein.graphs.construct_graphs` [#375](https://github.com/a-r-j/graphein/pull/375) * Fix cluster file loading bug in `pdb_data.py` [#396](https://github.com/a-r-j/graphein/pull/396) +* Improves storage of hetatm data in `graphein.protein.tensor.io.protein_to_pyg` [#397](https://github.com/a-r-j/graphein/pull/397). #### Misc * set logging to false by default and added mmcif support [#402](https://github.com/a-r-j/graphein/pull/402) diff --git a/graphein/protein/tensor/io.py b/graphein/protein/tensor/io.py index e938130a..fa5493ce 100644 --- a/graphein/protein/tensor/io.py +++ b/graphein/protein/tensor/io.py @@ -6,6 +6,7 @@ # Project Website: https://github.com/a-r-j/graphein # Code Repository: https://github.com/a-r-j/graphein +import collections import os from typing import List, Optional, Union @@ -218,13 +219,24 @@ def protein_to_pyg( if store_het: hetatms = df.loc[df.record_name == "HETATM"] all_hets = list(set(hetatms.residue_name)) - het_coords = {} + het_data = collections.defaultdict(dict) for het in all_hets: - het_coords[het] = torch.tensor( + het_data[het]["coords"] = torch.tensor( hetatms.loc[hetatms.residue_name == het][ ["x_coord", "y_coord", "z_coord"] ].values ) + het_data[het]["atoms"] = hetatms.loc[hetatms.residue_name == het][ + "atom_name" + ].values + het_data[het]["residue_number"] = torch.tensor( + hetatms.loc[hetatms.residue_name == het][ + "residue_number" + ].values + ) + het_data[het]["element_symbol"] = hetatms.loc[ + hetatms.residue_name == het + ]["element_symbol"].values df = df.loc[df.record_name == "ATOM"] if remove_nonstandard: @@ -263,7 +275,7 @@ def protein_to_pyg( ) if store_het: - out.hetatms = [het_coords] + out.hetatms = [het_data] if store_bfactor: # group by residue_id and average b_factor per residue