Skip to content

Commit

Permalink
update to support latest tf.keras
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikko Kotila committed Apr 14, 2022
1 parent 8475c6c commit f1ae02b
Showing 1 changed file with 19 additions and 24 deletions.
43 changes: 19 additions & 24 deletions kerasplotlib/traininglog.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,36 @@
from __future__ import division
import warnings

import matplotlib
import matplotlib.pyplot as plt
from IPython.display import clear_output

from tensorflow.keras.callbacks import Callback

def loss2name(loss):
if hasattr(loss, '__call__'):
# if passed as a function
return loss.__name__
else:
# if passed as a string
return loss

class TrainingLog(Callback):

def __init__(self,
figsize=None,
cell_size=(6, 4),
dynamic_x_axis=False,
max_cols=2,
y_max=1):
y_max=1,
metrics=None):

self.figsize = figsize
self.cell_size = cell_size
self.dynamic_x_axis = dynamic_x_axis
self.max_cols = max_cols
self.y_max = y_max
self.metrics = metrics

self.logs = []

def on_train_begin(self, logs={}):
self.base_metrics = [metric for metric in self.params['metrics'] if not metric.startswith('val_')]

self.base_metrics = [metric for metric in self.metrics if not metric.startswith('val_')]
if self.figsize is None:
self.figsize = (
self.max_cols * self.cell_size[0],
((len(self.base_metrics) + 1) // self.max_cols + 1) * self.cell_size[1]
)

if isinstance(self.model.loss, list):
losses = self.model.loss
elif isinstance(self.model.loss, dict):
losses = self.model.loss.values()
else:
losses = [self.model.loss]

self.max_epoch = self.params['epochs'] if not self.dynamic_x_axis else None

self.logs = []

def on_epoch_end(self, epoch, logs={}):

Expand All @@ -60,6 +44,7 @@ def on_epoch_end(self, epoch, logs={}):
y_max=self.y_max,
validation_fmt="val_{}")


def draw_plot(logs,
metrics,
figsize=None,
Expand All @@ -69,10 +54,15 @@ def draw_plot(logs,
validation_fmt="val_{}",
metric2title={}):

import matplotlib
import matplotlib.pyplot as plt
from IPython.display import clear_output

plt.figure(figsize=figsize)
clear_output(wait=True)

for metric_id, metric in enumerate(metrics):

plt.subplot((len(metrics) + 1) // max_cols + 1, max_cols, metric_id + 1)

if max_epoch is not None:
Expand All @@ -81,14 +71,19 @@ def draw_plot(logs,
plt.plot(range(1, len(logs) + 1),
[log[metric] for log in logs],
label="training", color='#1B2F33', linestyle='dashed')

plt.ylim(0, y_max)

if validation_fmt.format(metric) in logs[0]:
plt.plot(range(1, len(logs) + 1),
[log[validation_fmt.format(metric)] for log in logs],
label="validation", color='#A72608')

plt.ylim(0, y_max)

plt.title(metric2title.get(metric, metric), pad=15)
plt.xlabel('epoch', color='grey')

plt.legend(loc=1, ncol=1, bbox_to_anchor=(1.35, 1.0))

plt.tight_layout()
Expand Down

0 comments on commit f1ae02b

Please sign in to comment.