Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusOrsoe committed Sep 21, 2024
2 parents cbc4228 + 9a0e9c5 commit 8270013
Show file tree
Hide file tree
Showing 19 changed files with 901 additions and 79 deletions.
1 change: 1 addition & 0 deletions .github/actions/install/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ runs:
run: |
echo requirements/torch_${{ inputs.hardware }}.txt ${{ env.PIP_FLAGS }} .${{ inputs.extras }}
pip install -r requirements/torch_${{ inputs.hardware }}.txt ${{ env.PIP_FLAGS }} .${{ inputs.extras }}
pip install git+https://github.com/thoglu/jammy_flows.git
shell: bash
10 changes: 10 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ jobs:
uses: ./.github/actions/install
with:
editable: true
- name: Print packages in pip
run: |
pip show torch
pip show torch-geometric
pip show torch-cluster
pip show torch-sparse
pip show torch-scatter
pip show jammy_flows
- name: Run unit tests and generate coverage report
run: |
coverage run --source=graphnet -m pytest tests/ --ignore=tests/examples/04_training --ignore=tests/utilities
Expand Down Expand Up @@ -110,6 +118,8 @@ jobs:
pip show torch-sparse
pip show torch-scatter
pip show numpy
- name: Run unit tests and generate coverage report
run: |
set -o pipefail # To propagate exit code from pytest
Expand Down
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,20 @@ repos:
rev: 4.0.1
hooks:
- id: flake8
language_version: python3
- repo: https://github.com/pycqa/docformatter
rev: v1.5.0
hooks:
- id: docformatter
language_version: python3
- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
hooks:
- id: pydocstyle
language_version: python3
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.982
hooks:
- id: mypy
args: [--follow-imports=silent, --disallow-untyped-defs, --disallow-incomplete-defs, --disallow-untyped-calls]
language_version: python3
4 changes: 2 additions & 2 deletions docs/source/datasets/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ After that, you can construct your :code:`Dataset` from a SQLite database with j

.. code-block:: python
from graphnet.data.sqlite import SQLiteDataset
from graphnet.data.dataset.sqlite.sqlite_dataset import SQLiteDataset
from graphnet.models.detector.prometheus import Prometheus
from graphnet.models.graphs import KNNGraph
from graphnet.models.graphs.nodes import NodesAsPulses
Expand All @@ -203,7 +203,7 @@ Or similarly for Parquet files:

.. code-block:: python
from graphnet.data.parquet import ParquetDataset
from graphnet.data.dataset.parquet.parquet_dataset import ParquetDataset
from graphnet.models.detector.prometheus import Prometheus
from graphnet.models.graphs import KNNGraph
from graphnet.models.graphs.nodes import NodesAsPulses
Expand Down
10 changes: 5 additions & 5 deletions docs/source/installation/quick-start.html
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,20 @@
}

if (os == "linux" && cuda != "cpu" && torch != "no_torch"){
$("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]`);
$("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`);
}
else if (os == "linux" && cuda == "cpu" && torch != "no_torch"){
$("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]`);
$("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`);
}
else if (os == "linux" && cuda == "cpu" && torch == "no_torch"){
$("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[develop]`);
$("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`);
}

if (os == "macos" && cuda == "cpu" && torch != "no_torch"){
$("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[torch,develop]`);
$("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[torch,develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`);
}
if (os == "macos" && cuda == "cpu" && torch == "no_torch"){
$("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[develop]`);
$("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`);
}
}

Expand Down
235 changes: 235 additions & 0 deletions examples/04_training/07_train_normalizing_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""Example of training a conditional NormalizingFlow."""

import os
from typing import Any, Dict, List, Optional

from pytorch_lightning.loggers import WandbLogger
import torch
from torch.optim.adam import Adam

from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR
from graphnet.data.constants import FEATURES, TRUTH
from graphnet.models.detector.prometheus import Prometheus
from graphnet.models.gnn import DynEdge
from graphnet.models.graphs import KNNGraph
from graphnet.training.callbacks import PiecewiseLinearLR
from graphnet.training.utils import make_train_validation_dataloader
from graphnet.utilities.argparse import ArgumentParser
from graphnet.utilities.logging import Logger
from graphnet.utilities.imports import has_jammy_flows_package

# Make sure the jammy flows is installed
try:
assert has_jammy_flows_package()
from graphnet.models import NormalizingFlow
except AssertionError:
raise AssertionError(
"This example requires the package`jammy_flow` "
" to be installed. It appears that the package is "
" not installed. Please install the package."
)

# Constants
features = FEATURES.PROMETHEUS
truth = TRUTH.PROMETHEUS


def main(
path: str,
pulsemap: str,
target: str,
truth_table: str,
gpus: Optional[List[int]],
max_epochs: int,
early_stopping_patience: int,
batch_size: int,
num_workers: int,
wandb: bool = False,
) -> None:
"""Run example."""
# Construct Logger
logger = Logger()

# Initialise Weights & Biases (W&B) run
if wandb:
# Make sure W&B output directory exists
wandb_dir = "./wandb/"
os.makedirs(wandb_dir, exist_ok=True)
wandb_logger = WandbLogger(
project="example-script",
entity="graphnet-team",
save_dir=wandb_dir,
log_model=True,
)

logger.info(f"features: {features}")
logger.info(f"truth: {truth}")

# Configuration
config: Dict[str, Any] = {
"path": path,
"pulsemap": pulsemap,
"batch_size": batch_size,
"num_workers": num_workers,
"target": target,
"early_stopping_patience": early_stopping_patience,
"fit": {
"gpus": gpus,
"max_epochs": max_epochs,
},
}

archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs")
run_name = "dynedge_{}_example".format(config["target"])
if wandb:
# Log configuration to W&B
wandb_logger.experiment.config.update(config)

# Define graph representation
graph_definition = KNNGraph(detector=Prometheus())

(
training_dataloader,
validation_dataloader,
) = make_train_validation_dataloader(
db=config["path"],
graph_definition=graph_definition,
pulsemaps=config["pulsemap"],
features=features,
truth=truth,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
truth_table=truth_table,
selection=None,
)

# Building model

backbone = DynEdge(
nb_inputs=graph_definition.nb_outputs,
global_pooling_schemes=["min", "max", "mean", "sum"],
)

model = NormalizingFlow(
graph_definition=graph_definition,
backbone=backbone,
optimizer_class=Adam,
target_labels=config["target"],
optimizer_kwargs={"lr": 1e-03, "eps": 1e-03},
scheduler_class=PiecewiseLinearLR,
scheduler_kwargs={
"milestones": [
0,
len(training_dataloader) / 2,
len(training_dataloader) * config["fit"]["max_epochs"],
],
"factors": [1e-2, 1, 1e-02],
},
scheduler_config={
"interval": "step",
},
)

# Training model
model.fit(
training_dataloader,
validation_dataloader,
early_stopping_patience=config["early_stopping_patience"],
logger=wandb_logger if wandb else None,
**config["fit"],
)

# Get predictions
additional_attributes = model.target_labels
assert isinstance(additional_attributes, list) # mypy

results = model.predict_as_dataframe(
validation_dataloader,
additional_attributes=additional_attributes + ["event_no"],
gpus=config["fit"]["gpus"],
)

# Save predictions and model to file
db_name = path.split("/")[-1].split(".")[0]
path = os.path.join(archive, db_name, run_name)
logger.info(f"Writing results to {path}")
os.makedirs(path, exist_ok=True)

# Save results as .csv
results.to_csv(f"{path}/results.csv")

# Save full model (including weights) to .pth file - not version safe
# Note: Models saved as .pth files in one version of graphnet
# may not be compatible with a different version of graphnet.
model.save(f"{path}/model.pth")

# Save model config and state dict - Version safe save method.
# This method of saving models is the safest way.
model.save_state_dict(f"{path}/state_dict.pth")
model.save_config(f"{path}/model_config.yml")


if __name__ == "__main__":

# Parse command-line arguments
parser = ArgumentParser(
description="""
Train conditional NormalizingFlow without the use of config files.
"""
)

parser.add_argument(
"--path",
help="Path to dataset file (default: %(default)s)",
default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db",
)

parser.add_argument(
"--pulsemap",
help="Name of pulsemap to use (default: %(default)s)",
default="total",
)

parser.add_argument(
"--target",
help=(
"Name of feature to use as regression target (default: "
"%(default)s)"
),
default="total_energy",
)

parser.add_argument(
"--truth-table",
help="Name of truth table to be used (default: %(default)s)",
default="mc_truth",
)

parser.with_standard_arguments(
"gpus",
("max-epochs", 1),
"early-stopping-patience",
("batch-size", 50),
"num-workers",
)

parser.add_argument(
"--wandb",
action="store_true",
help="If True, Weights & Biases are used to track the experiment.",
)

args, unknown = parser.parse_known_args()

main(
args.path,
args.pulsemap,
args.target,
args.truth_table,
args.gpus,
args.max_epochs,
args.early_stopping_patience,
args.batch_size,
args.num_workers,
args.wandb,
)
13 changes: 9 additions & 4 deletions src/graphnet/data/dataconverter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains `DataConverter`."""

from typing import List, Union, OrderedDict, Dict, Tuple, Any, Optional, Type
from abc import ABC

Expand Down Expand Up @@ -260,8 +261,8 @@ def _request_event_nos(self, n_ids: int) -> List[int]:
event_nos = np.arange(start_idx, start_idx + n_ids, 1).tolist()
global_index.value += n_ids # type: ignore[name-defined]
else:
starting_index = self._index
event_nos = np.arange(starting_index, starting_index + n_ids, 1).tolist()
start_idx = self._index
event_nos = np.arange(start_idx, start_idx + n_ids, 1).tolist()
self._index += n_ids

return event_nos
Expand Down Expand Up @@ -316,7 +317,9 @@ def _update_shared_variables(
self._output_files.extend(list(sorted(output_files[:])))

@final
def merge_files(self, files: Optional[List[str]] = None, **kwargs: Any) -> None:
def merge_files(
self, files: Optional[Union[List[str], str]] = None, **kwargs: Any
) -> None:
"""Merge converted files.
`DataConverter` will call the `.merge_files` method in the
Expand All @@ -332,7 +335,9 @@ def merge_files(self, files: Optional[List[str]] = None, **kwargs: Any) -> None:
elif files is not None:
# Proceed to merge specified by user.
if isinstance(files, str):
files = [files] # Cast to list if user forgot
# We shouldn't merge a single file?
self.info(f"Got just a single file {files}. Merging skipped.")
return
files_to_merge = files
else:
# Raise error
Expand Down
Loading

0 comments on commit 8270013

Please sign in to comment.