Skip to content

Commit

Permalink
Merge pull request #3 from rbturnbull/tensor
Browse files Browse the repository at this point in the history
Tensor
  • Loading branch information
rbturnbull authored May 21, 2024
2 parents cfe5b94 + 9b274cd commit 2a6c033
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 136 deletions.
4 changes: 2 additions & 2 deletions hierarchicalsoftmax/dotexporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _default_edgeattrfunc(
def exclude_node(self, node):
return not node.is_root and node not in self.greedy_nodes and self.probabilities[node.index_in_softmax_layer] < self.threshold

def _DotExporter__iter_nodes(self, indent, nodenamefunc, nodeattrfunc):
def _DotExporter__iter_nodes(self, indent, nodenamefunc, nodeattrfunc, *args, **kwargs):
for node in PreOrderIter(self.node, maxlevel=self.maxlevel):
if self.exclude_node(node):
continue
Expand All @@ -64,7 +64,7 @@ def _DotExporter__iter_nodes(self, indent, nodenamefunc, nodeattrfunc):
nodeattr = " [%s]" % nodeattr if nodeattr is not None else ""
yield '%s"%s"%s;' % (indent, DotExporter.esc(nodename), nodeattr)

def _DotExporter__iter_edges(self, indent, nodenamefunc, edgeattrfunc, edgetypefunc):
def _DotExporter__iter_edges(self, indent, nodenamefunc, edgeattrfunc, edgetypefunc, *args, **kwargs):
maxlevel = self.maxlevel - 1 if self.maxlevel else None
for node in PreOrderIter(self.node, maxlevel=maxlevel):
nodename = nodenamefunc(node)
Expand Down
2 changes: 0 additions & 2 deletions hierarchicalsoftmax/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
from pathlib import Path
from anytree import PreOrderIter
from functools import partial

from . import nodes
from .dotexporter import ThresholdDotExporter
Expand Down Expand Up @@ -52,7 +51,6 @@ def leaf_probabilities(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode) -
return torch.index_select(probabilities, 1, root.leaf_indexes_in_softmax_layer)



def greedy_predictions(prediction_tensor:torch.Tensor, root:nodes.SoftmaxNode, max_depth:Optional[int]=None, threshold:Optional[float]=None) -> List[nodes.SoftmaxNode]:
"""
Takes the prediction scores for a number of samples and converts it to a list of predictions of nodes in the tree.
Expand Down
4 changes: 4 additions & 0 deletions hierarchicalsoftmax/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch import nn

from .nodes import SoftmaxNode
from .tensors import LazyLinearTensor

class HierarchicalSoftmaxLayerError(RuntimeError):
pass
Expand All @@ -17,6 +18,9 @@ def __init__(self, root:SoftmaxNode, out_features=None, **kwargs):

super().__init__(out_features=self.root.layer_size, **kwargs)

def forward(self, x) -> LazyLinearTensor:
return LazyLinearTensor(x, weight=self.weight, bias=self.bias)


class HierarchicalSoftmaxLinear(HierarchicalSoftmaxLayerMixin, nn.Linear):
"""
Expand Down
16 changes: 16 additions & 0 deletions hierarchicalsoftmax/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,19 @@ def greedy_accuracy_parent(prediction_tensor, target_tensor, root, max_depth=Non
return (prediction_parent_ids.to(target_parent_ids.device) == target_parent_ids).float().mean()


class GreedyAccuracy():
name:str = "greedy"

def __init__(self, root:nodes.SoftmaxNode, name="greedy_accuracy", max_depth=None):
self.max_depth = max_depth
self.name = name
self.root = root

@property
def __name__(self):
""" For using as a FastAI metric. """
return self.name

def __call__(self, predictions, targets):
return greedy_accuracy(predictions, targets, self.root, max_depth=self.max_depth)

100 changes: 100 additions & 0 deletions hierarchicalsoftmax/tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from torch import Tensor, Size
from functools import cached_property
import torch.nn.functional as F
from torch.nn.parameter import Parameter


class LazyLinearTensor(Tensor):
"""
A tensor that is designed to be used with HierarchicalSoftmaxLazyLinear layers.
"""
@staticmethod
def __new__(cls, x, weight:Parameter, bias:Parameter, *args, **kwargs):
return super().__new__(cls, x, *args, **kwargs)

def __init__(self, x:Tensor, weight:Parameter, bias:Parameter, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input = x
self.weight = weight
self.bias = bias

@cached_property
def result(self):
return F.linear(self.input, self.weight, self.bias)

def __add__(self, other):
return self.result + other

def __sub__(self, other):
return self.result - other

def __mul__(self, other):
return self.result * other

def __truediv__(self, other):
return self.result / other

def __matmul__(self, other):
return self.result @ other

def __radd__(self, other):
return other + self.result

def __rsub__(self, other):
return other - self.result

def __rmul__(self, other):
return other * self.result

def __rtruediv__(self, other):
return other / self.result

def __rmatmul__(self, other):
return other @ self.result

def __getitem__(self, index):
assert isinstance(index, int) or isinstance(index, slice) or isinstance(index, tuple)
if not isinstance(index, tuple) or isinstance(index, slice):
index = (index,)

my_shape = self.shape
if len(index) < len(my_shape):
return LazyLinearTensor(self.input[index], weight=self.weight, bias=self.bias)
if len(index) > len(my_shape):
raise IndexError(f"Cannot get index '{index}' for LazyLinearTensor of shape {len(my_shape)}")

input = self.input[index[:-1]]
weight = self.weight[index[-1]]
bias = self.bias[index[-1]]
return F.linear(input, weight, bias)

@property
def shape(self) -> Size:
return Size( self.input.shape[:-1] + (self.weight.shape[0],) )

def __str__(self) -> str:
return f"LazyLinearTensor (shape={tuple(self.shape)})"

def __repr__(self) -> str:
return str(self)

def __len__(self) -> int:
return self.shape[0]

def __iter__(self):
for i in range(len(self)):
yield self[i]

def float(self):
x = super().float()
x.input = self.input.float()
x.weight = self.weight.float()
x.bias = self.bias.float()
return x

def half(self):
x = super().half()
x.input = self.input.half()
x.weight = self.weight.half()
x.bias = self.bias.half()
return x
Loading

0 comments on commit 2a6c033

Please sign in to comment.