Skip to content

Commit

Permalink
Fix pre-commits and mypy warnings.
Browse files Browse the repository at this point in the history
  • Loading branch information
RolandMacDoland committed Dec 10, 2024
1 parent 0b0718f commit e8c83d2
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 25 deletions.
2 changes: 2 additions & 0 deletions qek/data/conversion_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Final

import rdkit.Chem as Chem
Expand Down
6 changes: 4 additions & 2 deletions qek/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
from dataclasses import dataclass

Expand All @@ -18,10 +20,10 @@ class ProcessedData:
state_dict: dict[str, int]
target: int

def __post_init__(self):
def __post_init__(self) -> None:
self.state_dict = _convert_np_int64_to_int(data=self.state_dict)

def save_to_file(self, file_path: str):
def save_to_file(self, file_path: str) -> None:
with open(file_path, "w") as file:
tmp_dict = {
"sequence": self.sequence.to_abstract_repr(),
Expand Down
18 changes: 12 additions & 6 deletions qek/data/datatools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json

import networkx as nx
Expand All @@ -8,11 +10,10 @@
import torch.utils.data as torch_data
import torch_geometric.data as pyg_data
import torch_geometric.utils as pyg_utils
from rdkit.Chem import AllChem

from qek_os.data_io.conversion_data import PTCFM_EDGES_MAP, PTCFM_NODES_MAP
from qek_os.data_io.dataset import ProcessedData
from qek_os.utils import graph_to_mol
from rdkit.Chem import AllChem


def add_graph_coord(
Expand Down Expand Up @@ -66,11 +67,15 @@ def split_train_test(
generator = torch.Generator().manual_seed(seed)
else:
generator = torch.Generator()
train, val = torch_data.random_split(dataset=dataset, lengths=lengths, generator=generator)
train, val = torch_data.random_split(
dataset=dataset, lengths=lengths, generator=generator
)
return train, val


def check_compatibility_graph_device(graph: pyg_data.Data, device: pl.devices.Device) -> bool:
def check_compatibility_graph_device(
graph: pyg_data.Data, device: pl.devices.Device
) -> bool:
"""Given the characteristics of a graph, return True if the graph can be embedded in
the device, False if not.
Expand Down Expand Up @@ -114,10 +119,11 @@ def _return_min_dist(graph: pyg_data.Data) -> float:
compl_graph = nx.complement(nx_graph)
for start, end in compl_graph.edges():
distances.append(np.linalg.norm(graph_pos[start] - graph_pos[end], ord=2))
return min(distances)
min_dist: float = min(distances)
return min_dist


def save_dataset(dataset: list[ProcessedData], file_path: str):
def save_dataset(dataset: list[ProcessedData], file_path: str) -> None:
"""Saves a dataset to a JSON file.
Args:
Expand Down
40 changes: 24 additions & 16 deletions qek/kernel/kernel.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from __future__ import annotations

import collections
from collections.abc import Sequence

import numpy as np
from scipy.spatial.distance import jensenshannon

from qek_os.data_io.dataset import ProcessedData
from scipy.spatial.distance import jensenshannon


class QekKernel:
def __init__(self, mu: float):
self.mu = mu

def __call__(
self, graph_1: ProcessedData, graph_2: ProcessedData, size_max: float = 100
self, graph_1: ProcessedData, graph_2: ProcessedData, size_max: int = 100
) -> float:
"""Compute the similarity between two graphs using Jensen-Shannon divergence.
Expand Down Expand Up @@ -47,9 +48,11 @@ class from qek_os.data_io.dataset. The size_max parameter controls the maximum
js = (
jensenshannon(p=dist_graph_1, q=dist_graph_2) ** 2
) # Because the divergence is the square root of the distance
return np.exp(-self.mu * js)
return float(np.exp(-self.mu * js))

def create_train_kernel_matrix(self, train_dataset: Sequence[ProcessedData]) -> np.ndarray:
def create_train_kernel_matrix(
self, train_dataset: Sequence[ProcessedData]
) -> np.ndarray:
"""Compute a kernel matrix for a given training dataset.
This method computes a symmetric N x N kernel matrix from the Jensen-Shannon
Expand All @@ -71,23 +74,26 @@ def create_train_kernel_matrix(self, train_dataset: Sequence[ProcessedData]) ->
return kernel_mat

def create_test_kernel_matrix(
self, test_dataset: Sequence[ProcessedData], train_dataset: Sequence[ProcessedData]
self,
test_dataset: Sequence[ProcessedData],
train_dataset: Sequence[ProcessedData],
) -> np.ndarray:
"""Compute a kernel matrix for a given testing dataset and training set.
This method computes an N x M kernel matrix from the Jensen-Shannon
divergences between all pairs of graphs in the input testing dataset and the training dataset.
divergences between all pairs of graphs in the input testing dataset
and the training dataset.
The resulting matrix can be used as a similarity metric for machine learning algorithms,
particularly when evaluating the performance on the test dataset using a trained model.
Args:
test_dataset (Sequence[ProcessedData]): A list of ProcessedData objects representing the
testing dataset.
train_dataset (Sequence[ProcessedData]): A list of ProcessedData objects representing the
training set.
test_dataset (Sequence[ProcessedData]): A list of ProcessedData
objects representing the testing dataset.
train_dataset (Sequence[ProcessedData]): A list of ProcessedData
objects representing the training set.
Returns:
np.ndarray: An M x N matrix where the entry at row i and column j represents
the similarity between the graph in position i of the test dataset and the graph in position j
of the training set.
the similarity between the graph in position i of the test dataset
and the graph in position j of the training set.
"""
N_train = len(train_dataset)
N_test = len(test_dataset)
Expand All @@ -110,7 +116,9 @@ def count_occupation_from_bitstring(bitstring: str) -> int:
return sum(int(bit) for bit in bitstring)


def dist_excitation_and_vec(count_bitstring: dict[str, int], size_max: int) -> np.ndarray:
def dist_excitation_and_vec(
count_bitstring: dict[str, int], size_max: int
) -> np.ndarray:
"""Calculates the distribution of excitation energies from a dictionary of
bitstrings to their respective counts, and then creates a NumPy vector with the
results.
Expand All @@ -124,7 +132,7 @@ def dist_excitation_and_vec(count_bitstring: dict[str, int], size_max: int) -> n
np.ndarray: A NumPy array where keys are the number of '1' bits
in each binary string and values are the normalized counts.
"""
count_occ = collections.defaultdict(float)
count_occ: dict = collections.defaultdict(float)
total = 0.0
for k, v in count_bitstring.items():
nbr_occ = count_occupation_from_bitstring(k)
Expand All @@ -133,7 +141,7 @@ def dist_excitation_and_vec(count_bitstring: dict[str, int], size_max: int) -> n

numpy_vec = np.zeros(size_max)
for k, v in count_occ.items():
if k <= size_max:
if int(k) <= size_max:
numpy_vec[k] = v / total

return numpy_vec
6 changes: 5 additions & 1 deletion qek/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import networkx as nx
import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -71,7 +73,9 @@ def is_disk_graph(G: pyg_data.Data, radius: float) -> bool:
else:
raise AttributeError("Graph object does not have a pos attribute")

nx_graph = pyg_utils.to_networkx(G, to_undirected=True) # Molecule are unidrected Graphs
nx_graph = pyg_utils.to_networkx(
G, to_undirected=True
) # Molecule are unidrected Graphs

# Check if the graph is connected
if not nx.is_connected(nx_graph):
Expand Down

0 comments on commit e8c83d2

Please sign in to comment.