From f8ee2f974383611b307abcdcb6edb0e375ce86e0 Mon Sep 17 00:00:00 2001 From: David Teller Date: Wed, 8 Jan 2025 12:46:14 +0100 Subject: [PATCH] [Cleanup] Gathering graph utilities around BasicGraph/MoleculeGraph, moving code from notebook to library --- examples/pipeline.ipynb | 189 ++++++++------------- qek/data/conversion_data.py | 37 ---- qek/data/datatools.py | 326 +++++++++++++++++++++++++----------- qek/kernel/__init__.py | 4 +- qek/kernel/kernel.py | 2 +- qek/utils.py | 70 +------- tests/test_datatools.py | 159 ++++++++++++++++-- tests/test_utils.py | 128 -------------- 8 files changed, 449 insertions(+), 466 deletions(-) delete mode 100644 qek/data/conversion_data.py delete mode 100644 tests/test_utils.py diff --git a/examples/pipeline.ipynb b/examples/pipeline.ipynb index 8558ee0..34c7a64 100644 --- a/examples/pipeline.ipynb +++ b/examples/pipeline.ipynb @@ -5,7 +5,8 @@ "metadata": {}, "source": [ "# QEK from A to Z\n", - "This notebook reproduces the results of [QEK](https://journals.aps.org/pra/abstract/10.1103/PhysRevA.107.042615) by running the jupyter notebook.\n", + "\n", + "This notebook reproduces the results of the [QEK paper](https://journals.aps.org/pra/abstract/10.1103/PhysRevA.107.042615).\n", "\n", "At the end, you will be able to:\n", "1. Find the embeddings of a molecular dataset\n", @@ -15,18 +16,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_160993/4114388664.py:9: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", - " from tqdm.autonotebook import tqdm\n" - ] - } - ], + "outputs": [], "source": [ "from __future__ import annotations\n", "\n", @@ -34,7 +26,6 @@ "\n", "import numpy as np\n", "import pulser as pl\n", - "import torch_geometric.data as pyg_data\n", "import torch_geometric.datasets as pyg_dataset\n", "from tqdm.autonotebook import tqdm\n" ] @@ -50,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -60,12 +51,11 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ - "import qek.data.datatools as qek_datatools\n", - "from qek.utils import compute_register, is_disk_graph" + "import qek.data.datatools as qek_datatools" ] }, { @@ -74,24 +64,26 @@ "source": [ "# Graph embedding\n", "\n", - "A graph, molecular or otherwise, does not have coordinates in space. We therefore need to find a regular embedding as possible. \n", - "In addition to this regularity, we must be sure that the embedding obtained by the function `add_graph_coord` is an **unit-disk graph embedding**.\n", - "\n", - "Here, we want that the distance between two connected qubits is equal to `RADIUS=5.001` $\\mu m$.\n", - "The `is_disk_graph` function check if the found embedding is an embedding of unit-disk graph, i.e. the distance between two connected nodes should be less than `RADIUS` and the distance between two disconnected nodes should be greater than `RADIUS`.\n", + "QEK lets researchers embed _graphs_ on Quantum Devices. To do this, we need to give these graphs a geometry (positions in\n", + "space) and to confirm that the geometry is compatible with a Quantum Device. Here, our dataset consists in molecules (represented\n", + "as graphs). To simplify things, QEK comes with a dedicated class `qek_datatools.MoleculeGraph` that adds a geometry to the graphs.\n", "\n", + "One of the core ideas behind QEK is that each nodes (aka atoms) in a graph (aka molecule) from the dataset is represented by one\n", + "cold atom on the Device and if two nodes are joined by an edge, their cold atoms must be close to each other. In geometrical terms,\n", + "this means that the `MoleculeGraph` must be a _disk graph_, with a radius of 5.001 $\\mu m$. In this notebook, for the sake of\n", + "simplicity, we simply discard graphs that are not disk graphs.\n", " " ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fd00873529984a7ea9c20e87bc59b4c2", + "model_id": "9deae7c65fae4de9b557c916e3df32bc", "version_major": 2, "version_minor": 0 }, @@ -104,13 +96,13 @@ } ], "source": [ - "list_of_graph = []\n", + "list_of_graphs = []\n", "RADIUS = 5.001\n", "EPS = 0.01\n", - "for graph in tqdm(og_ptcfm):\n", - " graph_with_pos = qek_datatools.add_graph_coord(graph=graph, blockade_radius=RADIUS)\n", - " if is_disk_graph(graph_with_pos, radius=RADIUS+EPS):\n", - " list_of_graph.append((graph_with_pos, graph.y.item()))" + "for data in tqdm(og_ptcfm):\n", + " graph = qek_datatools.MoleculeGraph(data=data, blockade_radius=RADIUS)\n", + " if graph.is_disk_graph(radius=RADIUS+EPS):\n", + " list_of_graphs.append((graph, graph.pyg.y.item()))" ] }, { @@ -119,87 +111,54 @@ "source": [ "## Create a Pulser sequence\n", "\n", - "Once the embedding is found, we will create a pulser sequence that can be interpreted by the QPU or a Pasqal emulator. A sequence consists of a **register**, which means the position of qubits in a device and a **pulse** sequence.\n", - "\n", - "The `create_sequence_from_graph` function is responsible for doing this. It checks if the positions of the qubits respect the constraints of the device (number of qubits, minimum and maximum distance between qubits, etc.) and create a register if the embedding pass all the tests. Finally, the pulse sequence is the same as that in the scientific paper, which is a constant pulse with values $\\Omega = 2\\pi$, $\\delta = 0$ and a duration of $660 ns$.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "# Create a sequence:\n", - "\n", - "def create_sequence_from_graph(graph:pyg_data.Data, device=pl.devices.Device)-> pl.Sequence:\n", - " if not qek_datatools.check_compatibility_graph_device(graph, device):\n", - " raise ValueError(f\"The graph is not compatible with {device}\")\n", - " reg = compute_register(data_graph=graph)\n", - " seq = pl.Sequence(register=reg, device=device)\n", - " Omega_max = 1.0 * 2 * np.pi\n", - " t_max = 660\n", - " pulse = pl.Pulse.ConstantAmplitude(\n", - " amplitude=Omega_max,\n", - " detuning=pl.waveforms.RampWaveform(t_max, 0, 0),\n", - " phase=0.0,\n", - " )\n", - " seq.declare_channel(\"ising\", \"rydberg_global\")\n", - " seq.add(pulse, \"ising\")\n", - " return seq" + "Once the embedding is found, we create a Pulser Sequence that can be interpreted by a Quantum Device. A Sequence consists of a **register** (i.e. a geometry of cold atoms on the device) and **pulse**s. Sequences need to be designed for a specific device, so our graph object offers a method `compute_sequence` that does exactly that." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 17, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f3e5e9dc5b864843ba85a4281951072b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/294 [00:00= MAX_NUMBER_OF_DATASETS:\n", - " break\n", - " except ValueError as err:\n", - " print(f\"{err}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "from pulser_simulation import QutipEmulator" + "for graph, target in tqdm(list_of_graphs):\n", + " # Some graph are not compatible with AnalogDevice\n", + " if graph.is_embeddable(device=pl.AnalogDevice):\n", + " dataset_sequence.append((graph.compute_sequence(device=pl.AnalogDevice), target))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "A pulser sequence is all you need for a quantum calculation on a Pasqal QPU! Before submitting the calculation to an actual quantum computer, we must first verify that everything works by emulation. For this, Pasqal has developed `pulser_simulation`.\n", - "\n", - "The code below allows us to emulate the entire \"quantum compatible\" PTC-FM dataset (i.e., whose embeddings are unit-disk and compatible with the device). However, we advise against running it for time reasons.\n", - "\n", - "Fortunately, we have already emulated the entire PTC-FM compatible dataset. You just need to load it up.\n" + "A pulser sequence is all you need for a quantum calculation on a Pasqal QPU! Before submitting the calculation to an actual quantum computer, let's verify that everything works on our machine. For this, Pasqal has developed several simulators, which you may find in `pulser_simulation`. Of course, quantum simulators are much slower than a real quantum computer, so we're not going to run all these embeddings on our simulator." ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "479ba694fb0540718becf9971c5ef27e", + "model_id": "36483a40288c431eb0482e66325b375a", "version_major": 2, "version_minor": 0 }, @@ -212,7 +171,13 @@ } ], "source": [ - "for seq, target in tqdm(dataset_sequence):\n", + "from pulser_simulation import QutipEmulator\n", + "\n", + "# In this tutorial, to make things faster, we'll only run the first compatible entry in the dataset.\n", + "# If you wish to run more entries, feel free to increase this value.\n", + "MAX_NUMBER_OF_DATASETS = 1\n", + "\n", + "for seq, target in tqdm(dataset_sequence[0:MAX_NUMBER_OF_DATASETS]):\n", " simul = QutipEmulator.from_sequence(sequence=seq)\n", " results = simul.run()\n" ] @@ -221,28 +186,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Loading the already existing dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "processed_dataset = qek_datatools.load_dataset(file_path=\"ptcfm_processed_dataset.json\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Some properties of the newly created dataset:" + "## Loading the already existing dataset\n", + "\n", + "For this notebook, instead of spending ours running the simulator on your computer, we're going to skip\n", + "this step and load on we're going to cheat and load the results, which are conveniently stored in a file." ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -254,6 +206,7 @@ } ], "source": [ + "processed_dataset = qek_datatools.load_dataset(file_path=\"ptcfm_processed_dataset.json\")\n", "print(f\"Size of the quantum compatible dataset = {len(processed_dataset)}\")" ] }, @@ -261,12 +214,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Cherry picked register and pulse sequence:" + "Let's look at a the sequence and register for one of these samples:" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -286,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -323,16 +276,16 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ - "from qek.kernel import Kernel\n" + "from qek.kernel import QuantumEvolutionKernel as QEK\n" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -415,11 +368,11 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ - "kernel = Kernel(mu=2.)" + "kernel = QEK(mu=2.)" ] }, { @@ -431,7 +384,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -504,7 +457,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -514,7 +467,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -538,7 +491,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 28, "metadata": {}, "outputs": [ { diff --git a/qek/data/conversion_data.py b/qek/data/conversion_data.py deleted file mode 100644 index 93ddcdf..0000000 --- a/qek/data/conversion_data.py +++ /dev/null @@ -1,37 +0,0 @@ -from __future__ import annotations - -from typing import Final - -import rdkit.Chem as Chem - -# Constants used to decode the PTC-FM data set. - -# Node labels -PTCFM_NODES_MAP: Final[dict[int, str]] = { - 0: "In", - 1: "P", - 2: "C", - 3: "O", - 4: "N", - 5: "Cl", - 6: "S", - 7: "Br", - 8: "Na", - 9: "F", - 10: "As", - 11: "K", - 12: "Cu", - 13: "I", - 14: "Ba", - 15: "Sn", - 16: "Pb", - 17: "Ca", -} - -# Edges labels -PTCFM_EDGES_MAP: Final[dict[int, Chem.BondType]] = { - 0: Chem.BondType.TRIPLE, - 1: Chem.BondType.SINGLE, - 2: Chem.BondType.DOUBLE, - 3: Chem.BondType.AROMATIC, -} diff --git a/qek/data/datatools.py b/qek/data/datatools.py index 56eb562..6fecef1 100644 --- a/qek/data/datatools.py +++ b/qek/data/datatools.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from typing import Final import networkx as nx import numpy as np @@ -12,64 +13,10 @@ import torch_geometric.utils as pyg_utils from rdkit.Chem import AllChem -from qek.data.conversion_data import PTCFM_EDGES_MAP, PTCFM_NODES_MAP from qek.data.dataset import ProcessedData from qek.utils import graph_to_mol -def add_graph_coord( - graph: pyg_data.Data, - blockade_radius: float, - node_mapping: dict[int, str] = PTCFM_NODES_MAP, - edge_mapping: dict[int, Chem.BondType] = PTCFM_EDGES_MAP, -) -> pyg_data.Data: - """ - Take a molecule described as a graph with only nodes and edges, - add 2D coordinates. - - This function: - 1. Converts the graph into a molecule (using `node_mapping` and - `edge_mapping` to determine the types of atoms and bonds). - 2. Uses the molecule to determine coordinates. - 3. Injects the coordinates into the graph. - - Args: - graph: A homogeneous graph, in PyTorch Geometric format. Unchanged. - blockade_radius: The radius of the Rydberg Blockade. Two - connected nodes should be at a distance < blockade_radius, while - two disconnected nodes should be at a distance > blockade_radius. - node_mapping: A mapping of node labels from numbers to strings, - e.g. `5 => "Cl"`. Used when building molecules, e.g. to compute - distances between nodes. - edge_mapping: A mapping of edge labels from number to chemical - bond types, e.g. `2 => DOUBLE`. Used when building molecules, e.g. - to compute distances between nodes. - - Returns: - A clone of `graph` augmented with 2D coordinates. - """ - graph = graph.clone() - nx_graph = pyg_utils.to_networkx( - data=graph, - node_attrs=["x"], - edge_attrs=["edge_attr"], - to_undirected=True, - ) - tmp_mol = graph_to_mol( - graph=nx_graph, - node_mapping=node_mapping, - edge_mapping=edge_mapping, - ) - AllChem.Compute2DCoords(tmp_mol, useRingTemplates=True) - pos = tmp_mol.GetConformer().GetPositions()[..., :2] # Convert to 2D - dist_list = [] - for start, end in nx_graph.edges(): - dist_list.append(np.linalg.norm(pos[start] - pos[end])) - norm_factor = np.max(dist_list) - graph.pos = pos * blockade_radius / norm_factor - return graph - - def split_train_test( dataset: torch_data.Dataset, lengths: list[float], @@ -96,54 +43,6 @@ def split_train_test( return train, val -def check_compatibility_graph_device(graph: pyg_data.Data, device: pl.devices.Device) -> bool: - """Given the characteristics of a graph, return True if the graph can be embedded in - the device, False if not. - - Args: - graph (pyg_data): The graph to embeded - device (pulser.devices.Device): the device - - Returns: - bool: True if possible, False if not - """ - pos_graph = graph.pos - # check the number of atoms - if graph.num_nodes > device.max_atom_num: - return False - # Check the distance from the center - distance_from_center = np.linalg.norm(pos_graph, ord=2, axis=-1) - if any(distance_from_center > device.max_radial_distance): - return False - if _return_min_dist(graph) < device.min_atom_distance: - return False - return True - - -def _return_min_dist(graph: pyg_data.Data) -> float: - """Calculates the minimum distance between any two nodes in the graph, including - both original and complementary edges. - - Args: - graph (pyg_data.Data): The graph to calculate min distance from. - - Returns: - float: Minimum distance between any two nodes. - """ - nx_graph = pyg_utils.to_networkx(graph) - graph_pos = graph.pos - distances = [] - - # get min distance in the graph - for start, end in nx_graph.edges(): - distances.append(np.linalg.norm(graph_pos[start] - graph_pos[end], ord=2)) - compl_graph = nx.complement(nx_graph) - for start, end in compl_graph.edges(): - distances.append(np.linalg.norm(graph_pos[start] - graph_pos[end], ord=2)) - min_dist: float = min(distances) - return min_dist - - def save_dataset(dataset: list[ProcessedData], file_path: str) -> None: """Saves a dataset to a JSON file. @@ -191,3 +90,226 @@ def load_dataset(file_path: str) -> list[ProcessedData]: ) for item in data ] + + +class BaseGraph: + """ + A graph being prepared for embedding on a quantum device. + """ + + def __init__(self, data: pyg_data.Data): + """ + Create a graph from geometric data. + + Args: + data: A homogeneous graph, in PyTorch Geometric format. Unchanged. + It MUST have attributes 'pos' + """ + if not hasattr(data, "pos"): + raise AttributeError("The graph should have an attribute 'pos'.") + + # The graph in torch geometric format. + self.pyg = data.clone() + + # The graph in networkx format, undirected. + self.nx_graph = pyg_utils.to_networkx( + data=data, + node_attrs=["x"], + edge_attrs=["edge_attr"] if data.edge_attr is not None else None, + to_undirected=True, + ) + + def is_disk_graph(self, radius: float) -> bool: + """ + A predicate to check if `self` is a disk graph with the specified + radius, i.e. `self` is a connected graph and, for every pair of nodes + `A` and `B` within `graph`, there exists an edge between `A` and `B` + if and only if the positions of `A` and `B` within `self` are such + that `|AB| <= radius`. + + Args: + radius: The maximal distance between two nodes of `self` + connected be an edge. + + Returns: + `True` if the graph is a disk graph with the specified radius, + `False` otherwise. + """ + + if self.pyg.num_nodes == 0 or self.pyg.num_nodes is None: + return False + + # Check if the graph is connected. + if len(self.nx_graph) == 0 or not nx.is_connected(self.nx_graph): + return False + + # Check the distances between all pairs of nodes. + pos = self.pyg.pos + for u, v in nx.non_edges(self.nx_graph): + distance = np.linalg.norm(np.array(pos[u]) - np.array(pos[v])) + if distance <= radius: + return False + + for u, v in self.nx_graph.edges(): + distance = np.linalg.norm(np.array(pos[u]) - np.array(pos[v])) + if distance > radius: + return False + + return True + + def is_embeddable(self, device: pl.devices.Device) -> bool: + """ + A predicate to check if the graph can be embedded in the + quantum device. + + For a graph to be embeddable on a device, all the following + criteria must be fulfilled: + - the device must have at least as many atoms as the graph has + nodes; + - the device must be physically large enough to place all the + nodes (device.max_radial_distance); + - the nodes must be distant enough that quantum interactions + may take place (device.min_atom_distance) + + Args: + device (pulser.devices.Device): the device + + Returns: + bool: True if possible, False if not + """ + + # Check the number of atoms + if self.pyg.num_nodes > device.max_atom_num: + return False + + # Check the distance from the center + pos_graph = self.pyg.pos + distance_from_center = np.linalg.norm(pos_graph, ord=2, axis=-1) + if any(distance_from_center > device.max_radial_distance): + return False + + # Check the distance between nodes. + nodes = list(self.nx_graph.nodes) + for i in range(0, len(nodes)): + for j in range(i + 1, len(nodes)): + dist = np.linalg.norm(pos_graph[i] - pos_graph[j], ord=2) + if dist < device.min_atom_distance: + return False + + return True + + def compute_register(self) -> pl.Register: + """Create a Quantum Register based on a graph. + + Returns: + pulser.Register: register + """ + return pl.Register.from_coordinates(coords=self.pyg.pos) + + def compute_sequence(self, device: pl.devices.Device) -> pl.Sequence: + """ + Compile a Quantum Sequence from a graph for a specific device. + + Raises: + ValueError if the graph cannot be embedded on the given device. + """ + if not self.is_embeddable(device): + raise ValueError(f"The graph is not compatible with {device}") + reg = self.compute_register() + seq = pl.Sequence(register=reg, device=device) + + # See the companion paper for an explanation on these constants. + Omega_max = 1.0 * 2 * np.pi + t_max = 660 + pulse = pl.Pulse.ConstantAmplitude( + amplitude=Omega_max, + detuning=pl.waveforms.RampWaveform(t_max, 0, 0), + phase=0.0, + ) + seq.declare_channel("ising", "rydberg_global") + seq.add(pulse, "ising") + return seq + + +class MoleculeGraph(BaseGraph): + """ + A graph based on molecular data, being prepared for embedding on a + quantum device. + """ + + # Constants used to decode the PTC-FM dataset, mapping + # integers (used as node attributes) to atom names. + PTCFM_ATOM_NAMES: Final[dict[int, str]] = { + 0: "In", + 1: "P", + 2: "C", + 3: "O", + 4: "N", + 5: "Cl", + 6: "S", + 7: "Br", + 8: "Na", + 9: "F", + 10: "As", + 11: "K", + 12: "Cu", + 13: "I", + 14: "Ba", + 15: "Sn", + 16: "Pb", + 17: "Ca", + } + + # Constants used to decode the PTC-FM dataset, mapping + # integers (used as edge attributes) to bond types. + PTCFM_BOND_TYPES: Final[dict[int, Chem.BondType]] = { + 0: Chem.BondType.TRIPLE, + 1: Chem.BondType.SINGLE, + 2: Chem.BondType.DOUBLE, + 3: Chem.BondType.AROMATIC, + } + + def __init__( + self, + data: pyg_data.Data, + blockade_radius: float, + node_mapping: dict[int, str] = PTCFM_ATOM_NAMES, + edge_mapping: dict[int, Chem.BondType] = PTCFM_BOND_TYPES, + ): + """ + Compute the geometry for a molecule graph. + + Args: + data: A homogeneous graph, in PyTorch Geometric format. Unchanged. + blockade_radius: The radius of the Rydberg Blockade. Two + connected nodes should be at a distance < blockade_radius, + while two disconnected nodes should be at a + distance > blockade_radius. + node_mapping: A mapping of node labels from numbers to strings, + e.g. `5 => "Cl"`. Used when building molecules, e.g. to compute + distances between nodes. + edge_mapping: A mapping of edge labels from number to chemical + bond types, e.g. `2 => DOUBLE`. Used when building molecules, + e.g. to compute distances between nodes. + """ + pyg = data.clone() + pyg.pos = None # Placeholder + super().__init__(pyg) + + # Reconstruct the molecule. + tmp_mol = graph_to_mol( + graph=self.nx_graph, + node_mapping=node_mapping, + edge_mapping=edge_mapping, + ) + + # Extract the geometry. + AllChem.Compute2DCoords(tmp_mol, useRingTemplates=True) + pos = tmp_mol.GetConformer().GetPositions()[..., :2] # Convert to 2D + dist_list = [] + for start, end in self.nx_graph.edges(): + dist_list.append(np.linalg.norm(pos[start] - pos[end])) + norm_factor = np.max(dist_list) + + # Finally, store the geometry. + self.pyg.pos = pos * blockade_radius / norm_factor diff --git a/qek/kernel/__init__.py b/qek/kernel/__init__.py index 1a9a586..93e6f61 100644 --- a/qek/kernel/__init__.py +++ b/qek/kernel/__init__.py @@ -1,3 +1,3 @@ -from .kernel import Kernel +from .kernel import QuantumEvolutionKernel -__all__ = ["Kernel"] +__all__ = ["QuantumEvolutionKernel"] diff --git a/qek/kernel/kernel.py b/qek/kernel/kernel.py index 649f8e4..f09bd74 100644 --- a/qek/kernel/kernel.py +++ b/qek/kernel/kernel.py @@ -9,7 +9,7 @@ from qek.data.dataset import ProcessedData -class Kernel: +class QuantumEvolutionKernel: def __init__(self, mu: float): self.mu = mu diff --git a/qek/utils.py b/qek/utils.py index d3ef5fb..90dca85 100644 --- a/qek/utils.py +++ b/qek/utils.py @@ -3,10 +3,7 @@ import networkx as nx import numpy as np import numpy.typing as npt -import pulser as pl import rdkit.Chem as Chem -import torch_geometric.data as pyg_data -import torch_geometric.utils as pyg_utils def graph_to_mol( @@ -18,7 +15,8 @@ def graph_to_mol( Args: graph (nx.Graph): Networkx graph of a molecule. - mapping (MolMapping): Object containing dicts for edges and nodes attributes. + mapping (MolMapping): Object containing dicts for edges and nodes + attributes. Returns: Chem.Mol: The generated rdkit molecule. @@ -56,67 +54,3 @@ def inverse_one_hot(array: npt.ArrayLike, dim: int) -> np.ndarray: """ tmp_array = np.asarray(array) return np.nonzero(tmp_array == 1.0)[dim] - - -def is_disk_graph(graph: pyg_data.Data, radius: float) -> bool: - """ - Check if `graph` is a disk graph with the specified radius, i.e. - `graph` is a connected graph and, for every pair of nodes `A` and `B` - within `graph`, there exists there exists an edge between `A` and `B` - if and only if the positions of `A` and `B` within `graph` are such - that `|AB| <= radius`. - - Args: - graph: A homogeneous, undirected, graph, in PyTorch - Geometric format. This graph MUST have an attribute `pos`, as - provided e.g. by `datatools.add_graph_coord`. - radius: The maximal distance between two nodes of `graph` - connected be an edge. - - Returns: - `True` if the graph is a disk graph with the specified radius, - `False` otherwise. - """ - - if hasattr(graph, "pos"): - pos = graph.pos - else: - raise AttributeError("Graph object does not have a 'pos' attribute") - - if graph.num_nodes == 0 or graph.num_nodes is None: - return False - - # Molecule are undirected Graphs. - nx_graph = pyg_utils.to_networkx(graph, to_undirected=True) - - # Check if the graph is connected. - if len(nx_graph) == 0 or not nx.is_connected(nx_graph): - return False - - # Check the distances between all pairs of nodes. - for u, v in nx.non_edges(nx_graph): - distance = np.linalg.norm(np.array(pos[u]) - np.array(pos[v])) - if distance <= radius: - return False - - for u, v in nx_graph.edges(): - distance = np.linalg.norm(np.array(pos[u]) - np.array(pos[v])) - if distance > radius: - return False - - return True - - -def compute_register(data_graph: pyg_data.Data) -> pl.Register: - """Create a register based on a graph using pulser. - - Args: - data_graph (pyg_data.Data): graph. It should have a node attribute named UD_pos - - Returns: - pulser.Register: register - """ - if not hasattr(data_graph, "pos"): - raise AttributeError("Graph should have a pos attribute") - position = data_graph.pos - return pl.Register.from_coordinates(coords=position) diff --git a/tests/test_datatools.py b/tests/test_datatools.py index 06ec739..1f7139c 100644 --- a/tests/test_datatools.py +++ b/tests/test_datatools.py @@ -1,14 +1,15 @@ from __future__ import annotations import networkx as nx +import torch import torch_geometric.datasets as pyg_dataset import torch_geometric.utils as pyg_utils +from torch_geometric.data import Data -from qek.data.datatools import add_graph_coord -from qek.utils import is_disk_graph +from qek.data.datatools import BaseGraph, MoleculeGraph -def test_add_graph_coord() -> None: +def test_graph_init() -> None: # Load dataset original_ptcfm_data = pyg_dataset.TUDataset(root="dataset", name="PTC_FM") @@ -16,23 +17,161 @@ def test_add_graph_coord() -> None: RADIUS = 5.001 EPS = 0.01 - for graph in original_ptcfm_data: - augmented_graph = add_graph_coord(graph=graph, blockade_radius=RADIUS) + for data in original_ptcfm_data: + graph = MoleculeGraph(data=data, blockade_radius=RADIUS) # Make sure that the graph has been augmented with "pos". - assert hasattr(augmented_graph, "pos") + assert hasattr(graph.pyg, "pos") # Confirm that the augmented graph is isomorphic to the original graph. nx_graph = pyg_utils.to_networkx( - data=graph, + data=data, node_attrs=["x"], edge_attrs=["edge_attr"], to_undirected=True, ) - nx_reconstruct = pyg_utils.to_networkx(augmented_graph).to_undirected() + nx_reconstruct = pyg_utils.to_networkx(graph.pyg).to_undirected() assert nx.is_isomorphic(nx_graph, nx_reconstruct) # The first graph from the dataset is known to be a disk graph. - augmented_graph = add_graph_coord(graph=original_ptcfm_data[0], blockade_radius=RADIUS) - assert is_disk_graph(augmented_graph, RADIUS + EPS) + graph = MoleculeGraph(data=original_ptcfm_data[0], blockade_radius=RADIUS) + assert graph.is_disk_graph(RADIUS + EPS) + + +def test_is_disk_graph_false() -> None: + """ + Testing is_disk_graph: these graphs are *not* disk graphs + """ + # The empty graph is not a disk graph. + graph_empty = BaseGraph( + Data( + x=torch.tensor([], dtype=torch.float), + edge_index=torch.tensor([], dtype=torch.int), + pos=torch.tensor([], dtype=torch.float), + ) + ) + assert not graph_empty.is_disk_graph(radius=1.0) + + # This graph has three nodes, each pair of nodes is closer than + # the diameter, but it's not a disk graph because one of the nodes + # is not connected. + graph_disconnected_close = BaseGraph( + Data( + x=torch.tensor([[0], [1], [2]], dtype=torch.float), + edge_index=torch.tensor( + [ + [0, 1], # edge 0 -> 1 + [1, 0], # edge 1 -> 0 + ], + dtype=torch.int, + ), + pos=torch.tensor([[0], [1], [2]], dtype=torch.float), + ) + ) + assert not graph_disconnected_close.is_disk_graph(radius=10.0) + + # This graph has three nodes, all nodes are connected, but it's + # not a disk graph because one of the edges is longer than the + # diameter. + graph_connected_far = BaseGraph( + Data( + x=torch.tensor([[0], [1], [2]], dtype=torch.float), + edge_index=torch.tensor( + [ + [ + 0, + 1, # edge 0 -> 1 + 1, + 2, # edge 1 -> 2 + 0, + 2, # edge 0 -> 2 + ], + [ + 1, + 0, # edge 1 -> 0 + 2, + 1, # edge 2 -> 1 + 2, + 0, # edge 2 -> 0 + ], + ], + dtype=torch.int, + ), + pos=torch.tensor([[0], [1], [12]], dtype=torch.float), + ) + ) + assert not graph_connected_far.is_disk_graph(radius=10.0) + + # This graph has three nodes, each pair of nodes is within + # the disk's diameter, but it's not a disk graph because + # one of the pairs does not have an edge. + graph_partially_connected_close = BaseGraph( + Data( + x=torch.tensor([[0], [1], [2]], dtype=torch.float), + edge_index=torch.tensor( + [ + [ + 0, + 1, # edge 0 -> 1 + 1, + 2, # edge 1 -> 2 + ], + [ + 1, + 0, # edge 1 -> 0 + 2, + 1, # edge 2 -> 1 + ], + ], + dtype=torch.int, + ), + pos=torch.tensor([[0], [1], [2]], dtype=torch.float), + ) + ) + assert not graph_partially_connected_close.is_disk_graph(radius=10.0) + + +def test_is_disk_graph_true() -> None: + """ + Testing is_disk_graph: these graphs are disk graphs + """ + # Single node + graph_single_node = BaseGraph( + Data( + x=torch.tensor([0], dtype=torch.float), + edge_index=torch.tensor([]), + ) + ) + assert graph_single_node.is_disk_graph(radius=1.0) + + # A complete graph with three nodes, each of the edges + # is shorter than the disk's diameter. + graph_connected_close = BaseGraph( + Data( + x=torch.tensor([[0], [1], [2]], dtype=torch.float), + edge_index=torch.tensor( + [ + [ + 0, + 1, # edge 0 -> 1 + 1, + 2, # edge 1 -> 2 + 0, + 2, # edge 0 -> 2 + ], + [ + 1, + 0, # edge 1 -> 0 + 2, + 1, # edge 2 -> 1 + 2, + 0, # edge 2 -> 0 + ], + ], + dtype=torch.int, + ), + pos=torch.tensor([[0], [1], [2]], dtype=torch.float), + ) + ) + assert graph_connected_close.is_disk_graph(radius=10.0) diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index babd140..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,128 +0,0 @@ -from __future__ import annotations - -import torch -from torch_geometric.data import Data as Graph - -from qek.utils import is_disk_graph - - -def test_is_disk_graph_false() -> None: - """ - Testing is_disk_graph: these graphs are *not* disk graphs - """ - # The empty graph is not a disk graph. - graph_empty = Graph() - assert not is_disk_graph(graph_empty, radius=1.0) - - # This graph has three nodes, each pair of nodes is closer than - # the diameter, but it's not a disk graph because one of the nodes - # is not connected. - graph_disconnected_close = Graph( - x=torch.tensor([[0], [1], [2]], dtype=torch.float), - edge_index=torch.tensor( - [ - [0, 1], # edge 0 -> 1 - [1, 0], # edge 1 -> 0 - ], - dtype=torch.int, - ), - pos=torch.tensor([[0], [1], [2]], dtype=torch.float), - ) - assert not is_disk_graph(graph_disconnected_close, radius=10.0) - - # This graph has three nodes, all nodes are connected, but it's - # not a disk graph because one of the edges is longer than the - # diameter. - graph_connected_far = Graph( - x=torch.tensor([[0], [1], [2]], dtype=torch.float), - edge_index=torch.tensor( - [ - [ - 0, - 1, # edge 0 -> 1 - 1, - 2, # edge 1 -> 2 - 0, - 2, # edge 0 -> 2 - ], - [ - 1, - 0, # edge 1 -> 0 - 2, - 1, # edge 2 -> 1 - 2, - 0, # edge 2 -> 0 - ], - ], - dtype=torch.int, - ), - pos=torch.tensor([[0], [1], [12]], dtype=torch.float), - ) - assert not is_disk_graph(graph_connected_far, radius=10.0) - - # This graph has three nodes, each pair of nodes is within - # the disk's diameter, but it's not a disk graph because - # one of the pairs does not have an edge. - graph_partially_connected_close = Graph( - x=torch.tensor([[0], [1], [2]], dtype=torch.float), - edge_index=torch.tensor( - [ - [ - 0, - 1, # edge 0 -> 1 - 1, - 2, # edge 1 -> 2 - ], - [ - 1, - 0, # edge 1 -> 0 - 2, - 1, # edge 2 -> 1 - ], - ], - dtype=torch.int, - ), - pos=torch.tensor([[0], [1], [2]], dtype=torch.float), - ) - assert not is_disk_graph(graph_partially_connected_close, radius=10.0) - - -def test_is_disk_graph_true() -> None: - """ - Testing is_disk_graph: these graphs are disk graphs - """ - # Single node - graph_single_node = Graph( - x=torch.tensor([0], dtype=torch.float), - edge_index=torch.tensor([]), - ) - assert is_disk_graph(graph_single_node, radius=1.0) - - # A complete graph with three nodes, each of the edges - # is shorter than the disk's diameter. - graph_connected_close = Graph( - x=torch.tensor([[0], [1], [2]], dtype=torch.float), - edge_index=torch.tensor( - [ - [ - 0, - 1, # edge 0 -> 1 - 1, - 2, # edge 1 -> 2 - 0, - 2, # edge 0 -> 2 - ], - [ - 1, - 0, # edge 1 -> 0 - 2, - 1, # edge 2 -> 1 - 2, - 0, # edge 2 -> 0 - ], - ], - dtype=torch.int, - ), - pos=torch.tensor([[0], [1], [2]], dtype=torch.float), - ) - assert is_disk_graph(graph_connected_close, radius=10.0)