Skip to content

Commit

Permalink
Add log_cosh and huber loss (keras-team#67)
Browse files Browse the repository at this point in the history
* Add log_cosh and huber loss

* Docstring standardization

* Format

* Standardize wrapper function docstrings
  • Loading branch information
grasskin authored May 2, 2023
1 parent b1b1a4b commit 8ce74e1
Show file tree
Hide file tree
Showing 2 changed files with 348 additions and 1 deletion.
150 changes: 149 additions & 1 deletion keras_core/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,63 @@ def __init__(
)


@keras_core_export("keras_core.losses.Huber")
class Huber(LossFunctionWrapper):
"""Computes the Huber loss between `y_true` & `y_pred`.
Formula:
```python
for x in error:
if abs(x) <= delta:
loss.append(0.5 * x^2)
elif abs(x) > delta:
loss.append(delta * abs(x) - 0.5 * delta^2)
loss = mean(loss, axis=-1)
```
See: [Huber loss](https://en.wikipedia.org/wiki/Huber_loss).
Args:
delta: A float, the point where the Huber loss function changes from a
quadratic to linear.
reduction: Type of reduction to apply to loss. Options are `"sum"`,
`"sum_over_batch_size"` or `None`. Defaults to
`"sum_over_batch_size"`.
name: Optional name for the instance.
"""

def __init__(
self,
delta=1.0,
reduction="sum_over_batch_size",
name="huber_loss",
):
super().__init__(huber, name=name, reduction=reduction, delta=delta)


@keras_core_export("keras_core.losses.LogCosh")
class LogCosh(LossFunctionWrapper):
"""Computes the logarithm of the hyperbolic cosine of the prediction error.
Formula:
```python
error = y_pred - y_true
logcosh = log((exp(error) + exp(-error))/2)`
```
where x is the error `y_pred - y_true`.
Args:
reduction: Type of reduction to apply to loss. Options are `"sum"`,
`"sum_over_batch_size"` or `None`. Defaults to
`"sum_over_batch_size"`.
name: Optional name for the instance.
"""

def __init__(self, reduction="sum_over_batch_size", name="log_cosh"):
super().__init__(log_cosh, name=name, reduction=reduction)


@keras_core_export("keras_core.losses.Hinge")
class Hinge(LossFunctionWrapper):
"""Computes the hinge loss between `y_true` & `y_pred`.
Expand Down Expand Up @@ -1063,7 +1120,7 @@ def mean_squared_error(y_true, y_pred):
loss = mean(square(y_true - y_pred), axis=-1)
```
Standalone usage:
Example:
>>> y_true = np.random.randint(0, 2, size=(2, 3))
>>> y_pred = np.random.random(size=(2, 3))
Expand Down Expand Up @@ -1237,6 +1294,97 @@ def cosine_similarity(y_true, y_pred, axis=-1):
return -ops.sum(y_true * y_pred, axis=axis)


@keras_core_export(["keras_core.losses.huber", "keras_core.metrics.huber"])
def huber(y_true, y_pred, delta=1.0):
"""Computes Huber loss value.
Formula:
```python
for x in error:
if abs(x) <= delta:
loss.append(0.5 * x^2)
elif abs(x) > delta:
loss.append(delta * abs(x) - 0.5 * delta^2)
loss = mean(loss, axis=-1)
```
See: [Huber loss](https://en.wikipedia.org/wiki/Huber_loss).
Example:
>>> y_true = [[0, 1], [0, 0]]
>>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
>>> loss = keras_core.losses.huber(y_true, y_pred)
0.155
Args:
y_true: tensor of true targets.
y_pred: tensor of predicted targets.
delta: A float, the point where the Huber loss function changes from a
quadratic to linear. Defaults to 1.
Returns:
Tensor with one scalar loss entry per sample.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
delta = ops.convert_to_tensor(delta)
error = ops.subtract(y_pred, y_true)
abs_error = ops.abs(error)
half = ops.convert_to_tensor(0.5, dtype=abs_error.dtype)
return ops.mean(
ops.where(
abs_error <= delta,
half * ops.square(error),
delta * abs_error - half * ops.square(delta),
),
axis=-1,
)


@keras_core_export(
["keras_core.losses.log_cosh", "keras_core.metrics.log_cosh"]
)
def log_cosh(y_true, y_pred):
"""Logarithm of the hyperbolic cosine of the prediction error.
Formula:
```python
loss = mean(log(cosh(y_pred - y_true)), axis=-1)
```
Note that `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small
`x` and to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works
mostly like the mean squared error, but will not be so strongly affected by
the occasional wildly incorrect prediction.
Example:
>>> y_true = [[0., 1.], [0., 0.]]
>>> y_pred = [[1., 1.], [0., 0.]]
>>> loss = keras_core.losses.log_cosh(y_true, y_pred)
0.108
Args:
y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.
y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.
Returns:
Logcosh error values with shape = `[batch_size, d0, .. dN-1]`.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)
y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
log2 = ops.convert_to_tensor(ops.log(2.0), dtype=y_pred.dtype)

def _logcosh(x):
return x + ops.softplus(-2.0 * x) - log2

return ops.mean(_logcosh(y_pred - y_true), axis=-1)


@keras_core_export(
[
"keras_core.metrics.kl_divergence",
Expand Down
199 changes: 199 additions & 0 deletions keras_core/losses/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,205 @@ def test_axis(self):
self.assertAlmostEqual(loss, expected_loss, 3)


class HuberLossTest(testing.TestCase):
def huber_loss(self, y_true, y_pred, delta=1.0):
error = y_pred - y_true
abs_error = np.abs(error)

quadratic = np.minimum(abs_error, delta)
linear = np.subtract(abs_error, quadratic)
return np.add(
np.multiply(0.5, np.multiply(quadratic, quadratic)),
np.multiply(delta, linear),
)

def setup(self, delta=1.0):
self.np_y_pred = np.array([[0.9, 0.2, 0.2], [0.8, 0.4, 0.6]])
self.np_y_true = np.array([[1.0, 0.0, 1.0], [1.0, 0.0, 0.0]])

self.batch_size = 6
self.expected_losses = self.huber_loss(
self.np_y_true, self.np_y_pred, delta
)

self.y_pred = self.np_y_pred
self.y_true = self.np_y_true

def test_config(self):
h_obj = losses.Huber(reduction="sum", name="huber")
self.assertEqual(h_obj.name, "huber")
self.assertEqual(h_obj.reduction, "sum")

def test_all_correct(self):
self.setup()
h_obj = losses.Huber()
loss = h_obj(self.y_true, self.y_true)
self.assertAlmostEqual(loss, 0.0, 3)

def test_unweighted(self):
self.setup()
h_obj = losses.Huber()
loss = h_obj(self.y_true, self.y_pred)
actual_loss = np.sum(self.expected_losses) / self.batch_size
self.assertAlmostEqual(loss, actual_loss, 3)

def test_scalar_weighted(self):
self.setup()
h_obj = losses.Huber()
sample_weight = 2.3
loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
actual_loss = (
sample_weight * np.sum(self.expected_losses) / self.batch_size
)
self.assertAlmostEqual(loss, actual_loss, 3)

# Verify we get the same output when the same input is given
loss_2 = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
self.assertAlmostEqual(loss, loss_2, 3)

def test_sample_weighted(self):
self.setup()
h_obj = losses.Huber()
sample_weight = np.array([[1.2], [3.4]])

loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
actual_loss = np.multiply(
self.expected_losses,
np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)),
)
actual_loss = np.sum(actual_loss) / self.batch_size
self.assertAlmostEqual(loss, actual_loss, 3)

def test_timestep_weighted(self):
self.setup()
h_obj = losses.Huber()
y_pred = self.np_y_pred.reshape((2, 3, 1))
y_true = self.np_y_true.reshape((2, 3, 1))
expected_losses = self.huber_loss(y_true, y_pred)

sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3, 1))
loss = h_obj(
y_true,
y_pred,
sample_weight=sample_weight,
)
actual_loss = np.multiply(expected_losses, sample_weight)
actual_loss = np.sum(actual_loss) / self.batch_size
self.assertAlmostEqual(loss, actual_loss, 3)

def test_zero_weighted(self):
self.setup()
h_obj = losses.Huber()
sample_weight = 0
loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
self.assertAlmostEqual(loss, 0.0, 3)

def test_non_default_delta(self):
self.setup(delta=0.8)
h_obj = losses.Huber(delta=0.8)
sample_weight = 2.3
loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
actual_loss = (
sample_weight * np.sum(self.expected_losses) / self.batch_size
)
self.assertAlmostEqual(loss, actual_loss, 3)

def test_loss_with_non_default_dtype(self):
# Test case for GitHub issue:
# https://github.com/tensorflow/tensorflow/issues/39004
# TODO
pass


class LogCoshTest(testing.TestCase):
def setup(self):
y_true = np.asarray([[1, 9, 2], [-5, -2, 6]], dtype=np.float32)
y_pred = np.asarray([[4, 8, 12], [8, 1, 3]], dtype=np.float32)

self.batch_size = 6
error = y_pred - y_true
self.expected_losses = np.log((np.exp(error) + np.exp(-error)) / 2)

self.y_true = y_true
self.y_pred = y_pred

def test_config(self):
logcosh_obj = losses.LogCosh(reduction="sum", name="logcosh_loss")
self.assertEqual(logcosh_obj.name, "logcosh_loss")
self.assertEqual(logcosh_obj.reduction, "sum")

def test_unweighted(self):
self.setup()
logcosh_obj = losses.LogCosh()

loss = logcosh_obj(self.y_true, self.y_pred)
expected_loss = np.sum(self.expected_losses) / self.batch_size
self.assertAlmostEqual(loss, expected_loss, 3)

def test_scalar_weighted(self):
self.setup()
logcosh_obj = losses.LogCosh()
sample_weight = 2.3

loss = logcosh_obj(
self.y_true, self.y_pred, sample_weight=sample_weight
)
expected_loss = (
sample_weight * np.sum(self.expected_losses) / self.batch_size
)
self.assertAlmostEqual(loss, expected_loss, 3)

# Verify we get the same output when the same input is given
loss_2 = logcosh_obj(
self.y_true, self.y_pred, sample_weight=sample_weight
)
self.assertAlmostEqual(loss, loss_2, 3)

def test_sample_weighted(self):
self.setup()
logcosh_obj = losses.LogCosh()

sample_weight = np.asarray([1.2, 3.4])
loss = logcosh_obj(
self.y_true, self.y_pred, sample_weight=sample_weight
)

expected_loss = np.multiply(
self.expected_losses,
np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)),
)
expected_loss = np.sum(expected_loss) / self.batch_size
self.assertAlmostEqual(loss, expected_loss, 3)

def test_timestep_weighted(self):
self.setup()
logcosh_obj = losses.LogCosh()
y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)
y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)
error = y_pred - y_true
expected_losses = np.log((np.exp(error) + np.exp(-error)) / 2)
sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3, 1))

loss = logcosh_obj(
y_true,
y_pred,
sample_weight=sample_weight,
)
expected_loss = (
np.sum(expected_losses * sample_weight) / self.batch_size
)
self.assertAlmostEqual(loss, expected_loss, 3)

def test_zero_weighted(self):
self.setup()
logcosh_obj = losses.LogCosh()
sample_weight = 0
loss = logcosh_obj(
self.y_true, self.y_pred, sample_weight=sample_weight
)
self.assertAlmostEqual(loss, 0.0, 3)


class KLDivergenceTest(testing.TestCase):
def setup(self):
self.y_pred = np.asarray(
Expand Down

0 comments on commit 8ce74e1

Please sign in to comment.