From f1ae02b4b7cfb2ded7318d0acb0a556d6ae7f0e1 Mon Sep 17 00:00:00 2001 From: Mikko Kotila Date: Thu, 14 Apr 2022 22:49:36 +0300 Subject: [PATCH] update to support latest tf.keras --- kerasplotlib/traininglog.py | 43 ++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/kerasplotlib/traininglog.py b/kerasplotlib/traininglog.py index 68794d5..40992f2 100644 --- a/kerasplotlib/traininglog.py +++ b/kerasplotlib/traininglog.py @@ -1,52 +1,36 @@ -from __future__ import division -import warnings - -import matplotlib -import matplotlib.pyplot as plt -from IPython.display import clear_output - from tensorflow.keras.callbacks import Callback -def loss2name(loss): - if hasattr(loss, '__call__'): - # if passed as a function - return loss.__name__ - else: - # if passed as a string - return loss class TrainingLog(Callback): + def __init__(self, figsize=None, cell_size=(6, 4), dynamic_x_axis=False, max_cols=2, - y_max=1): + y_max=1, + metrics=None): self.figsize = figsize self.cell_size = cell_size self.dynamic_x_axis = dynamic_x_axis self.max_cols = max_cols self.y_max = y_max + self.metrics = metrics + + self.logs = [] def on_train_begin(self, logs={}): - self.base_metrics = [metric for metric in self.params['metrics'] if not metric.startswith('val_')] + + self.base_metrics = [metric for metric in self.metrics if not metric.startswith('val_')] if self.figsize is None: self.figsize = ( self.max_cols * self.cell_size[0], ((len(self.base_metrics) + 1) // self.max_cols + 1) * self.cell_size[1] ) - if isinstance(self.model.loss, list): - losses = self.model.loss - elif isinstance(self.model.loss, dict): - losses = self.model.loss.values() - else: - losses = [self.model.loss] - self.max_epoch = self.params['epochs'] if not self.dynamic_x_axis else None - self.logs = [] def on_epoch_end(self, epoch, logs={}): @@ -60,6 +44,7 @@ def on_epoch_end(self, epoch, logs={}): y_max=self.y_max, validation_fmt="val_{}") + def draw_plot(logs, metrics, figsize=None, @@ -69,10 +54,15 @@ def draw_plot(logs, validation_fmt="val_{}", metric2title={}): + import matplotlib + import matplotlib.pyplot as plt + from IPython.display import clear_output + plt.figure(figsize=figsize) clear_output(wait=True) for metric_id, metric in enumerate(metrics): + plt.subplot((len(metrics) + 1) // max_cols + 1, max_cols, metric_id + 1) if max_epoch is not None: @@ -81,14 +71,19 @@ def draw_plot(logs, plt.plot(range(1, len(logs) + 1), [log[metric] for log in logs], label="training", color='#1B2F33', linestyle='dashed') + plt.ylim(0, y_max) + if validation_fmt.format(metric) in logs[0]: plt.plot(range(1, len(logs) + 1), [log[validation_fmt.format(metric)] for log in logs], label="validation", color='#A72608') + plt.ylim(0, y_max) + plt.title(metric2title.get(metric, metric), pad=15) plt.xlabel('epoch', color='grey') + plt.legend(loc=1, ncol=1, bbox_to_anchor=(1.35, 1.0)) plt.tight_layout()