diff --git a/micrograd/activation_functions.py b/micrograd/activation_functions.py index 3f682de6..ee31b4a1 100644 --- a/micrograd/activation_functions.py +++ b/micrograd/activation_functions.py @@ -17,18 +17,9 @@ def linear(value: "Value") -> "Value": @staticmethod def sigmoid(value: "Value") -> "Value": from .engine import Value - self = value - out = Value(1 / (1 + (-value).exp()), (self,), "Sigmoid") - - # Value(0 if self.data < 0 else self.data, (self,), "Sigmoid") - def _backward(): - self.grad += (out.data > 0) * out.grad - - out._backward = _backward - return out - - # return 1 / (1 + (-value).exp()) + return value.sigmoid() @staticmethod def tanh(value: "Value") -> "Value": - return (value.exp() - (-value).exp()) / (value.exp() + (-value).exp()) + from .engine import Value + return value.tanh() diff --git a/micrograd/engine.py b/micrograd/engine.py index 2c03e8cb..67b0d407 100644 --- a/micrograd/engine.py +++ b/micrograd/engine.py @@ -8,26 +8,13 @@ class Value: """stores a single scalar value and its gradient""" def __init__(self, data: Number | 'Value', _children: tuple = (), _op: str = ""): - if isinstance(data, Value): - data.copy(data, _children, _op) - else: - self.data: Number = data - self.grad: float = 0.0 - # internal variables used for autograd graph construction - self._backward: Callable[[], None] = lambda: None - self._prev: Set["Value"] = set(_children) - self._op = _op # the op that produced this node, for graphviz / debugging / etc - - def copy(self, value: 'Value', _children: tuple = (), _op: str = ""): - self.data = value.data - self.grad = value.grad - self._backward = value._backward - self._prev = value._prev - self._op = value._op - if _children: - self._prev: Set["Value"] = set(_children) - if _op: - self._op = _op + + self.data: Number = data + self.grad: float = 0.0 + # internal variables used for autograd graph construction + self._backward: Callable[[], None] = lambda: None + self._prev: Set["Value"] = set(_children) + self._op = _op # the op that produced this node, for graphviz / debugging / etc def __add__(self, other: Number | "Value") -> "Value": other = other if isinstance(other, Value) else Value(other) @@ -41,6 +28,19 @@ def _backward(): return out + def __abs__(self) -> "Value": + self.data = abs(self.data) + + return self + + def __lt__(self, other: Number | "Value") -> bool: + other = other if isinstance(other, Value) else Value(other) + return self.data < other.data + + def __gt__(self, other: Number | "Value") -> bool: + other = other if isinstance(other, Value) else Value(other) + return self.data > other.data + def __mul__(self, other: Number | "Value"): other = other if isinstance(other, Value) else Value(other) out = Value(self.data * other.data, (self, other), "*") @@ -76,6 +76,26 @@ def _backward(): return out + def sigmoid(self) -> "Value": + out = Value(1 / (1 + (-1 * self).exp()), (self,), "Sigmoid") + + def _backward(): + self.grad += (out.data > 0) * out.grad + + out._backward = _backward + + return out + + def tanh(self) -> "Value": + out = Value((self.exp() - (-self).exp()) / (self.exp() + (-self).exp()), (self,), "tanh") + + def _backward(): + self.grad += (out.data > 0) * out.grad + + out._backward = _backward + + return out + def backward(self) -> None: # topological order all the children in the graph diff --git a/test/test_trainer.py b/test/test_trainer.py index 64556f77..5d81d260 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -2,20 +2,22 @@ def test_relu(): - assert Activation.relu(Value(3.0)).data == 3.0 - assert Activation.relu(Value(-2.0)).data == 0.0 + assert Activation.relu(Value(3.0)).data == Value(3.0).data + assert Activation.relu(Value(-2.0)).data == Value(0.0).data def test_linear(): - assert Activation.linear(Value(3.0)).data == 3.0 - assert Activation.linear(Value(-2.0)).data == -2.0 - -# -# def test_sigmoid(): -# sigmoid_value = Activation.sigmoid(Value(0.0)).data -# assert abs(sigmoid_value - 0.5) < 1e-6 -# -# -# def test_tanh(): -# tanh_value = Activation.tanh(Value(0.0)).data -# assert abs(tanh_value - 0.0) < 1e-6 + assert Activation.linear(Value(3.0)).data == Value(3.0).data + assert Activation.linear(Value(-2.0)).data == Value(-2.0).data + + +def test_sigmoid(): + sigmoid_value = Activation.sigmoid(Value(0.0)) + assert isinstance(sigmoid_value, Value) + assert abs(sigmoid_value - 0.5) < 1e-6 + + +def test_tanh(): + tanh_value = Activation.tanh(Value(0.0)) + assert isinstance(tanh_value, Value) + assert abs(tanh_value) < 1e-6