Skip to content

Commit

Permalink
tests removed copy method
Browse files Browse the repository at this point in the history
  • Loading branch information
SermetPekin committed Dec 5, 2024
1 parent 2548858 commit 0b8f043
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 46 deletions.
15 changes: 3 additions & 12 deletions micrograd/activation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,9 @@ def linear(value: "Value") -> "Value":
@staticmethod
def sigmoid(value: "Value") -> "Value":
from .engine import Value
self = value
out = Value(1 / (1 + (-value).exp()), (self,), "Sigmoid")

# Value(0 if self.data < 0 else self.data, (self,), "Sigmoid")
def _backward():
self.grad += (out.data > 0) * out.grad

out._backward = _backward
return out

# return 1 / (1 + (-value).exp())
return value.sigmoid()

@staticmethod
def tanh(value: "Value") -> "Value":
return (value.exp() - (-value).exp()) / (value.exp() + (-value).exp())
from .engine import Value
return value.tanh()
60 changes: 40 additions & 20 deletions micrograd/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,13 @@ class Value:
"""stores a single scalar value and its gradient"""

def __init__(self, data: Number | 'Value', _children: tuple = (), _op: str = ""):
if isinstance(data, Value):
data.copy(data, _children, _op)
else:
self.data: Number = data
self.grad: float = 0.0
# internal variables used for autograd graph construction
self._backward: Callable[[], None] = lambda: None
self._prev: Set["Value"] = set(_children)
self._op = _op # the op that produced this node, for graphviz / debugging / etc

def copy(self, value: 'Value', _children: tuple = (), _op: str = ""):
self.data = value.data
self.grad = value.grad
self._backward = value._backward
self._prev = value._prev
self._op = value._op
if _children:
self._prev: Set["Value"] = set(_children)
if _op:
self._op = _op

self.data: Number = data
self.grad: float = 0.0
# internal variables used for autograd graph construction
self._backward: Callable[[], None] = lambda: None
self._prev: Set["Value"] = set(_children)
self._op = _op # the op that produced this node, for graphviz / debugging / etc

def __add__(self, other: Number | "Value") -> "Value":
other = other if isinstance(other, Value) else Value(other)
Expand All @@ -41,6 +28,19 @@ def _backward():

return out

def __abs__(self) -> "Value":
self.data = abs(self.data)

return self

def __lt__(self, other: Number | "Value") -> bool:
other = other if isinstance(other, Value) else Value(other)
return self.data < other.data

def __gt__(self, other: Number | "Value") -> bool:
other = other if isinstance(other, Value) else Value(other)
return self.data > other.data

def __mul__(self, other: Number | "Value"):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data * other.data, (self, other), "*")
Expand Down Expand Up @@ -76,6 +76,26 @@ def _backward():

return out

def sigmoid(self) -> "Value":
out = Value(1 / (1 + (-1 * self).exp()), (self,), "Sigmoid")

def _backward():
self.grad += (out.data > 0) * out.grad

out._backward = _backward

return out

def tanh(self) -> "Value":
out = Value((self.exp() - (-self).exp()) / (self.exp() + (-self).exp()), (self,), "tanh")

def _backward():
self.grad += (out.data > 0) * out.grad

out._backward = _backward

return out

def backward(self) -> None:

# topological order all the children in the graph
Expand Down
30 changes: 16 additions & 14 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@


def test_relu():
assert Activation.relu(Value(3.0)).data == 3.0
assert Activation.relu(Value(-2.0)).data == 0.0
assert Activation.relu(Value(3.0)).data == Value(3.0).data
assert Activation.relu(Value(-2.0)).data == Value(0.0).data


def test_linear():
assert Activation.linear(Value(3.0)).data == 3.0
assert Activation.linear(Value(-2.0)).data == -2.0

#
# def test_sigmoid():
# sigmoid_value = Activation.sigmoid(Value(0.0)).data
# assert abs(sigmoid_value - 0.5) < 1e-6
#
#
# def test_tanh():
# tanh_value = Activation.tanh(Value(0.0)).data
# assert abs(tanh_value - 0.0) < 1e-6
assert Activation.linear(Value(3.0)).data == Value(3.0).data
assert Activation.linear(Value(-2.0)).data == Value(-2.0).data


def test_sigmoid():
sigmoid_value = Activation.sigmoid(Value(0.0))
assert isinstance(sigmoid_value, Value)
assert abs(sigmoid_value - 0.5) < 1e-6


def test_tanh():
tanh_value = Activation.tanh(Value(0.0))
assert isinstance(tanh_value, Value)
assert abs(tanh_value) < 1e-6

0 comments on commit 0b8f043

Please sign in to comment.