Skip to content

Commit

Permalink
python3Packages.torcheval: init at 0.0.6
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin Sparks committed Feb 3, 2025
1 parent 45d9208 commit 801454c
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
119 changes: 119 additions & 0 deletions pkgs/development/python-modules/torcheval/default.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
{
lib,
buildPythonPackage,
fetchFromGitHub,
fetchPypi,

# build-system
setuptools,

# dependencies
torchtnt-nightly,
typing-extensions,

# tests
numpy,
pytestCheckHook,
pytest-timeout,
scikit-learn,
torchvision,
cython_0,
}:
let
pname = "torcheval";
version = "0.0.6";

# The torcheval 0.0.6 lib depends on a torchtnt>=0.0.5, however the available versions
# of torchtnt on nixpkgs (0.4.2 at the time of writing) are not compatible due to missing methods.
#
# To remedy this, the torchtnt-nightly release on PyPI that was published on the same day as
# the torcheval 0.0.6 lib is used in lieu of the torchtnt lib available on nixpkgs,
# as it is doubtful that this particular release of torchtnt is useful outside of this package.
torchtnt = torchtnt-nightly.overridePythonAttrs rec {
pname = "torchtnt-nightly";
version = "2023.1.25";
src = fetchPypi {
inherit pname version;
hash = "sha256-eFouZgdVQaqrXIKLY7geLV8GEVPGu6IHKIPK6aLJH8w=";
};
};
in
buildPythonPackage {
inherit pname version;
pyproject = true;

src = fetchFromGitHub {
owner = "pytorch";
repo = "torcheval";
tag = version;
hash = "sha256-FnMSPU8tjXegLH4speeyD8UDrKSvjf8STftt7aXTuJI=";
};

# Patches are only applied to usages of sklearn and numpy within tests,
# which are only used for testing purposes (see dev-requirements.txt)
postPatch =
# sklearn's confusion matrix's `normalize` keyword argument does not support "none".
# However, None and "none" appear twice in this test; The only missing case is "all".
''
substituteInPlace tests/metrics/functional/classification/test_confusion_matrix.py \
--replace-fail 'input, target, num_classes, normalize="none"' 'input, target, num_classes, normalize="all"'
''
# sklearn's mean squared error requires naming `sample_weight` due to the asterisk in
# mean_squared_error(y_true, y_pred, *, sample_weight=None, ...)
# ^^^
+ ''
substituteInPlace tests/metrics/window/test_mean_squared_error.py \
--replace-fail "torch.cat(update_weight[-2:], dim=0)," "sample_weight=torch.cat(update_weight[-2:], dim=0)," \
--replace-fail "torch.cat(update_weight, dim=0)," "sample_weight=torch.cat(update_weight, dim=0),"
''
# numpy's `np.NAN` was changed to `np.nan` when numpy 2 was released
+ ''
substituteInPlace tests/metrics/classification/test_accuracy.py tests/metrics/functional/classification/test_accuracy.py \
--replace-fail "np.NAN" "np.nan"
'';

build-system = [ setuptools ];

dependencies = [
torchtnt
typing-extensions
];

pythonImportsCheck = [ "torcheval" ];

nativeCheckInputs = [
pytestCheckHook
numpy
torchvision
pytest-timeout
cython_0
scikit-learn
];

pytestFlagsArray = [
"tests/"

# -- tests/tools/test_module_summary.py --
# models.alexnet(pretrained=True) -> PermissionError: [Errno 13] Permission denied: '/homeless-shelter'
# Touch filesystem and require network access.
"--deselect=tests/tools/test_module_summary.py::ModuleSummaryTest::test_alexnet_print"
"--deselect=tests/tools/test_module_summary.py::ModuleSummaryTest::test_alexnet_with_input_tensor"
"--deselect=tests/tools/test_module_summary.py::ModuleSummaryTest::test_forward_elapsed_time"
"--deselect=tests/tools/test_module_summary.py::ModuleSummaryTest::test_resnet_max_depth"

# -- tests/metrics/functional/text/test_perplexity.py --
# AssertionError: Scalars are not close!
# Expected 3.537154912949 but got 3.53715443611145
"--deselect=tests/metrics/functional/text/test_perplexity.py::Perplexity::test_perplexity_with_ignore_index"
];

meta = {
description = "Rich collection of performant PyTorch model metrics and tools for PyTorch model evaluations";
homepage = "https://pytorch.org/torcheval";
changelog = "https://github.com/pytorch/torcheval/releases/tag/${version}";

platforms = lib.platforms.linux;
license = with lib.licenses; [ bsd3 ];
maintainers = with lib.maintainers; [ bengsparks ];
};
}
2 changes: 2 additions & 0 deletions pkgs/top-level/python-packages.nix
Original file line number Diff line number Diff line change
Expand Up @@ -16438,6 +16438,8 @@ self: super: with self; {

torchdiffeq = callPackage ../development/python-modules/torchdiffeq { };

torcheval = callPackage ../development/python-modules/torcheval { };

torchmetrics = callPackage ../development/python-modules/torchmetrics { };

torchio = callPackage ../development/python-modules/torchio { };
Expand Down

0 comments on commit 801454c

Please sign in to comment.