Skip to content

Commit

Permalink
[Cleanup] Gathering graph utilities around BasicGraph/MoleculeGraph
Browse files Browse the repository at this point in the history
Before this change, we have many graph-related functions with various
pre-requisites on graphs, plus lots of conversions from pyg to networkx.

This change:

- introduces two classes BasicGraph and MoleculeGraph, depending on
	which invariants hold on the graph;
- attaches the relevant methods to BasicGraph/MoleculeGraph instead of
	pyg;
- makes sure that we have only one single conversion from pyg to networkx.
  • Loading branch information
Yoric committed Jan 7, 2025
1 parent d0482ac commit 03af417
Show file tree
Hide file tree
Showing 5 changed files with 342 additions and 350 deletions.
52 changes: 22 additions & 30 deletions examples/pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,11 @@
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'torch_geometric'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch_geometric\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpyg_dataset\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch_geometric\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpyg_data\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mqek\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdataset\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mqek_dataset\u001b[39;00m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_geometric'"
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_782556/4114388664.py:9: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
" from tqdm.autonotebook import tqdm\n"
]
}
],
Expand Down Expand Up @@ -67,17 +64,9 @@
"metadata": {},
"outputs": [],
"source": [
"import qek.data.datatools as qek_datatools\n",
"from qek.utils import compute_register, is_disk_graph"
"import qek.data.datatools as qek_datatools"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -101,7 +90,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "093a63f828d44786afb3c91f8c9dd78f",
"model_id": "8dfa40146d2a40b4b26b378debf4d231",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -114,13 +103,13 @@
}
],
"source": [
"list_of_graph = []\n",
"list_of_graphs = []\n",
"RADIUS = 5.001\n",
"EPS = 0.01\n",
"for graph in tqdm(og_ptcfm):\n",
" graph_with_pos = qek_datatools.add_graph_coord(graph=graph, blockade_radius=RADIUS)\n",
" if is_disk_graph(graph_with_pos, radius=RADIUS+EPS):\n",
" list_of_graph.append((graph_with_pos, graph.y.item()))"
"for data in tqdm(og_ptcfm):\n",
" graph = qek_datatools.MoleculeGraph(data=data, blockade_radius=RADIUS)\n",
" if graph.is_disk_graph(radius=RADIUS+EPS):\n",
" list_of_graphs.append((graph, data.y.item()))"
]
},
{
Expand All @@ -143,9 +132,9 @@
"# Create a sequence:\n",
"\n",
"def create_sequence_from_graph(graph:pyg_data.Data, device=pl.devices.Device)-> pl.Sequence:\n",
" if not qek_datatools.check_compatibility_graph_device(graph, device):\n",
" if not graph.is_embeddable(device):\n",
" raise ValueError(f\"The graph is not compatible with {device}\")\n",
" reg = compute_register(data_graph=graph)\n",
" reg = graph.compute_register()\n",
" seq = pl.Sequence(register=reg, device=device)\n",
" Omega_max = 1.0 * 2 * np.pi\n",
" t_max = 660\n",
Expand All @@ -161,7 +150,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand All @@ -182,19 +171,22 @@
"The graph is not compatible with AnalogDevice\n",
"The graph is not compatible with AnalogDevice\n",
"The graph is not compatible with AnalogDevice\n",
"The graph is not compatible with AnalogDevice\n"
"The graph is not compatible with AnalogDevice\n",
"We may embed 279/294 graphs\n"
]
}
],
"source": [
"dataset_sequence = []\n",
"\n",
"for graph, target in list_of_graph:\n",
" # Some graph are not compatible with the AnalogDevice device\n",
"for graph, target in list_of_graphs:\n",
" # Not all graphs are compatible with AnalogDevice\n",
" try:\n",
" dataset_sequence.append((create_sequence_from_graph(graph, device=pl.AnalogDevice), target))\n",
" except ValueError as err:\n",
" print(f\"{err}\")"
" print(f\"{err}\")\n",
"\n",
"print(f\"We may embed {len(dataset_sequence)}/{len(list_of_graphs)} graphs\")"
]
},
{
Expand Down
Loading

0 comments on commit 03af417

Please sign in to comment.