Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor #3

Merged
merged 14 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hierarchicalsoftmax/dotexporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _default_edgeattrfunc(
def exclude_node(self, node):
return not node.is_root and node not in self.greedy_nodes and self.probabilities[node.index_in_softmax_layer] < self.threshold

def _DotExporter__iter_nodes(self, indent, nodenamefunc, nodeattrfunc):
def _DotExporter__iter_nodes(self, indent, nodenamefunc, nodeattrfunc, *args, **kwargs):
for node in PreOrderIter(self.node, maxlevel=self.maxlevel):
if self.exclude_node(node):
continue
Expand All @@ -64,7 +64,7 @@ def _DotExporter__iter_nodes(self, indent, nodenamefunc, nodeattrfunc):
nodeattr = " [%s]" % nodeattr if nodeattr is not None else ""
yield '%s"%s"%s;' % (indent, DotExporter.esc(nodename), nodeattr)

def _DotExporter__iter_edges(self, indent, nodenamefunc, edgeattrfunc, edgetypefunc):
def _DotExporter__iter_edges(self, indent, nodenamefunc, edgeattrfunc, edgetypefunc, *args, **kwargs):
maxlevel = self.maxlevel - 1 if self.maxlevel else None
for node in PreOrderIter(self.node, maxlevel=maxlevel):
nodename = nodenamefunc(node)
Expand Down
2 changes: 0 additions & 2 deletions hierarchicalsoftmax/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
from pathlib import Path
from anytree import PreOrderIter
from functools import partial

from . import nodes
from .dotexporter import ThresholdDotExporter
Expand Down Expand Up @@ -52,7 +51,6 @@ def leaf_probabilities(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode) -
return torch.index_select(probabilities, 1, root.leaf_indexes_in_softmax_layer)



def greedy_predictions(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:Optional[int]=None, threshold:Optional[float]=None) -> List[nodes.SoftmaxNode]:
"""
Takes the prediction scores for a number of samples and converts it to a list of predictions of nodes in the tree.
Expand Down
4 changes: 4 additions & 0 deletions hierarchicalsoftmax/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch import nn

from .nodes import SoftmaxNode
from .tensors import LazyLinearTensor

class HierarchicalSoftmaxLayerError(RuntimeError):
pass
Expand All @@ -17,6 +18,9 @@ def __init__(self, root:SoftmaxNode, out_features=None, **kwargs):

super().__init__(out_features=self.root.layer_size, **kwargs)

def forward(self, x) -> LazyLinearTensor:
return LazyLinearTensor(x, weight=self.weight, bias=self.bias)


class HierarchicalSoftmaxLinear(HierarchicalSoftmaxLayerMixin, nn.Linear):
"""
Expand Down
16 changes: 16 additions & 0 deletions hierarchicalsoftmax/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,19 @@ def greedy_accuracy_parent(prediction_tensor, target_tensor, root, max_depth=Non
return (prediction_parent_ids.to(target_parent_ids.device) == target_parent_ids).float().mean()


class GreedyAccuracy():
name:str = "greedy"

def __init__(self, root:nodes.SoftmaxNode, name="greedy_accuracy", max_depth=None):
self.max_depth = max_depth
self.name = name
self.root = root

@property
def __name__(self):
""" For using as a FastAI metric. """
return self.name

def __call__(self, predictions, targets):
return greedy_accuracy(predictions, targets, self.root, max_depth=self.max_depth)

100 changes: 100 additions & 0 deletions hierarchicalsoftmax/tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from torch import Tensor, Size
from functools import cached_property
import torch.nn.functional as F
from torch.nn.parameter import Parameter


class LazyLinearTensor(Tensor):
"""
A tensor that is designed to be used with HierarchicalSoftmaxLazyLinear layers.
"""
@staticmethod
def __new__(cls, x, weight:Parameter, bias:Parameter, *args, **kwargs):
return super().__new__(cls, x, *args, **kwargs)

def __init__(self, x:Tensor, weight:Parameter, bias:Parameter, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input = x
self.weight = weight
self.bias = bias

@cached_property
def result(self):
return F.linear(self.input, self.weight, self.bias)

def __add__(self, other):
return self.result + other

def __sub__(self, other):
return self.result - other

def __mul__(self, other):
return self.result * other

def __truediv__(self, other):
return self.result / other

def __matmul__(self, other):
return self.result @ other

def __radd__(self, other):
return other + self.result

def __rsub__(self, other):
return other - self.result

def __rmul__(self, other):
return other * self.result

def __rtruediv__(self, other):
return other / self.result

def __rmatmul__(self, other):
return other @ self.result

def __getitem__(self, index):
assert isinstance(index, int) or isinstance(index, slice) or isinstance(index, tuple)
if not isinstance(index, tuple) or isinstance(index, slice):
index = (index,)

my_shape = self.shape
if len(index) < len(my_shape):
return LazyLinearTensor(self.input[index], weight=self.weight, bias=self.bias)
if len(index) > len(my_shape):
raise IndexError(f"Cannot get index '{index}' for LazyLinearTensor of shape {len(my_shape)}")

input = self.input[index[:-1]]
weight = self.weight[index[-1]]
bias = self.bias[index[-1]]
return F.linear(input, weight, bias)

@property
def shape(self) -> Size:
return Size( self.input.shape[:-1] + (self.weight.shape[0],) )

def __str__(self) -> str:
return f"LazyLinearTensor (shape={tuple(self.shape)})"

def __repr__(self) -> str:
return str(self)

def __len__(self) -> int:
return self.shape[0]

def __iter__(self):
for i in range(len(self)):
yield self[i]

def float(self):
x = super().float()
x.input = self.input.float()
x.weight = self.weight.float()
x.bias = self.bias.float()
return x

def half(self):
x = super().half()
x.input = self.input.half()
x.weight = self.weight.half()
x.bias = self.bias.half()
return x
Loading
Loading