diff --git a/hierarchicalsoftmax/dotexporter.py b/hierarchicalsoftmax/dotexporter.py index 701114c..44d96c0 100644 --- a/hierarchicalsoftmax/dotexporter.py +++ b/hierarchicalsoftmax/dotexporter.py @@ -55,7 +55,7 @@ def _default_edgeattrfunc( def exclude_node(self, node): return not node.is_root and node not in self.greedy_nodes and self.probabilities[node.index_in_softmax_layer] < self.threshold - def _DotExporter__iter_nodes(self, indent, nodenamefunc, nodeattrfunc): + def _DotExporter__iter_nodes(self, indent, nodenamefunc, nodeattrfunc, *args, **kwargs): for node in PreOrderIter(self.node, maxlevel=self.maxlevel): if self.exclude_node(node): continue @@ -64,7 +64,7 @@ def _DotExporter__iter_nodes(self, indent, nodenamefunc, nodeattrfunc): nodeattr = " [%s]" % nodeattr if nodeattr is not None else "" yield '%s"%s"%s;' % (indent, DotExporter.esc(nodename), nodeattr) - def _DotExporter__iter_edges(self, indent, nodenamefunc, edgeattrfunc, edgetypefunc): + def _DotExporter__iter_edges(self, indent, nodenamefunc, edgeattrfunc, edgetypefunc, *args, **kwargs): maxlevel = self.maxlevel - 1 if self.maxlevel else None for node in PreOrderIter(self.node, maxlevel=maxlevel): nodename = nodenamefunc(node) diff --git a/hierarchicalsoftmax/inference.py b/hierarchicalsoftmax/inference.py index cb15b79..83a0839 100644 --- a/hierarchicalsoftmax/inference.py +++ b/hierarchicalsoftmax/inference.py @@ -2,7 +2,6 @@ import torch from pathlib import Path from anytree import PreOrderIter -from functools import partial from . import nodes from .dotexporter import ThresholdDotExporter @@ -52,7 +51,6 @@ def leaf_probabilities(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode) - return torch.index_select(probabilities, 1, root.leaf_indexes_in_softmax_layer) - def greedy_predictions(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:Optional[int]=None, threshold:Optional[float]=None) -> List[nodes.SoftmaxNode]: """ Takes the prediction scores for a number of samples and converts it to a list of predictions of nodes in the tree. diff --git a/hierarchicalsoftmax/layers.py b/hierarchicalsoftmax/layers.py index bc38b0c..caa1518 100644 --- a/hierarchicalsoftmax/layers.py +++ b/hierarchicalsoftmax/layers.py @@ -1,6 +1,7 @@ from torch import nn from .nodes import SoftmaxNode +from .tensors import LazyLinearTensor class HierarchicalSoftmaxLayerError(RuntimeError): pass @@ -17,6 +18,9 @@ def __init__(self, root:SoftmaxNode, out_features=None, **kwargs): super().__init__(out_features=self.root.layer_size, **kwargs) + def forward(self, x) -> LazyLinearTensor: + return LazyLinearTensor(x, weight=self.weight, bias=self.bias) + class HierarchicalSoftmaxLinear(HierarchicalSoftmaxLayerMixin, nn.Linear): """ diff --git a/hierarchicalsoftmax/metrics.py b/hierarchicalsoftmax/metrics.py index 4b9ec79..8e1689f 100644 --- a/hierarchicalsoftmax/metrics.py +++ b/hierarchicalsoftmax/metrics.py @@ -137,3 +137,19 @@ def greedy_accuracy_parent(prediction_tensor, target_tensor, root, max_depth=Non return (prediction_parent_ids.to(target_parent_ids.device) == target_parent_ids).float().mean() +class GreedyAccuracy(): + name:str = "greedy" + + def __init__(self, root:nodes.SoftmaxNode, name="greedy_accuracy", max_depth=None): + self.max_depth = max_depth + self.name = name + self.root = root + + @property + def __name__(self): + """ For using as a FastAI metric. """ + return self.name + + def __call__(self, predictions, targets): + return greedy_accuracy(predictions, targets, self.root, max_depth=self.max_depth) + diff --git a/hierarchicalsoftmax/tensors.py b/hierarchicalsoftmax/tensors.py new file mode 100644 index 0000000..5ae41e5 --- /dev/null +++ b/hierarchicalsoftmax/tensors.py @@ -0,0 +1,100 @@ +from torch import Tensor, Size +from functools import cached_property +import torch.nn.functional as F +from torch.nn.parameter import Parameter + + +class LazyLinearTensor(Tensor): + """ + A tensor that is designed to be used with HierarchicalSoftmaxLazyLinear layers. + """ + @staticmethod + def __new__(cls, x, weight:Parameter, bias:Parameter, *args, **kwargs): + return super().__new__(cls, x, *args, **kwargs) + + def __init__(self, x:Tensor, weight:Parameter, bias:Parameter, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input = x + self.weight = weight + self.bias = bias + + @cached_property + def result(self): + return F.linear(self.input, self.weight, self.bias) + + def __add__(self, other): + return self.result + other + + def __sub__(self, other): + return self.result - other + + def __mul__(self, other): + return self.result * other + + def __truediv__(self, other): + return self.result / other + + def __matmul__(self, other): + return self.result @ other + + def __radd__(self, other): + return other + self.result + + def __rsub__(self, other): + return other - self.result + + def __rmul__(self, other): + return other * self.result + + def __rtruediv__(self, other): + return other / self.result + + def __rmatmul__(self, other): + return other @ self.result + + def __getitem__(self, index): + assert isinstance(index, int) or isinstance(index, slice) or isinstance(index, tuple) + if not isinstance(index, tuple) or isinstance(index, slice): + index = (index,) + + my_shape = self.shape + if len(index) < len(my_shape): + return LazyLinearTensor(self.input[index], weight=self.weight, bias=self.bias) + if len(index) > len(my_shape): + raise IndexError(f"Cannot get index '{index}' for LazyLinearTensor of shape {len(my_shape)}") + + input = self.input[index[:-1]] + weight = self.weight[index[-1]] + bias = self.bias[index[-1]] + return F.linear(input, weight, bias) + + @property + def shape(self) -> Size: + return Size( self.input.shape[:-1] + (self.weight.shape[0],) ) + + def __str__(self) -> str: + return f"LazyLinearTensor (shape={tuple(self.shape)})" + + def __repr__(self) -> str: + return str(self) + + def __len__(self) -> int: + return self.shape[0] + + def __iter__(self): + for i in range(len(self)): + yield self[i] + + def float(self): + x = super().float() + x.input = self.input.float() + x.weight = self.weight.float() + x.bias = self.bias.float() + return x + + def half(self): + x = super().half() + x.input = self.input.half() + x.weight = self.weight.half() + x.bias = self.bias.half() + return x \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 1452792..e040511 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,10 +1,9 @@ -# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "alabaster" version = "0.7.13" description = "A configurable sidebar-enabled Sphinx theme" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -16,7 +15,6 @@ files = [ name = "anytree" version = "2.8.0" description = "Powerful and Lightweight Python Tree Data Structure.." -category = "main" optional = false python-versions = "*" files = [ @@ -35,7 +33,6 @@ test = ["coverage"] name = "appnope" version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" -category = "dev" optional = false python-versions = "*" files = [ @@ -47,7 +44,6 @@ files = [ name = "asttokens" version = "2.2.1" description = "Annotate AST trees with source code positions" -category = "dev" optional = false python-versions = "*" files = [ @@ -65,7 +61,6 @@ test = ["astroid", "pytest"] name = "atomicwrites" version = "1.4.1" description = "Atomic file writes." -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -76,7 +71,6 @@ files = [ name = "attrs" version = "21.4.0" description = "Classes Without Boilerplate" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -94,7 +88,6 @@ tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy" name = "autopep8" version = "1.7.0" description = "A tool that automatically formats Python code to conform to the PEP 8 style guide" -category = "dev" optional = false python-versions = "*" files = [ @@ -110,7 +103,6 @@ toml = "*" name = "babel" version = "2.12.1" description = "Internationalization utilities" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -125,7 +117,6 @@ pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""} name = "backcall" version = "0.2.0" description = "Specifications for callback functions passed in to an API" -category = "dev" optional = false python-versions = "*" files = [ @@ -137,7 +128,6 @@ files = [ name = "beautifulsoup4" version = "4.12.2" description = "Screen-scraping library" -category = "dev" optional = false python-versions = ">=3.6.0" files = [ @@ -156,7 +146,6 @@ lxml = ["lxml"] name = "black" version = "21.12b0" description = "The uncompromising code formatter." -category = "dev" optional = false python-versions = ">=3.6.2" files = [ @@ -171,8 +160,8 @@ pathspec = ">=0.9.0,<1" platformdirs = ">=2" tomli = ">=0.2.6,<2.0.0" typing-extensions = [ - {version = ">=3.10.0.0", markers = "python_version < \"3.10\""}, {version = ">=3.10.0.0,<3.10.0.1 || >3.10.0.1", markers = "python_version >= \"3.10\""}, + {version = ">=3.10.0.0", markers = "python_version < \"3.10\""}, ] [package.extras] @@ -186,7 +175,6 @@ uvloop = ["uvloop (>=0.15.2)"] name = "bleach" version = "6.0.0" description = "An easy safelist-based HTML-sanitizing tool." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -205,7 +193,6 @@ css = ["tinycss2 (>=1.1.0,<1.2)"] name = "certifi" version = "2023.5.7" description = "Python package for providing Mozilla's CA Bundle." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -217,7 +204,6 @@ files = [ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." -category = "dev" optional = false python-versions = "*" files = [ @@ -294,7 +280,6 @@ pycparser = "*" name = "cfgv" version = "3.3.1" description = "Validate configuration and produce human readable error messages." -category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -306,7 +291,6 @@ files = [ name = "charset-normalizer" version = "3.1.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -391,7 +375,6 @@ files = [ name = "click" version = "8.1.3" description = "Composable command line interface toolkit" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -406,7 +389,6 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -418,7 +400,6 @@ files = [ name = "comm" version = "0.1.3" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -438,7 +419,6 @@ typing = ["mypy (>=0.990)"] name = "commonmark" version = "0.9.1" description = "Python parser for the CommonMark Markdown spec" -category = "main" optional = false python-versions = "*" files = [ @@ -453,7 +433,6 @@ test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"] name = "coverage" version = "5.5" description = "Code coverage measurement for Python" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" files = [ @@ -518,7 +497,6 @@ toml = ["toml"] name = "debugpy" version = "1.6.7" description = "An implementation of the Debug Adapter Protocol for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -546,7 +524,6 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -558,7 +535,6 @@ files = [ name = "defusedxml" version = "0.7.1" description = "XML bomb protection for Python stdlib modules" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -570,7 +546,6 @@ files = [ name = "distlib" version = "0.3.6" description = "Distribution utilities" -category = "dev" optional = false python-versions = "*" files = [ @@ -582,7 +557,6 @@ files = [ name = "docutils" version = "0.17.1" description = "Docutils -- Python Documentation Utilities" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -594,7 +568,6 @@ files = [ name = "executing" version = "1.2.0" description = "Get the currently executing AST node of a frame, and other information" -category = "dev" optional = false python-versions = "*" files = [ @@ -609,7 +582,6 @@ tests = ["asttokens", "littleutils", "pytest", "rich"] name = "fastjsonschema" version = "2.17.1" description = "Fastest Python implementation of JSON schema" -category = "dev" optional = false python-versions = "*" files = [ @@ -624,7 +596,6 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.12.0" description = "A platform independent file lock." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -640,7 +611,6 @@ testing = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "diff-cover (>=7.5)", "p name = "identify" version = "2.5.24" description = "File identification library for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -655,7 +625,6 @@ license = ["ukkonen"] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -667,7 +636,6 @@ files = [ name = "imagesize" version = "1.4.1" description = "Getting image size from png/jpeg/jpeg2000/gif file" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -679,7 +647,6 @@ files = [ name = "importlib-metadata" version = "6.6.0" description = "Read metadata from Python packages" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -699,7 +666,6 @@ testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packag name = "importlib-resources" version = "5.12.0" description = "Read resources from Python packages" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -718,7 +684,6 @@ testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-chec name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -730,7 +695,6 @@ files = [ name = "ipykernel" version = "6.23.1" description = "IPython Kernel for Jupyter" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -744,7 +708,7 @@ comm = ">=0.1.1" debugpy = ">=1.6.5" ipython = ">=7.23.1" jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" matplotlib-inline = ">=0.1" nest-asyncio = "*" packaging = "*" @@ -764,7 +728,6 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio" name = "ipython" version = "8.12.2" description = "IPython: Productive Interactive Computing" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -804,7 +767,6 @@ test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pa name = "jedi" version = "0.18.2" description = "An autocompletion tool for Python that can be used for text editors." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -824,7 +786,6 @@ testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -842,7 +803,6 @@ i18n = ["Babel (>=2.7)"] name = "joblib" version = "1.2.0" description = "Lightweight pipelining with Python functions" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -854,7 +814,6 @@ files = [ name = "jsonschema" version = "4.17.3" description = "An implementation of JSON Schema validation for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -876,7 +835,6 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jupyter-client" version = "8.2.0" description = "Jupyter protocol implementation and client libraries" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -886,7 +844,7 @@ files = [ [package.dependencies] importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" tornado = ">=6.2" @@ -900,7 +858,6 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt name = "jupyter-core" version = "5.3.0" description = "Jupyter core package. A base package on which Jupyter projects rely." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -921,7 +878,6 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyterlab-pygments" version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -933,7 +889,6 @@ files = [ name = "livereload" version = "2.6.3" description = "Python LiveReload is an awesome tool for web developers" -category = "dev" optional = false python-versions = "*" files = [ @@ -949,7 +904,6 @@ tornado = {version = "*", markers = "python_version > \"2.7\""} name = "markdown-it-py" version = "1.1.0" description = "Python port of markdown-it. Markdown parsing, done right!" -category = "dev" optional = false python-versions = "~=3.6" files = [ @@ -972,7 +926,6 @@ testing = ["coverage", "psutil", "pytest (>=3.6,<4)", "pytest-benchmark (>=3.2,< name = "markupsafe" version = "2.1.3" description = "Safely add untrusted strings to HTML/XML markup." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -996,6 +949,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -1032,7 +995,6 @@ files = [ name = "matplotlib-inline" version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -1047,7 +1009,6 @@ traitlets = "*" name = "mdit-py-plugins" version = "0.2.8" description = "Collection of plugins for markdown-it-py" -category = "dev" optional = false python-versions = "~=3.6" files = [ @@ -1067,7 +1028,6 @@ testing = ["coverage", "pytest (>=3.6,<4)", "pytest-cov", "pytest-regressions"] name = "mistune" version = "2.0.5" description = "A sane Markdown parser with useful plugins and renderers" -category = "dev" optional = false python-versions = "*" files = [ @@ -1079,7 +1039,6 @@ files = [ name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -1091,7 +1050,6 @@ files = [ name = "myst-parser" version = "0.15.2" description = "An extended commonmark compliant parser, with bridges to docutils & sphinx." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1117,7 +1075,6 @@ testing = ["beautifulsoup4", "coverage", "docutils (>=0.17.0,<0.18.0)", "pytest name = "nbclient" version = "0.8.0" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." -category = "dev" optional = false python-versions = ">=3.8.0" files = [ @@ -1127,7 +1084,7 @@ files = [ [package.dependencies] jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" nbformat = ">=5.1" traitlets = ">=5.4" @@ -1140,7 +1097,6 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= name = "nbconvert" version = "7.4.0" description = "Converting Jupyter Notebooks" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1179,7 +1135,6 @@ webpdf = ["pyppeteer (>=1,<1.1)"] name = "nbformat" version = "5.9.0" description = "The Jupyter Notebook format" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1201,7 +1156,6 @@ test = ["pep440", "pre-commit", "pytest", "testpath"] name = "nbsphinx" version = "0.8.12" description = "Jupyter Notebook Tools for Sphinx" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1221,7 +1175,6 @@ traitlets = ">=5" name = "nest-asyncio" version = "1.5.6" description = "Patch asyncio to allow nested event loops" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -1233,7 +1186,6 @@ files = [ name = "nodeenv" version = "1.8.0" description = "Node.js virtual environment builder" -category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" files = [ @@ -1248,7 +1200,6 @@ setuptools = "*" name = "numpy" version = "1.24.3" description = "Fundamental package for array computing in Python" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1286,7 +1237,6 @@ files = [ name = "nvidia-cublas-cu11" version = "11.10.3.66" description = "CUBLAS native runtime libraries" -category = "main" optional = false python-versions = ">=3" files = [ @@ -1302,7 +1252,6 @@ wheel = "*" name = "nvidia-cuda-nvrtc-cu11" version = "11.7.99" description = "NVRTC native runtime libraries" -category = "main" optional = false python-versions = ">=3" files = [ @@ -1319,7 +1268,6 @@ wheel = "*" name = "nvidia-cuda-runtime-cu11" version = "11.7.99" description = "CUDA Runtime native Libraries" -category = "main" optional = false python-versions = ">=3" files = [ @@ -1335,7 +1283,6 @@ wheel = "*" name = "nvidia-cudnn-cu11" version = "8.5.0.96" description = "cuDNN runtime libraries" -category = "main" optional = false python-versions = ">=3" files = [ @@ -1351,7 +1298,6 @@ wheel = "*" name = "packaging" version = "23.1" description = "Core utilities for Python packages" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1363,7 +1309,6 @@ files = [ name = "pandocfilters" version = "1.5.0" description = "Utilities for writing pandoc filters in python" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1375,7 +1320,6 @@ files = [ name = "parso" version = "0.8.3" description = "A Python Parser" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1391,7 +1335,6 @@ testing = ["docopt", "pytest (<6.0.0)"] name = "pathspec" version = "0.11.1" description = "Utility library for gitignore style pattern matching of file paths." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1403,7 +1346,6 @@ files = [ name = "pexpect" version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." -category = "dev" optional = false python-versions = "*" files = [ @@ -1418,7 +1360,6 @@ ptyprocess = ">=0.5" name = "pickleshare" version = "0.7.5" description = "Tiny 'shelve'-like database with concurrency support" -category = "dev" optional = false python-versions = "*" files = [ @@ -1430,7 +1371,6 @@ files = [ name = "pkgutil-resolve-name" version = "1.3.10" description = "Resolve a name to an object." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1442,7 +1382,6 @@ files = [ name = "platformdirs" version = "3.5.1" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1458,7 +1397,6 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest- name = "pluggy" version = "1.0.0" description = "plugin and hook calling mechanisms for python" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1474,7 +1412,6 @@ testing = ["pytest", "pytest-benchmark"] name = "pre-commit" version = "2.21.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1493,7 +1430,6 @@ virtualenv = ">=20.10.0" name = "prompt-toolkit" version = "3.0.38" description = "Library for building powerful interactive command lines in Python" -category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -1508,7 +1444,6 @@ wcwidth = "*" name = "psutil" version = "5.9.5" description = "Cross-platform lib for process and system monitoring in Python." -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1535,7 +1470,6 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" -category = "dev" optional = false python-versions = "*" files = [ @@ -1547,7 +1481,6 @@ files = [ name = "pure-eval" version = "0.2.2" description = "Safely evaluate AST nodes without side effects" -category = "dev" optional = false python-versions = "*" files = [ @@ -1562,7 +1495,6 @@ tests = ["pytest"] name = "py" version = "1.11.0" description = "library with cross-python path, ini-parsing, io, code, log facilities" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -1574,7 +1506,6 @@ files = [ name = "pycodestyle" version = "2.10.0" description = "Python style guide checker" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1586,7 +1517,6 @@ files = [ name = "pycparser" version = "2.21" description = "C parser in Python" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1598,7 +1528,6 @@ files = [ name = "pygments" version = "2.15.1" description = "Pygments is a syntax highlighting package written in Python." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1613,7 +1542,6 @@ plugins = ["importlib-metadata"] name = "pyrsistent" version = "0.19.3" description = "Persistent/Functional/Immutable data structures" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1650,7 +1578,6 @@ files = [ name = "pytest" version = "6.2.5" description = "pytest: simple powerful testing with Python" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1675,7 +1602,6 @@ testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xm name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" -category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -1690,7 +1616,6 @@ six = ">=1.5" name = "pytz" version = "2023.3" description = "World timezone definitions, modern and historical" -category = "dev" optional = false python-versions = "*" files = [ @@ -1702,7 +1627,6 @@ files = [ name = "pywin32" version = "306" description = "Python for Window Extensions" -category = "dev" optional = false python-versions = "*" files = [ @@ -1726,7 +1650,6 @@ files = [ name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1776,7 +1699,6 @@ files = [ name = "pyzmq" version = "25.1.0" description = "Python bindings for 0MQ" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1866,7 +1788,6 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} name = "requests" version = "2.31.0" description = "Python HTTP for Humans." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1888,7 +1809,6 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "rich" version = "10.16.2" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" -category = "main" optional = false python-versions = ">=3.6.2,<4.0.0" files = [ @@ -1908,7 +1828,6 @@ jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] name = "scikit-learn" version = "1.2.2" description = "A set of python modules for machine learning and data mining" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1951,7 +1870,6 @@ tests = ["black (>=22.3.0)", "flake8 (>=3.8.2)", "matplotlib (>=3.1.3)", "mypy ( name = "scipy" version = "1.9.3" description = "Fundamental algorithms for scientific computing in Python" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1990,7 +1908,6 @@ test = ["asv", "gmpy2", "mpmath", "pytest", "pytest-cov", "pytest-xdist", "sciki name = "setuptools" version = "67.8.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2007,7 +1924,6 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -2019,7 +1935,6 @@ files = [ name = "snowballstemmer" version = "2.2.0" description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." -category = "dev" optional = false python-versions = "*" files = [ @@ -2031,7 +1946,6 @@ files = [ name = "soupsieve" version = "2.4.1" description = "A modern CSS selector implementation for Beautiful Soup." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2043,7 +1957,6 @@ files = [ name = "sphinx" version = "4.5.0" description = "Python documentation generator" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2079,7 +1992,6 @@ test = ["cython", "html5lib", "pytest", "pytest-cov", "typed-ast"] name = "sphinx-autobuild" version = "2021.3.14" description = "Rebuild Sphinx documentation on changes, with live-reload in the browser." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2099,7 +2011,6 @@ test = ["pytest", "pytest-cov"] name = "sphinx-click" version = "3.1.1.dev4" description = "Sphinx extension that automatically documents click applications" -category = "dev" optional = false python-versions = ">=3.6" files = [] @@ -2121,7 +2032,6 @@ resolved_reference = "ec700fc5864f42cfb71f143f0ec077c89e1102eb" name = "sphinx-copybutton" version = "0.4.0" description = "Add a copy button to each of your code cells." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2140,7 +2050,6 @@ rtd = ["ipython", "sphinx", "sphinx-book-theme"] name = "sphinx-rtd-theme" version = "1.2.2" description = "Read the Docs theme for Sphinx" -category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -2160,7 +2069,6 @@ dev = ["bump2version", "sphinxcontrib-httpdomain", "transifex-client", "wheel"] name = "sphinxcontrib-applehelp" version = "1.0.4" description = "sphinxcontrib-applehelp is a Sphinx extension which outputs Apple help books" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2176,7 +2084,6 @@ test = ["pytest"] name = "sphinxcontrib-devhelp" version = "1.0.2" description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2192,7 +2099,6 @@ test = ["pytest"] name = "sphinxcontrib-htmlhelp" version = "2.0.1" description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2208,7 +2114,6 @@ test = ["html5lib", "pytest"] name = "sphinxcontrib-jquery" version = "4.1" description = "Extension to include jQuery on newer Sphinx releases" -category = "dev" optional = false python-versions = ">=2.7" files = [ @@ -2223,7 +2128,6 @@ Sphinx = ">=1.8" name = "sphinxcontrib-jsmath" version = "1.0.1" description = "A sphinx extension which renders display math in HTML via JavaScript" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2238,7 +2142,6 @@ test = ["flake8", "mypy", "pytest"] name = "sphinxcontrib-qthelp" version = "1.0.3" description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2254,7 +2157,6 @@ test = ["pytest"] name = "sphinxcontrib-serializinghtml" version = "1.1.5" description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2270,7 +2172,6 @@ test = ["pytest"] name = "stack-data" version = "0.6.2" description = "Extract data from python stack frames and tracebacks for informative displays" -category = "dev" optional = false python-versions = "*" files = [ @@ -2290,7 +2191,6 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] name = "threadpoolctl" version = "3.1.0" description = "threadpoolctl" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2302,7 +2202,6 @@ files = [ name = "tinycss2" version = "1.2.1" description = "A tiny CSS parser" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2321,7 +2220,6 @@ test = ["flake8", "isort", "pytest"] name = "toml" version = "0.10.2" description = "Python Library for Tom's Obvious, Minimal Language" -category = "dev" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -2333,7 +2231,6 @@ files = [ name = "tomli" version = "1.2.3" description = "A lil' TOML parser" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2345,7 +2242,6 @@ files = [ name = "torch" version = "1.13.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -2386,7 +2282,6 @@ opt-einsum = ["opt-einsum (>=3.3)"] name = "tornado" version = "6.3.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." -category = "dev" optional = false python-versions = ">= 3.8" files = [ @@ -2407,7 +2302,6 @@ files = [ name = "traitlets" version = "5.9.0" description = "Traitlets Python configuration system" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2423,7 +2317,6 @@ test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] name = "typer" version = "0.9.0" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2445,7 +2338,6 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6. name = "typing-extensions" version = "4.6.3" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2457,7 +2349,6 @@ files = [ name = "urllib3" version = "2.0.3" description = "HTTP library with thread-safe connection pooling, file post, and more." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2475,7 +2366,6 @@ zstd = ["zstandard (>=0.18.0)"] name = "virtualenv" version = "20.23.0" description = "Virtual Python Environment builder" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2496,7 +2386,6 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "coverage-enable-subprocess name = "wcwidth" version = "0.2.6" description = "Measures the displayed width of unicode strings in a terminal" -category = "dev" optional = false python-versions = "*" files = [ @@ -2508,7 +2397,6 @@ files = [ name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" -category = "dev" optional = false python-versions = "*" files = [ @@ -2520,7 +2408,6 @@ files = [ name = "wheel" version = "0.40.0" description = "A built-package format for Python" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2535,7 +2422,6 @@ test = ["pytest (>=6.0.0)"] name = "zipp" version = "3.15.0" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2550,4 +2436,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "2e038764d33f447f1d07a7eaead4c012234dc0c42d7ca196a141942acf0f9bf0" +content-hash = "97397878342455c36eac58f7f3f8f4ddf50a1f79f30ec130e49eb32a78f03a4a" diff --git a/pyproject.toml b/pyproject.toml index 81fb793..de30692 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "hierarchicalsoftmax" -version = "1.0.3" +version = "1.1.0" description = "A Hierarchical Softmax Framework for PyTorch." authors = ["Robert Turnbull "] license = "Apache-2.0" diff --git a/tests/test_layers.py b/tests/test_layers.py index d45f8e6..cc195b2 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -2,9 +2,13 @@ from torch import nn from hierarchicalsoftmax import HierarchicalSoftmaxLinear, HierarchicalSoftmaxLazyLinear from hierarchicalsoftmax.layers import HierarchicalSoftmaxLayerError +from hierarchicalsoftmax.tensors import LazyLinearTensor + +import torch from .util import depth_two_tree, assert_multiline_strings + def test_linear_layer(): layer = HierarchicalSoftmaxLinear(in_features=100, root=depth_two_tree()) @@ -36,4 +40,14 @@ def test_linear_model(): (1): ReLU() (2): HierarchicalSoftmaxLinear(in_features=100, out_features=6, bias=True) ) - """) \ No newline at end of file + """) + + +def test_forward(): + layer = HierarchicalSoftmaxLinear(in_features=100, root=depth_two_tree()) + x = torch.rand(10, 100) + result = layer(x) + assert result.shape == (10, 6) + assert isinstance(result, LazyLinearTensor) + assert result.result.shape == (10, 6) + \ No newline at end of file diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 5ff7ab7..6bc9c01 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -8,6 +8,7 @@ greedy_accuracy_parent, greedy_precision, greedy_recall, + GreedyAccuracy, ) from torch.testing import assert_allclose @@ -33,6 +34,9 @@ def test_greedy_accuracy(): assert_allclose(greedy_accuracy(predictions, target_tensor, root=root), 0.75) + metric = GreedyAccuracy(root=root) + assert_allclose(metric(predictions, target_tensor), 0.75) + def test_greedy_f1_score(): root, targets = depth_two_tree_and_targets_three_children() @@ -119,6 +123,18 @@ def test_greedy_accuracy_max_depth_simple(): assert greedy_accuracy_depth_two(predictions_rearranged, target_tensor, root=root) < 0.01 assert greedy_accuracy(predictions_rearranged, target_tensor, root=root) < 0.01 + depth_one = GreedyAccuracy(root=root, max_depth=1, name="depth_one") + assert 0.99 < depth_one(predictions_rearranged, target_tensor) + depth_two = GreedyAccuracy(root=root, max_depth=2, name="depth_two") + assert depth_two(predictions_rearranged, target_tensor) < 0.01 + + assert depth_one.name == "depth_one" + assert depth_one.__name__ == "depth_one" + assert depth_two.name == "depth_two" + assert depth_two.__name__ == "depth_two" + + + def test_greedy_accuracy_max_depth_complex(): root, targets = depth_three_tree_and_targets() diff --git a/tests/test_tensors.py b/tests/test_tensors.py new file mode 100644 index 0000000..5de6c03 --- /dev/null +++ b/tests/test_tensors.py @@ -0,0 +1,121 @@ +from hierarchicalsoftmax.tensors import LazyLinearTensor +import torch +import unittest + + +class TestLazyLinearTensor(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.in_features = 5 + self.out_features = 11 + self.weight = torch.nn.Parameter(torch.arange(self.out_features * self.in_features).reshape(self.out_features, self.in_features).float()) + self.bias = torch.nn.Parameter(torch.arange(self.out_features).float()) + self.input = torch.arange(self.batch_size * self.in_features).reshape(self.batch_size, self.in_features).float() + self.tensor = LazyLinearTensor(self.input, self.weight, self.bias) + + def test_add(self): + result = self.tensor + 1 + expected = torch.matmul(self.input, self.weight.t()) + self.bias + 1 + assert torch.allclose(result, expected) + + def test_mul(self): + result = self.tensor * 2 + expected = torch.matmul(self.input, self.weight.t()) * 2 + self.bias * 2 + assert torch.allclose(result, expected) + + def test_sub(self): + result = self.tensor - 1 + expected = torch.matmul(self.input, self.weight.t()) + self.bias - 1 + assert torch.allclose(result, expected) + + def test_shape(self): + assert self.tensor.shape == (self.batch_size, self.out_features) + + def test_get_item(self): + result = self.tensor[0] + assert isinstance(result, LazyLinearTensor) + assert result.shape == (self.out_features,) + + def test_slice(self): + result = self.tensor[:,:2] + assert not isinstance(result, LazyLinearTensor) + + def test_get_item_slice(self): + result = self.tensor[0] + assert torch.allclose(result[:7], self.tensor.result[0,:7]) + assert torch.allclose(result[7:], self.tensor.result[0,7:]) + + def test_get_item(self): + assert len(self.tensor) == self.batch_size + assert len(self.tensor[0]) == self.out_features + + def test_iter_slice(self): + for i, tensor in enumerate(self.tensor): + assert torch.allclose(tensor[:2], self.tensor[i][:2]) + + def test_str(self): + assert str(self.tensor) == "LazyLinearTensor (shape=(2, 11))" + + def test_repr(self): + assert repr(self.tensor) == "LazyLinearTensor (shape=(2, 11))" + + def test_truediv(self): + result = self.tensor / 2 + expected = torch.matmul(self.input, self.weight.t()) / 2 + self.bias / 2 + assert torch.allclose(result, expected) + + def test_matmul(self): + result = self.tensor @ torch.arange(self.out_features).float() + expected = torch.matmul(self.tensor.result, torch.arange(self.out_features).float()) + assert torch.allclose(result, expected) + + def test_radd(self): + result = 1 + self.tensor + expected = torch.matmul(self.input, self.weight.t()) + self.bias + 1 + assert torch.allclose(result, expected) + + def test_rsub(self): + result = 1 - self.tensor + expected = 1 - (torch.matmul(self.input, self.weight.t()) + self.bias) + assert torch.allclose(result, expected) + + def test_rmul(self): + result = 2 * self.tensor + expected = 2 * (torch.matmul(self.input, self.weight.t()) + self.bias) + assert torch.allclose(result, expected) + + def test_rtruediv(self): + result = 2 / self.tensor + expected = 2 / (torch.matmul(self.input, self.weight.t()) + self.bias) + assert torch.allclose(result, expected) + + def test_rmatmul(self): + size = 9 + matrix = torch.arange(self.batch_size*size).reshape(size, self.batch_size).float() + result = matrix @ self.tensor + expected = matrix @ (torch.matmul(self.input, self.weight.t()) + self.bias) + assert torch.allclose(result, expected) + + def test_index_error(self): + with self.assertRaises(IndexError): + self.tensor[0,0,5] + + def test_is_floating_point(self): + assert torch.is_floating_point(self.tensor) == True + assert self.tensor.is_floating_point() == True + + def test_tensor_float(self): + weight = torch.nn.Parameter(torch.arange(self.out_features * self.in_features).reshape(self.out_features, self.in_features).double()) + bias = torch.nn.Parameter(torch.arange(self.out_features).double()) + input = torch.arange(self.batch_size * self.in_features).reshape(self.batch_size, self.in_features).double() + tensor = LazyLinearTensor(input, weight, bias) + assert tensor.dtype == torch.float64 + assert tensor.float().dtype == torch.float32 + + def test_tensor_half(self): + weight = torch.nn.Parameter(torch.arange(self.out_features * self.in_features).reshape(self.out_features, self.in_features).double()) + bias = torch.nn.Parameter(torch.arange(self.out_features).double()) + input = torch.arange(self.batch_size * self.in_features).reshape(self.batch_size, self.in_features).double() + tensor = LazyLinearTensor(input, weight, bias) + assert tensor.dtype == torch.float64 + assert tensor.half().dtype == torch.float16