From df56f1d961b260724f24e4910f812cafcbcc6f22 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 25 Apr 2021 19:27:37 -0400 Subject: [PATCH 1/6] Initial checkin --- tensorflow-framework/pom.xml | 6 + .../framework/callbacks/CSVLogger.java | 269 ++++++++++ .../framework/callbacks/Callback.java | 262 ++++++++++ .../framework/callbacks/CallbackList.java | 230 ++++++++ .../framework/callbacks/History.java | 91 ++++ .../framework/callbacks/LambdaCallback.java | 233 +++++++++ .../tensorflow/framework/callbacks/Mode.java | 22 + .../framework/callbacks/ProgbarLogger.java | 353 +++++++++++++ .../framework/callbacks/UpdateFreq.java | 23 + .../framework/callbacks/VerboseMode.java | 16 + .../util/PathPlaceholderStringFormat.java | 92 ++++ .../framework/callbacks/util/ProgressBar.java | 489 ++++++++++++++++++ .../framework/callbacks/CSVLoggerTest.java | 113 ++++ .../framework/callbacks/CallbackListTest.java | 52 ++ .../framework/callbacks/HistoryTest.java | 59 +++ .../callbacks/LambdaCallbackTest.java | 63 +++ .../callbacks/ProgbarLoggerTest.java | 176 +++++++ .../util/PathPlaceholderStringFormatTest.java | 48 ++ 18 files changed, 2597 insertions(+) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CallbackList.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/History.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Mode.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/ProgbarLogger.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/UpdateFreq.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/VerboseMode.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormat.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/ProgressBar.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CallbackListTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/HistoryTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/LambdaCallbackTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/ProgbarLoggerTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormatTest.java diff --git a/tensorflow-framework/pom.xml b/tensorflow-framework/pom.xml index af7f47815d5..0ce06282d29 100644 --- a/tensorflow-framework/pom.xml +++ b/tensorflow-framework/pom.xml @@ -43,6 +43,12 @@ org.tensorflow tensorflow-core-api ${project.version} + + + + org.apache.commons + commons-csv + 1.8 org.junit.jupiter diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java new file mode 100644 index 00000000000..f009dfadfc0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java @@ -0,0 +1,269 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVPrinter; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.family.TNumber; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.StringJoiner; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +/** + * Callback that streams epoch results to a CSV file. public + * + *

Supports all values that can be represented as a string + * + * @param the data type for the weights in the model + */ +class CSVLogger extends Callback implements AutoCloseable { + + public static final char DEFAULT_SEPARATOR = ','; + public static final boolean DEFAULT_APPEND = false; + + private final File file; + private final char separator; + private final boolean append; + private List keys; + private boolean appendHeader = true; + + private CSVPrinter writer; + + /** + * Creates a CSVLogger callback using {@link #DEFAULT_SEPARATOR} to separate elements in the csv + * file, and {@link #DEFAULT_APPEND} for the append value. + * + * @param file the csv file + */ + public CSVLogger(File file) { + this(file, DEFAULT_SEPARATOR, DEFAULT_APPEND); + } + + /** + * Creates a CSVLogger callback using {@link #DEFAULT_SEPARATOR} to separate elements in the csv + * file, and {@link #DEFAULT_APPEND} for the append value. + * + * @param filename filename of the csv file + */ + public CSVLogger(String filename) { + this(new File(filename), DEFAULT_SEPARATOR, DEFAULT_APPEND); + } + + /** + * Creates a CSVLogger callback using {@link #DEFAULT_APPEND} for the append value. + * + * @param file the csv file + * @param separator string used to separate elements in the csv file. + */ + public CSVLogger(File file, char separator) { + this(file, separator, false); + } + + /** + * Creates a CSVLogger callback using {@link #DEFAULT_APPEND} for the append value. + * + * @param filename filename of the csv file + * @param separator string used to separate elements in the csv file. + */ + public CSVLogger(String filename, char separator) { + this(new File(filename), separator, false); + } + + /** + * Creates a CSVLogger callback. + * + * @param filename filename of the csv file + * @param separator the character used to separate elements in the csv file. + * @param append if true, append if file exists (useful for continuing training). if false, + * overwrite existing file. + */ + public CSVLogger(String filename, char separator, boolean append) { + this(new File(filename), separator, append); + } + + /** + * Creates a CSVLogger callback. + * + * @param file the csv file + * @param separator the character used to separate elements in the csv file. + * @param append if true, append if file exists (useful for continuing training). if false, + * overwrite existing file. + */ + public CSVLogger(File file, char separator, boolean append) { + this.file = file; + this.separator = separator; + this.append = append; + } + + /** {@inheritDoc} */ + @Override + public void onTrainBegin(Map logs) { + appendHeader = !append || !file.exists(); + } + + // TODO Should we handle Java arrays?? + @SuppressWarnings("unchecked") + private String handleValue(Object val) { + + if (val instanceof String) { + return val.toString(); + } else if (val instanceof NdArray) { // todo + boolean isScalar = ((NdArray) val).rank() == 0; + if (isScalar) { + return ((NdArray) val).getObject().toString(); + } else { + NdArray array = (NdArray) val; + return toString(array); + } + } else if (val instanceof Collection) { + return "[" + + ((Collection) val).stream().map(Object::toString).collect(Collectors.joining(",")) + + "]"; + } else { + return val.toString(); + } + } + + /** + * coverts an NdArray to a printable string + * + * @param ndArray the NdArray + * @return the printable string + */ + private String toString(NdArray ndArray) { + Iterator> iterator = ndArray.scalars().iterator(); + Shape shape = ndArray.shape(); + if (shape.numDimensions() == 0) { + if (!iterator.hasNext()) { + return ""; + } + return valToString(iterator.next().getObject()); + } + return toString(iterator, shape, 0); + } + + private String toString(Iterator> iterator, Shape shape, int dimension) { + if (dimension < shape.numDimensions() - 1) { + StringJoiner joiner = new StringJoiner("", "[", "]"); + for (long i = 0, size = shape.size(dimension); i < size; ++i) { + String element = toString(iterator, shape, dimension + 1); + joiner.add(element); + } + return joiner.toString(); + } else { + StringJoiner joiner = new StringJoiner(", ", "[", "]"); + for (long i = 0, size = shape.size(dimension); i < size; ++i) { + Object element = iterator.next().getObject(); + joiner.add(valToString(element)); + } + return joiner.toString(); + } + } + + /** + * Converts a value to a printable string + * + * @param val the value + * @return tje printable string + */ + private String valToString(Object val) { + if (val instanceof Number) { + Number nVal = (Number) val; + if (nVal instanceof Float || nVal instanceof Double) { + return String.format("%e", nVal.doubleValue()); + } else if (nVal instanceof Byte) { + return String.format("0x%2x", nVal.byteValue()); + } else { + return String.format("%d", nVal.longValue()); + } + } else { + return val.toString(); + } + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public void onEpochEnd(int epoch, Map logs) { + logs = logs == null ? Collections.EMPTY_MAP : logs; + + if (keys == null) { + keys = new ArrayList<>(logs.keySet()); + Collections.sort(this.keys); + } + + if (writer == null) { + try { + List fieldNames = new ArrayList<>(); + fieldNames.add("epoch"); + fieldNames.addAll(this.keys); + CSVFormat csvFormat = + appendHeader + ? CSVFormat.EXCEL + .withHeader(fieldNames.toArray(new String[0])) + .withDelimiter(separator) + : CSVFormat.EXCEL.withDelimiter(separator); + writer = new CSVPrinter(new FileWriter(file, append), csvFormat); + } catch (IOException ex) { + Logger.getLogger(CSVLogger.class.getName()).log(Level.SEVERE, null, ex); + return; + } + } + + /* TODO include when integrated with Model + if (getModel().isStopTraining()) { + final Map flogs = logs; + keys.forEach( + key -> { + if (!flogs.containsKey(key)) { + flogs.put(key, Double.NaN); + } + }); + } + */ + try { + final List values = new ArrayList<>(); + final Map logsFinal = logs; + values.add(String.valueOf(epoch)); + keys.forEach(key -> values.add(handleValue(logsFinal.get(key)))); + writer.printRecord(values); + writer.flush(); + } catch (IOException ex) { + Logger.getLogger(CSVLogger.class.getName()).log(Level.SEVERE, null, ex); + } + } + + /** {@inheritDoc} */ + @Override + public void close() throws IOException { + if (writer != null) { + writer.close(); + writer = null; + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java new file mode 100644 index 00000000000..0ed72e6592c --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java @@ -0,0 +1,262 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import java.util.Collections; +import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Abstract base class used to build new callbacks. + * + *

The logs map that callback methods take as argument will contain keys for quantities relevant + * to the current batch or epoch (see method-specific docstrings). + */ +public abstract class Callback { + protected Map params; + // TODO protected Model model; + + /** Creates a Callback */ + protected Callback() { + this(null); + } + + /** + * Creates a Callback + * + * @param params Training parameters + */ + protected Callback(Map params) { + this.params = params; + } + + /** + * Creates a Callback + * + * @param params Training parameters + * @param model the Model + */ + /* TODO + protected Callback(Map params, Model model) {= + this.params = params; + this.model = model; + } + */ + + /** + * Performs custom processing at the the start of an epoch. This method should only be Performs + * custom processing during TRAIN mode. This method is empty. Extend this class to handle this + * event. + * + * @param epoch index of epoch. + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onEpochBegin(int epoch, Map logs) {} + + /** + * Performs custom processing at the end of an epoch.This method should only be Performs custom + * processing during TRAIN mode. This method is empty. Extend this class to handle this event. + * + * @param epoch index of epoch. + * @param logs metric results for this training epoch, and for the validation epoch if validation + * is performed. Validation result keys are prefixed with `val_`. + */ + @SuppressWarnings("unused") + public void onEpochEnd(int epoch, Map logs) {} + + /** + * Performs custom processing at the beginning of a training batch in `fit` methods. This method + * is empty. Extend this class to handle this event. + * + * @param batch the batch index + * @param logs Has keys `batch` and `size` representing the current batch number and the size of + * the batch. + */ + @SuppressWarnings("unused") + public void onTrainBatchBegin(int batch, Map logs) {} + + /** + * Performs custom processing at the end of a training batch in `fit` methods. This method is + * empty. Extend this class to handle this event. + * + * @param batch index of batch within the current epoch. + * @param logs Metric results for this batch. + */ + @SuppressWarnings("unused") + public void onTrainBatchEnd(int batch, Map logs) {} + + /** + * Performs custom processing at the beginning of training. This method is empty. Extend this + * class to handle this event. + * + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onTrainBegin(Map logs) {} + + /** + * Performs custom processing at the end of training. This method is empty. Extend this class to + * handle this event. + * + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onTrainEnd(Map logs) {} + + /** + * Performs custom processing at the beginning of a batch in `evaluate` methods. Also Performs + * custom processing at the beginning of a validation batch in the `fit` methods, if validation + * data is provided. This method is empty. Extend this class to handle this event. + * + * @param batch the batch number + * @param logs Has keys `batch` and `size` representing the current batch number and the size of + * the batch. + */ + @SuppressWarnings("unused") + public void onTestBatchBegin(int batch, Map logs) {} + + /** + * Performs custom processing at the end of a batch in `evaluate` methods. Also Performs custom + * processing at the end of a validation batch in the `fit` methods, if validation data is + * provided. + * + *

This method is empty. Extend this class to handle this event. + * + * @param batch the batch number + * @param logs Metric results for this batch. + */ + @SuppressWarnings("unused") + public void onTestBatchEnd(int batch, Map logs) {} + + /** + * Performs custom processing at the beginning of evaluation or validation. This method is empty. + * Extend this class to handle this event. + * + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onTestBegin(Map logs) {} + + /** + * Performs custom processing at the end of evaluation or validation. This method is empty. Extend + * this class to handle this event. + * + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onTestEnd(Map logs) {} + + /** + * Performs custom processing at the beginning of a batch in `predict` methods. This method is + * empty. Extend this class to handle this event. + * + * @param batch index of batch within the current epoch. + * @param logs Has keys `batch` and `size` representing the current batch number and the size of + * the batch. + */ + @SuppressWarnings("unused") + public void onPredictBatchBegin(int batch, Map logs) {} + + /** + * Performs custom processing at the end of a batch in `predict` methods. This method is empty. + * Extend this class to handle this event. + * + * @param batch index of batch within the current epoch. + * @param logs Metric results for this batch. + */ + @SuppressWarnings("unused") + public void onPredictBatchEnd(int batch, Map logs) {} + + /** + * Performs custom processing at the beginning of prediction. This method is empty. Extend this + * class to handle this event. + * + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onPredictBegin(Map logs) {} + + /** + * Performs custom processing at the end of prediction. This method is empty. Extend this class to + * handle this event. + * + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onPredictEnd(Map logs) {} + + /** + * Gets a monitor value from the value logs + * + * @param logs the value logs + * @param monitor the monitor to fetch + * @return the monitor value, returns null if the monitor value is not in logs. + */ + @SuppressWarnings("unchecked") + protected Number getMonitorValue(Map logs, String monitor) { + logs = logs == null ? Collections.EMPTY_MAP : logs; + Number monitorValue = logs.get(monitor); + if (monitorValue != null) { + Logger.getLogger(getClass().getName()) + .log( + Level.WARNING, + String.format( + "Early stopping conditioned on metric `%s` which is not available. Available metrics are: %s", + monitor, String.join(",", logs.keySet()))); + } + return monitorValue; + } + + /** + * Gets the params + * + * @return the params + */ + public Map getParams() { + return params; + } + + /** + * Sets the params + * + * @param params the params to set + */ + public void setParams(Map params) { + this.params = params; + } + + /** + * Gets the model + * + * @return the model + */ + /* TODO + public Model getModel() { + return model; + TODO */ + + /** + * Sets the model + * + * @param model the model + */ + /* TODO + public void setModel(Model model) { + this.model = model; + } + TODO */ +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CallbackList.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CallbackList.java new file mode 100644 index 00000000000..a1124f1b266 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CallbackList.java @@ -0,0 +1,230 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +// TODO import org.tensorflow.framework.model.Model; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Container for {@link Callback} instances. + * + *

This object wraps a list of {@link Callback} instances, making it possible to call them all at + * once via a single endpoint (e.g. {@code callbackList.onEpochEnd(...)}). + */ +public class CallbackList extends Callback { + private final List callbacks = new ArrayList<>(); + + // TODO private final Model model; + private History history; + + /** Creates a CallbackList */ + public CallbackList() { + super(); + } + + /** + * Creates a CallbackList + * + * @param addHistory Whether a {@link History} callback should be added, if one does not already + * exist in the {@code callbacks} list. + */ + public CallbackList(boolean addHistory) { + addDefaultCallbacks(addHistory); + } + + /** + * Creates a CallbackList + * + * @param callbacks List of {@link Callback} instances. + * @param addHistory Whether a {@link History} callback should be added, if one does not already + * exist * in the {@code callbacks} list. + */ + public CallbackList(List callbacks, boolean addHistory) { + this(addHistory); + this.callbacks.addAll(callbacks); + } + + /* TODO add when integrated with Model + // /** + * Creates a CallbackList + * + * @param model the model these callbacks are used with. + * @param addHistory Whether a {@link History} callback should be added, if one does not already + * exist in the {@code callbacks} list. + // * + + public CallbackList(Model model, boolean addHistory) { + this.model = model; + addDefaultCallbacks(addHistory); + } + TODO */ + + /* TODO add when integrated with Model + // /** + * Creates a CallbackList + * + * @param model the model these callbacks are used with. + * @param callbacks List of {@link Callback} instances. + * @param addHistory Whether a {@link History} callback should be added, if one does not already + * exist in the {@code callbacks} list. + // * + + public CallbackList(Model model, List callbacks, boolean addHistory) { + this(model, addHistory); + this.callbacks.addAll(callbacks); + } + TODO */ + + /** + * Adds Callback's that are always present. + * + * @param addHistory Whether a {@link History} callback should be added, if one does not already + * exist in the {@code callbacks} list. + */ + private void addDefaultCallbacks(boolean addHistory) { + callbacks.forEach( + c -> { + if (c instanceof History) { + history = (History) c; + } + }); + if (history == null && addHistory) { + history = new History(); + addCallback(history); + } + } + + /** + * Adds a callback + * + * @param callback the callback + */ + public void addCallback(Callback callback) { + callbacks.add(callback); + } + + /** + * Adds callbacks + * + * @param callbacks the callbacks + */ + public void addCallbacks(List callbacks) { + this.callbacks.addAll(callbacks); + } + + /** {@inheritDoc } */ + @Override + public void onTrainBegin(Map logs) { + callbacks.forEach(c -> c.onTrainBegin(logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTrainEnd(Map logs) { + callbacks.forEach(c -> c.onTrainEnd(logs)); + } + /** {@inheritDoc } */ + @Override + public void onEpochBegin(int epoch, Map logs) { + callbacks.forEach(c -> c.onEpochBegin(epoch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onEpochEnd(int epoch, Map logs) { + callbacks.forEach(c -> c.onEpochEnd(epoch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTrainBatchBegin(int batch, Map logs) { + callbacks.forEach(c -> c.onTrainBatchBegin(batch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTrainBatchEnd(int batch, Map logs) { + callbacks.forEach(c -> c.onTrainBatchEnd(batch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTestBatchBegin(int batch, Map logs) { + callbacks.forEach(c -> c.onTestBatchBegin(batch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTestBatchEnd(int batch, Map logs) { + callbacks.forEach(c -> c.onTestBatchEnd(batch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTestBegin(Map logs) { + callbacks.forEach(c -> c.onTestBegin(logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTestEnd(Map logs) { + callbacks.forEach(c -> c.onTestEnd(logs)); + } + + /** {@inheritDoc } */ + @Override + public void onPredictBatchBegin(int batch, Map logs) { + callbacks.forEach(c -> c.onPredictBatchBegin(batch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onPredictBatchEnd(int batch, Map logs) { + callbacks.forEach(c -> c.onPredictBatchEnd(batch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onPredictBegin(Map logs) { + callbacks.forEach(c -> c.onPredictBegin(logs)); + } + + /** {@inheritDoc } */ + @Override + public void onPredictEnd(Map logs) { + callbacks.forEach(c -> c.onPredictEnd(logs)); + } + + /** + * Gets the callbacks + * + * @return the callbacks + */ + public List getCallbacks() { + return callbacks; + } + + /** + * Gets the history callback + * + * @return the history callback + */ + public History getHistory() { + return history; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/History.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/History.java new file mode 100644 index 00000000000..53ce9373d8f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/History.java @@ -0,0 +1,91 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Callback that records events into a History object. + * + *

This callback is automatically applied to every model. The History object gets returned by the + * fit method of models. + */ +public class History extends Callback { + private final Map> history = new HashMap<>(); + private final List epoch = new ArrayList<>(); + + /** Creates a History Callback */ + public History() { + super(); + } + + /* TODO + * Creates a History Callback + * + * @param params Training parameters + * @param model the Model + * + + public History( Model model) {= + super(null, model); + } + TODO */ + + /** {@inheritDoc} */ + @Override + public void onTrainBegin(Map logs) { + epoch.clear(); + } + + /** {@inheritDoc} */ + @Override + public void onEpochEnd(int epoch, Map logs) { + Map localLogs = logs == null ? Collections.emptyMap() : logs; + this.epoch.add(epoch); + + logs.entrySet() + .forEach( + e -> { + List item = history.get(e.getKey()); + if (item == null) { + item = new ArrayList<>(); + history.put(e.getKey(), item); + } + item.add(e.getValue()); + }); + } + + /** + * Gets the History map for each log value + * + * @return the History map for each log value + */ + public Map> getHistory() { + return history; + } + + /** + * Gets the history of epochs + * + * @return the history of epochs + */ + public List getEpoch() { + return epoch; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java new file mode 100644 index 00000000000..af976eee2ff --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java @@ -0,0 +1,233 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +/** + * Callback for creating simple, custom callbacks on-the-fly. + * + *

Example: + * + *

{@code
+ * LambdaCallbacks batchPrintCallback = new LambdaCallbacks();
+ * batchPrintCallback.setOnTrainBatchBegin((batch, logs)->
+ *     System.out.println("Batch: " + batch + " started");
+ * }
+ * + *

This callback is constructed with anonymous functions that will be called at the appropriate + * time. Note that the callbacks expects positional arguments, as: + * + *

+ */ +public class LambdaCallback extends Callback { + + /** Called at the beginning of every epoch. expect two positional arguments: `epoch`, `logs` */ + private BiConsumer> onEpochBegin; + + /** Called at the end of every epoch. expect two positional arguments: `epoch`, `logs` */ + private BiConsumer> onEpochEnd; + + /** Called at the beginning of every batch. expect two positional arguments: `batch`, `logs` */ + private BiConsumer> onTrainBatchBegin; + + /** called at the end of every batch. expect two positional arguments: `batch`, `logs` */ + private BiConsumer> onTrainBatchEnd; + + /** called at the beginning of model training. expect one positional argument: `logs` */ + private Consumer> onTrainBegin; + + /** called at the end of model training. expect one positional argument: `logs` */ + private Consumer> onTrainEnd; + + /** Creates a LambdaCallbacks callback */ + public LambdaCallback() { + super(); + } + + /** + * Creates a LambdaCallbacks callback + * + * @param params Training parameters + */ + public LambdaCallback(Map params) { + super(params); + } + + /** {@inheritDoc} */ + @Override + public void onEpochBegin(int epoch, Map logs) { + if (this.onEpochBegin != null) { + this.onEpochBegin.accept(epoch, logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onEpochEnd(int epoch, Map logs) { + if (this.onEpochEnd != null) { + this.onEpochEnd.accept(epoch, logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onTrainBatchBegin(int batch, Map logs) { + if (this.onTrainBatchBegin != null) { + this.onTrainBatchBegin.accept(batch, logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onTrainBatchEnd(int batch, Map logs) { + if (this.onTrainBatchEnd != null) { + this.onTrainBatchEnd.accept(batch, logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onTrainBegin(Map logs) { + if (this.onTrainBegin != null) { + this.onTrainBegin.accept(logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onTrainEnd(Map logs) { + if (this.onTrainEnd != null) { + this.onTrainEnd.accept(logs); + } + } + + /** + * Gets the onEpochBegin lambda function + * + * @return the onEpochBegin lambda function + */ + public BiConsumer> getOnEpochBegin() { + return onEpochBegin; + } + + /** + * Sets the onEpochBegin lambda function + * + * @param onEpochBegin lambda function to set + */ + public void setOnEpochBegin(BiConsumer> onEpochBegin) { + this.onEpochBegin = onEpochBegin; + } + + /** + * Gets the onEpochEnd lambda function + * + * @return the onEpochEnd lambda function + */ + public BiConsumer> getOnEpochEnd() { + return onEpochEnd; + } + + /** + * Sets the onEpochEnd lambda function + * + * @param onEpochEnd the lambda function + */ + public void setOnEpochEnd(BiConsumer> onEpochEnd) { + this.onEpochEnd = onEpochEnd; + } + + /** + * Gets the onTrainBatchBegin lambda function + * + * @return the onTrainBatchBegin lambda function + */ + public BiConsumer> getOnTrainBatchBegin() { + return onTrainBatchBegin; + } + + /** + * Sets the onTrainBatchBegin lambda function + * + * @param onTrainBatchBegin the lambda function + */ + public void setOnTrainBatchBegin(BiConsumer> onTrainBatchBegin) { + this.onTrainBatchBegin = onTrainBatchBegin; + } + + /** + * Gets the onTrainBatchEnd lambda function + * + * @return the onTrainBatchEnd lambda function + */ + public BiConsumer> getOnTrainBatchEnd() { + return onTrainBatchEnd; + } + + /** + * Sets the onTrainBatchEnd lambda function + * + * @param onTrainBatchEnd the onTrainBatchEnd lambda function + */ + public void setOnTrainBatchEnd(BiConsumer> onTrainBatchEnd) { + this.onTrainBatchEnd = onTrainBatchEnd; + } + + /** + * Gets the onTrainBegin lambda function + * + * @return the onTrainBegin lambda function + */ + public Consumer> getOnTrainBegin() { + return onTrainBegin; + } + + /** + * Sets the onTrainBegin lambda function + * + * @param onTrainBegin the onTrainBegin lambda function + */ + public void setOnTrainBegin(Consumer> onTrainBegin) { + this.onTrainBegin = onTrainBegin; + } + + /** + * Gets the onTrainEnd lambda function + * + * @return the onTrainEnd lambda function + */ + public Consumer> getOnTrainEnd() { + return onTrainEnd; + } + + /** + * Sets the onTrainEnd lambda function + * + * @param onTrainEnd the onTrainEnd lambda function + */ + public void setOnTrainEnd(Consumer> onTrainEnd) { + this.onTrainEnd = onTrainEnd; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Mode.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Mode.java new file mode 100644 index 00000000000..50035c0731b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Mode.java @@ -0,0 +1,22 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +/** The mode on when a Callback takes action. */ +public enum Mode { + AUTO, + MIN, + MAX +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/ProgbarLogger.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/ProgbarLogger.java new file mode 100644 index 00000000000..112fa8f3cd1 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/ProgbarLogger.java @@ -0,0 +1,353 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.tensorflow.framework.callbacks.util.ProgressBar; + +import java.io.PrintWriter; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.IntSupplier; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** Callback that prints metrics to console. */ +public class ProgbarLogger extends Callback { + + private final ProgressBar.CountMode unit; + private Set statefulMetrics; + private int seen = 0; + private ProgressBar progbar = null; + private Integer target = null; + private ProgressBar.VerboseMode verbose = ProgressBar.VerboseMode.VERBOSE; + private int epochs = 1; + private boolean calledInFit = false; + + // TODO wire these up to Model + private final IntSupplier getTrainCounter = null; + private final IntSupplier getTestCounter = null; + private final IntSupplier getPredictCounter = null; + + private PrintWriter writer = null; + + /** Create a ProgbarLogger */ + public ProgbarLogger() { + this(null, null, ProgressBar.CountMode.SAMPLES, (List) null); + } + + /** + * Create a ProgbarLogger + * + * @param mode Whether the progress bar should count SAMPLES seen or STEPS (batches) seen. + */ + public ProgbarLogger(ProgressBar.CountMode mode) { + this(null, null, mode, (List) null); + } + + /** + * Create a ProgbarLogger + * + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgbarLogger(List statefulMetrics) { + this(null, null, ProgressBar.CountMode.SAMPLES, statefulMetrics); + } + + /** + * Create a ProgbarLogger + * + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgbarLogger(String... statefulMetrics) { + this(null, null, ProgressBar.CountMode.SAMPLES, Arrays.asList(statefulMetrics)); + } + + /** + * Create a ProgbarLogger + * + * @param mode Whether the progress bar should count SAMPLES seen or STEPS (batches) seen. + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgbarLogger(ProgressBar.CountMode mode, List statefulMetrics) { + this(null, null, mode, statefulMetrics); + } + + /** + * Create a ProgbarLogger + * + * @param mode Whether the progress bar should count SAMPLES seen or STEPS (batches) seen. + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgbarLogger(ProgressBar.CountMode mode, String... statefulMetrics) { + this(null, null, mode, Arrays.asList(statefulMetrics)); + } + + /** + * Create a ProgbarLogger + * + * @param params Training parameters + */ + public ProgbarLogger(Map params) { + this(params, null, ProgressBar.CountMode.SAMPLES, (List) null); + } + + /** + * Create a ProgbarLogger + * + * @param params Training parameters + * @param model Reference of the model being trained. + */ + public ProgbarLogger(Map params, Object model) { + this(params, model, ProgressBar.CountMode.SAMPLES, (List) null); + } + + /** + * Create a ProgbarLogger + * + * @param params Training parameters + * @param model Reference of the model being trained. + * @param mode Whether the progress bar should count SAMPLES seen or STEPS (batches) seen. + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgbarLogger( + Map params, + Object model, + ProgressBar.CountMode mode, + String... statefulMetrics) { + this(params, model, mode, Arrays.asList(statefulMetrics)); + } + + /** + * Create a ProgbarLogger + * + * @param params Training parameters + * @param model Reference of the model being trained. + * @param unit Whether the progress bar should count SAMPLES seen or STEPS (batches) seen. + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgbarLogger( + Map params, + Object model, + ProgressBar.CountMode unit, + List statefulMetrics) { + // TODO super(params, model); + + this.unit = unit; + this.statefulMetrics = + statefulMetrics != null ? new HashSet<>(statefulMetrics) : new HashSet<>(); + setParams(params); + } + + /** {@inheritDoc} */ + @Override + public final void setParams(Map params) { + if (params == null) { + return; + } + super.setParams(params); + verbose = + ((ProgressBar.VerboseMode) params.getOrDefault("verbose", ProgressBar.VerboseMode.VERBOSE)); + epochs = (Integer) params.getOrDefault("epochs", 1); + target = + unit == ProgressBar.CountMode.STEPS + ? (Integer) params.get("steps") + : (Integer) params.get("samples"); + writer = (PrintWriter) params.get("writer"); + + if (target == null) { + /* TODO wire into Model + getTrainCounter = () -> model.getTrainCounter(); + getTestCounter = () -> model.getTestCounter(); + getPredictCounter = () -> model.getPredictCounter(); + + */ + } + } + + /** {@inheritDoc} */ + @Override + public void onTrainBegin(Map logs) { + // When this logger is called inside fit, validation is silent. + calledInFit = true; + } + + /** {@inheritDoc} */ + @Override + public void onTestBegin(Map logs) { + if (!calledInFit) { + resetProgBar(); + maybeInitProgbar(); + } + } + + /** {@inheritDoc} */ + @Override + public void onPredictBegin(Map logs) { + resetProgBar(); + maybeInitProgbar(); + } + + /** {@inheritDoc} */ + @Override + public void onEpochBegin(int epoch, Map logs) { + resetProgBar(); + maybeInitProgbar(); + + if (verbose != ProgressBar.VerboseMode.SILENT && epochs > 1) { + Logger.getLogger(ProgbarLogger.class.getName()) + .log(Level.INFO, String.format("Epoch %d/%d", (epoch + 1), epochs)); + } + } + + @Override + public void onTrainBatchEnd(int batch, Map logs) { + batchUpdateProgbar(batch, logs); + } + + @Override + public void onTestBatchEnd(int batch, Map logs) { + if (!calledInFit) { + batchUpdateProgbar(batch, logs); + } + } + + @Override + public void onPredictBatchEnd(int batch, Map logs) { + // Don't pass prediction results. + super.onPredictBatchEnd(batch, null); + } + + /** {@inheritDoc} */ + @Override + public void onEpochEnd(int epoch, Map logs) { + finalizeProgbar(logs, getTrainCounter); + } + + /** {@inheritDoc} */ + @Override + public void onTestEnd(Map logs) { + if (!calledInFit) { + finalizeProgbar(logs, getTestCounter); + } + } + + /** {@inheritDoc} */ + @Override + public void onPredictEnd(Map logs) { + finalizeProgbar(logs, getPredictCounter); + } + + /** + * Updates the {@link ProgressBar} + * + * @param batch the batch number + * @param logs loss or metric results + */ + private void batchUpdateProgbar(int batch, Map logs) { + Map llogs = logs == null ? Collections.emptyMap() : logs; + maybeInitProgbar(); + + if (unit == ProgressBar.CountMode.STEPS) { + seen = batch + 1; + } else { + // make shallow copy + llogs = new HashMap<>(llogs); + Number batchSize = llogs.getOrDefault("size", 0); + Number numSteps = llogs.getOrDefault("num_steps", 1); + llogs.remove("batch"); + int addSeen = numSteps.intValue() * batchSize.intValue(); + seen += addSeen; + } + + if (verbose != ProgressBar.VerboseMode.SILENT) { + progbar.update(seen, llogs, false); + } + } + + /** + * Finalizes the Progess Bar + * + * @param logs results to apply + * @param getCounter gets the counter from the model + */ + private void finalizeProgbar(Map logs, IntSupplier getCounter) { + if (progbar != null) { + Integer counter = null; + if (target == null) { + if (getCounter != null) { + int counterValue = getCounter.getAsInt(); + if (unit == ProgressBar.CountMode.SAMPLES) { + Number size = logs.getOrDefault("", 1); + counterValue *= size.intValue(); + } + counter = counterValue; + } + target = counter == null ? seen : counter; + progbar.setTarget(target); + } + progbar.update(seen, logs, true); + } + } + + private void resetProgBar() { + seen = 0; + progbar = null; + } + + private void maybeInitProgbar() { + if (statefulMetrics == null) { + /* TODO - Model + if(model != null) { + statefulMetrics = new ArrayList<>(); + model.getMetrics().forEach(m -> statefulMetrics.add(m.getName)); + }else { + */ + statefulMetrics = new HashSet<>(); + // TODO - Model } + } + if (progbar == null) { + if (writer == null) { + progbar = new ProgressBar(target, verbose, unit, statefulMetrics); + } else { + progbar = + new ProgressBar( + writer, + target, + ProgressBar.DEFAULT_WIDTH, + verbose, + ProgressBar.DEFAULT_INTERVAL, + unit, + statefulMetrics); + } + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/UpdateFreq.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/UpdateFreq.java new file mode 100644 index 00000000000..61713630e9f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/UpdateFreq.java @@ -0,0 +1,23 @@ +/* Copyright 202 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +/** Enum that defines the update frequency */ +public enum UpdateFreq { + /** Update on every epoch */ + EPOCH, + /** Update on every batch */ + BATCH +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/VerboseMode.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/VerboseMode.java new file mode 100644 index 00000000000..705603b05f1 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/VerboseMode.java @@ -0,0 +1,16 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormat.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormat.java new file mode 100644 index 00000000000..e2340c4ed73 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormat.java @@ -0,0 +1,92 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks.util; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * These are utilities for generating a file path from an original file path containing named + * formatting options. + */ +public class PathPlaceholderStringFormat { + private static final Pattern PYTHON_MATCH = Pattern.compile("\\{(\\w+):([\\w.]+)}"); + + /** + * Converts a filepath containing named formatting options, which will be filled with the value of + * epoch and keys in logs (passed in onEpochEnd). + * + *

For example: + * + *

if filepath is * weights.{epoch:02d}-{val_loss:.2f}.hdf5, then the + * model checkpoints will be saved with the epoch number and the validation loss in the filename + * (e.g. "weights.561-0.71.hdf5"). + * + * @param filename the filename containing the formatting options + * @param epoch the epoch + * @param logs the logs map that contain the values + * @return the converted file path name + */ + public static String convertFilePath(String filename, int epoch, Map logs) { + List vars = new ArrayList<>(); + String format = getFilePath(filename, vars); + List values = new ArrayList<>(); + vars.forEach( + key -> { + if (key.equals("epoch")) values.add(epoch); + else if (logs.containsKey(key)) values.add(logs.get(key).doubleValue()); + else values.add(0.0); + }); + return String.format(format, values.toArray()); + } + + /** + * Creates a {@link String#format} string for formatting the filepath for including the log values + * identified by the original filepath placeholder names + * + * @param filename the filename with the placeholders. + * @param vars the list is populated with the log names from the placeholder names found in the + * original file path string that will be included in resulting name + * @return the String format for formatting the values identified from the placeholder names. + */ + private static String getFilePath(String filename, List vars) { + Matcher m = PYTHON_MATCH.matcher(filename); + StringBuilder sb = new StringBuilder(); + int beginIndex = 0; + Map indexMap = new HashMap<>(); + int lastIndex = 1; + while (m.find()) { + int start = m.start(0); + int end = m.end(0); + String variable = m.group(1); + vars.add(variable); + String format = m.group(2); + Integer index = indexMap.get(variable); + if (index == null) { + indexMap.put(variable, lastIndex); + index = lastIndex++; + } + sb.append(filename, beginIndex, start); + sb.append('%').append(index).append('$').append(format); + beginIndex = end; + } + sb.append(filename.substring(beginIndex)); + return sb.toString(); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/ProgressBar.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/ProgressBar.java new file mode 100644 index 00000000000..cf645157c13 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/ProgressBar.java @@ -0,0 +1,489 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks.util; + +import java.io.Console; +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Displays a progress bar. + * + *

Output may be sent to {@link Console} if one is present, otherwise it is sent to {@link + * System#out}. Whether a Java virtual machine has a console is dependent upon the underlying + * platform. + */ +public class ProgressBar { + public static final int DEFAULT_WIDTH = 30; + public static final VerboseMode DEFAULT_VERBOSE = VerboseMode.VERBOSE; + + // NOTE: Console may be null + public static final double DEFAULT_INTERVAL = 50; // msecs + public static final CountMode DEFAULT_COUNT_MODE = CountMode.STEPS; + private static final double MICRO_SECONDS = 1E-3; + private static final long MILLI_SECOND = 1L; + private static final long SECOND = 1000L; + private static final long MINUTE = 60L * SECOND; + private static final long HOUR = 60L * MINUTE; + private final int width; + private final VerboseMode verbose; + private final long interval; + private final Set statefulMetrics; + private final CountMode unit; + private final Map> values = new HashMap<>(); + private final List valuesOrdered = new ArrayList<>(); + private final long start = System.currentTimeMillis(); + // will be null if java system console is not present + private final Console console; + // defaults to stdout if console is null; + private final PrintWriter writer; + private final boolean dynamicDisplay; + private Integer target; + private int totalWidth; + private int seenSoFar; + private long lastUpdate; + private Long timeAfterFirstStep; + /** Create a ProgressBar */ + public ProgressBar() { + this(null, DEFAULT_WIDTH, DEFAULT_VERBOSE, DEFAULT_INTERVAL, DEFAULT_COUNT_MODE, null); + } + /** + * Create a ProgressBar + * + * @param target Total number of steps expected, null if unknown. + */ + public ProgressBar(Integer target) { + this(target, DEFAULT_WIDTH, DEFAULT_VERBOSE, DEFAULT_INTERVAL, DEFAULT_COUNT_MODE, null); + } + + /** + * Create a ProgressBar + * + * @param target Total number of steps expected, null if unknown. + * @param verbose Verbosity mode + * @param unit Display name for step counts, "step" or "sample". + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgressBar( + Integer target, VerboseMode verbose, CountMode unit, Set statefulMetrics) { + this(target, DEFAULT_WIDTH, verbose, DEFAULT_INTERVAL, unit, statefulMetrics); + } + + /** + * Create a ProgressBar + * + * @param target Total number of steps expected, null if unknown. + * @param width Progress bar width on screen. + * @param verbose Verbosity mode, false is silent + * @param interval Minimum visual progress update interval (in milliseconds). + * @param unit Display name for step counts, "step" or "sample". + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgressBar( + Integer target, + int width, + VerboseMode verbose, + double interval, + CountMode unit, + Set statefulMetrics) { + this.target = target; + this.width = width; + this.verbose = verbose; + this.interval = (long) interval; + this.unit = unit; + this.statefulMetrics = statefulMetrics == null ? Collections.emptySet() : statefulMetrics; + switch (verbose) { + case VERBOSE: + console = System.console(); + writer = console != null ? null : new PrintWriter(System.out); + break; + case SEMI_VERBOSE: + writer = new PrintWriter(System.out); + console = null; + break; + default: + writer = null; + console = null; + break; + } + dynamicDisplay = console != null; + } + + /** + * Create a ProgressBar + * + * @param writer the writer to use rather than the defaults of {@link Console} or {@link + * System#out} + * @param target Total number of steps expected, null if unknown. + * @param width Progress bar width on screen. + * @param verbose Verbosity mode, false is silent + * @param interval Minimum visual progress update interval (in milliseconds). + * @param unit Display name for step counts, "step" or "sample". + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgressBar( + PrintWriter writer, + Integer target, + int width, + VerboseMode verbose, + double interval, + CountMode unit, + Set statefulMetrics) { + this.target = target; + this.width = width; + this.verbose = verbose; + this.interval = (long) interval; + this.unit = unit; + this.statefulMetrics = statefulMetrics == null ? Collections.emptySet() : statefulMetrics; + switch (verbose) { + case VERBOSE: + case SEMI_VERBOSE: + console = null; + this.writer = writer; + break; + default: + this.writer = null; + console = null; + break; + } + dynamicDisplay = false; + } + + /** + * Updates the progress bar. + * + * @param seen Index of current step. + * @param logValues map of log values to apply + */ + public void update(int seen, Map logValues) { + update(seen, logValues, null); + } + + /** + * Updates the progress bar. + * + * @param current Index of current step, if null the current step is calculated. + * @param logValues Map of log values + * @param finalize Whether this is the last update for the progress bar. If null, defaults to + * {@code current >= self.target}. + */ + public void update(Integer current, Map logValues, Boolean finalize) { + boolean shouldFinalize; + + int iCurrent = current == null ? 0 : current; + + if (finalize == null) { + if (target == null) { + shouldFinalize = false; + } else { + shouldFinalize = iCurrent >= target; + } + } else { + shouldFinalize = finalize; + } + + Map lValues = logValues == null ? Collections.emptyMap() : logValues; + lValues.forEach( + (key, value) -> { + if (!this.valuesOrdered.contains(key)) { + this.valuesOrdered.add(key); + } + if (!this.statefulMetrics.contains(key)) { + // In the case that progress bar doesn't have a target value in the first + // epoch, both onTrainBatchEnd and onEpochEnd will be called, which will + // cause 'current' and 'seenSoFar' to have the same value. Force + // the minimal value to 1 here, otherwise stateful_metric will be 0s. + int valueBase = Math.max(iCurrent - seenSoFar, 1); + // stores the pair, the value and its base, (10, 10) == 1, 10,100 = .1 + if (!values.containsKey(key)) { + values.put(key, Arrays.asList(value.doubleValue() * valueBase, valueBase)); + } else { + List currentValues = values.get(key); + double v1 = currentValues.get(0).doubleValue() + value.doubleValue() * valueBase; + int b1 = currentValues.get(1).intValue() + valueBase; + values.put(key, Arrays.asList(v1, b1)); + } + } else { + values.put(key, Arrays.asList(value, 1)); + } + }); + this.seenSoFar = iCurrent; + long now = System.currentTimeMillis(); + // convert time to seconds + double timeSinceStart = (double) (now - this.start) / SECOND; + StringBuilder info = new StringBuilder(String.format(" - %.0fs", timeSinceStart)); + + if (this.verbose == VerboseMode.VERBOSE) { + if (now - lastUpdate < this.interval && !shouldFinalize) { + return; + } + + int prevTotalWidth = this.totalWidth; + if (dynamicDisplay) { + // backspace to beginning of line + this.console.printf("%s\r", repeat("\b", prevTotalWidth)); + } else { + writer.print('\n'); + } + StringBuilder bar = new StringBuilder(); + if (target != null) { + int numDigits = target > 0 ? (int) Math.log10(target) + 1 : 1; + String formatStr = String.format("%%%dd/%%d [", numDigits); + bar.append(String.format(formatStr, iCurrent, target)); + double prog = (double) iCurrent / target; + int progWidth = (int) (this.width * prog); + if (progWidth > 0) { + bar.append(repeat("=", progWidth - 1)); + if (iCurrent < target) { + bar.append('>'); + } else { + bar.append('='); + } + bar.append(repeat(".", width - progWidth)); + bar.append(']'); + } + } else { + bar.append(String.format("%7d/Unknown", iCurrent)); + } + this.totalWidth = bar.length(); + print(bar.toString()); + + // NOTE: in millis, Python is in seconds + double timePerUnit = estimateStepDuration(current, now); + if (target == null || shouldFinalize) { + if (timePerUnit >= SECOND || timePerUnit == 0.0) { // seconds + info.append( + String.format(" %.0fs/%s", timePerUnit / SECOND, unit.toString().toLowerCase())); + } else if (timePerUnit >= 1) { // milliseconds + info.append(String.format(" %.0fms/%s", timePerUnit, unit.toString().toLowerCase())); + } else { // microseconds + info.append( + String.format(" %.0fus/%s", timePerUnit * SECOND, unit.toString().toLowerCase())); + } + } else { + double eta = timePerUnit * (target - iCurrent); + String etaFormat; + if (eta > HOUR) { // greater than an hour + + etaFormat = + String.format( + "%d:%02d:%02d", // hh:mm:ss + (int) (eta / HOUR), (int) ((eta % HOUR) / MINUTE), (int) (eta % MINUTE) / SECOND); + } else if (eta > MINUTE) { + etaFormat = + String.format( + "%d:%02d", // mm:ss + (int) (eta / MINUTE), (int) (eta % MINUTE) / SECOND); + } else { + etaFormat = String.format("%ds", (int) (eta / SECOND)); // seconds + } + info.append(" - ETA: ").append(etaFormat); + } + + this.valuesOrdered.forEach( + key -> { + info.append(String.format(" - %s:", key)); + List vals = values.get(key); + double avg = vals.get(0).doubleValue() / Math.max(1.0, vals.get(1).doubleValue()); + if (Math.abs(avg) > 1e-3) { // Normal number + info.append(String.format(" %.4f", avg)); + } else { // Floating point notation. + info.append(String.format(" %.4e", avg)); + } + }); + totalWidth += info.length(); + if (prevTotalWidth > totalWidth) { + info.append(repeat(" ", prevTotalWidth - totalWidth)); + } + if (shouldFinalize) { + info.append('\n'); + } + print(info.toString(), true); + + } else if (verbose == VerboseMode.SEMI_VERBOSE) { + if (shouldFinalize) { + int numDigits = target > 0 ? (int) Math.log10(target) + 1 : 1; + String formatStr = String.format("%%%dd/%%d [", numDigits); + final StringBuilder tmpInfo = + new StringBuilder(String.format(formatStr, iCurrent, target)).append(info); + valuesOrdered.forEach( + k -> { + tmpInfo.append(String.format(" - %s:", k)); + List valEntry = values.get(k); + // TODO average + double avg = + valEntry.get(0).doubleValue() / Math.max(1.0, valEntry.get(1).doubleValue()); + if (avg > 1e-3) { + tmpInfo.append(String.format(" %.4f", avg)); + } else { + tmpInfo.append(String.format(" %.4ef", avg)); + } + }); + print(tmpInfo.toString(), true); + } + } + this.lastUpdate = now; + } + + /** + * Print the string to the output stream without flushing + * + * @param s the string + */ + private void print(String s) { + print(s, false); + } + + /** + * Print the string to the output stream + * + * @param s the string + * @param doFlush whether to flush the output after printing. + */ + private void print(String s, boolean doFlush) { + if (this.console != null) { + this.console.printf(s); + if (doFlush) { + this.console.flush(); + } + } else { + writer.print(s); + if (doFlush) { + writer.flush(); + } + } + } + + /** + * Estimate the duration of a single step, in millis + * + *

Given the step number, current, and the corresponding time, now, this function returns an + * estimate for how long a single step takes. If this is called before one step has been completed + * (i.e.{@code current == 0}) then zero is given as an estimate. The duration estimate ignores the + * duration of the (assumed to be non-representative) first step for estimates when more steps are + * available (i.e. {@code current 1}). + * + * @param current Index of current step + * @param now the current time + * @return the estimate of the duration of a single step. + */ + private double estimateStepDuration(Integer current, long now) { + if (current != null) { + // there are a few special scenarios here: + // 1) somebody is calling the progress bar without ever supplying step 1 + // 2) somebody is calling the progress bar and supplies step one mulitple + // times, e.g. as part of a finalizing call + // in these cases, we just fall back to the simple calculation + double timePerUnit; + if (current == 0) { + timePerUnit = (now - start); + } else if (timeAfterFirstStep != null && current > 2) { + timePerUnit = (double) (now - timeAfterFirstStep) / (double) (current - 1); + } else { + timePerUnit = (double) (now - start) / (double) current; + } + if (current == 1) { + timeAfterFirstStep = now; + } + return timePerUnit; + + } else { + return 0; + } + } + + /** + * Repeats the string s, count times. + * + * @param s the string to repeat + * @param count the number of times to repeat + * @return the repeated string + */ + private String repeat(String s, int count) { + return new String(new char[count]).replace("\0", s); + } + + /** updates the progress bar by one unit */ + public void increment() { + add(1); + } + + /** + * updates the progress bar by one unit + * + * @param logValues map of log values to apply + */ + public void increment(Map logValues) { + add(1, logValues); + } + + /** + * update the progress bar + * + * @param n the number of units to add to the current number + */ + public void add(int n) { + add(n, null); + } + + /** + * update the progress bar + * + * @param n the number of units to add to the current number + * @param logValues map of log values to apply + */ + public void add(int n, Map logValues) { + this.update(this.seenSoFar + n, logValues); + } + + /** @return the target */ + public Integer getTarget() { + return target; + } + + /** @param target the target to set */ + public void setTarget(Integer target) { + this.target = target; + } + + public enum CountMode { + /** the progress bar should count steps () */ + STEPS, + /** the progress bar should count samples */ + SAMPLES + } + + /** Verbosity mode */ + public enum VerboseMode { + /** Do not log output */ + SILENT, + /** verbose, try to use {@link Console}, if available */ + VERBOSE, + /** Semi verbose, Use {@link System#out} */ + SEMI_VERBOSE + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java new file mode 100644 index 00000000000..63988fe4166 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java @@ -0,0 +1,113 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVRecord; +import org.junit.jupiter.api.Test; +import org.tensorflow.types.TFloat64; + +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.io.Reader; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.fail; + +class CSVLoggerTest { + + @Test + public void testStandAlone() { + try { + int epoch = 0; + double[] values = {0.95, 0.90, 0.85, 0.90, 0.99, Double.NaN}; + File tmpFile = File.createTempFile("tf-test", ".csv"); + Map logs = new HashMap<>(); + try (CSVLogger csvLogger = new CSVLogger<>(tmpFile)) { + csvLogger.onTrainBegin(null); + for (int i = 0; i < values.length; i++) { + logs.put("accuracy", values[epoch]); + csvLogger.onEpochEnd(epoch++, logs); + } + + try (Reader reader = new FileReader(tmpFile)) { + Iterable records = CSVFormat.EXCEL.withFirstRecordAsHeader().parse(reader); + int iv = 0; + for (CSVRecord record : records) { + String epochStr = record.get("epoch"); + assertNotNull(epochStr); + String valueStr = record.get("accuracy"); + assertNotNull(valueStr); + assertEquals(iv, Integer.valueOf(epochStr)); + double v = Double.valueOf(valueStr); + assertEquals(values[iv++], v, 0e-6); + } + } finally { + tmpFile.delete(); + } + } + + } catch (IOException ex) { + fail(ex); + } + } + + @Test + public void testStandAlone2Vals() { + try { + int epoch = 0; + double[] valuesAcc = {0.95, 0.90, 0.85, 0.90, 0.99, Double.NaN}; + double[] valuesErr = {1e-1, 1e-2, 1e-3, 1e-4, 1e-5, Double.NaN}; + File tmpFile = File.createTempFile("tf-test", ".csv"); + Map logs = new HashMap<>(); + try (CSVLogger csvLogger = new CSVLogger<>(tmpFile)) { + csvLogger.onTrainBegin(null); + for (int i = 0; i < valuesAcc.length; i++) { + logs.put("accuracy", valuesAcc[epoch]); + logs.put("error", valuesErr[epoch]); + csvLogger.onEpochEnd(epoch++, logs); + } + + try (Reader reader = new FileReader(tmpFile)) { + Iterable records = CSVFormat.EXCEL.withFirstRecordAsHeader().parse(reader); + int iv = 0; + for (CSVRecord record : records) { + String epochStr = record.get("epoch"); + assertNotNull(epochStr); + String valueStr = record.get("accuracy"); + assertNotNull(valueStr); + String errorStr = record.get("error"); + assertNotNull(errorStr); + assertEquals(iv, Integer.valueOf(epochStr)); + double v = Double.valueOf(valueStr); + assertEquals(valuesAcc[iv], v, 0e-6); + double e = Double.valueOf(errorStr); + assertEquals(valuesErr[iv], e, 0e-8); + iv++; + } + } finally { + tmpFile.delete(); + } + } + + } catch (IOException ex) { + fail(ex); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CallbackListTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CallbackListTest.java new file mode 100644 index 00000000000..0e4ef138a26 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CallbackListTest.java @@ -0,0 +1,52 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +class CallbackListTest { + + @Test + public void testUpdates() { + Map logs = new HashMap<>(); + logs.put("acc", 0.98); + LambdaCallback lambdaCB = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + + lambdaCB.setOnEpochBegin( + (epoch, log) -> { + called.set(true); + }); + + CallbackList instance = new CallbackList(true); + History history = instance.getHistory(); + instance.addCallback(lambdaCB); + + instance.onTrainBegin(null); + instance.onEpochBegin(0, logs); + instance.onEpochEnd(0, logs); + instance.onTrainEnd(null); + + assertTrue(history.getHistory().containsKey("acc")); + assert (history.getHistory().get("acc").size() == 1); + assertTrue(called.get()); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/HistoryTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/HistoryTest.java new file mode 100644 index 00000000000..b11e3887f0c --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/HistoryTest.java @@ -0,0 +1,59 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class HistoryTest { + + @Test + void testOnTrainBegin() { + History instance = new History(); + + instance.onTrainBegin(null); + Map logs = new HashMap<>(); + logs.put("acc", 0.99); + logs.put("err", 0.012345); + int totalEpochs = 100; + for (int epoch = 0; epoch < totalEpochs; epoch++) { + instance.onEpochEnd(epoch, logs); + } + assertEquals(totalEpochs, instance.getEpoch().size()); + + Map> results = instance.getHistory(); + assertEquals(2, results.size()); + assertEquals(results.get("acc").size(), totalEpochs); + assertEquals(results.get("err").size(), totalEpochs); + + instance.onTrainBegin(null); + assertEquals(0, instance.getEpoch().size()); + for (int epoch = 0; epoch < totalEpochs; epoch++) { + instance.onEpochEnd(epoch, logs); + } + + assertEquals(totalEpochs, instance.getEpoch().size()); + + results = instance.getHistory(); + assertEquals(2, results.size()); + assertEquals(results.get("acc").size(), totalEpochs * 2); + assertEquals(results.get("err").size(), totalEpochs * 2); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/LambdaCallbackTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/LambdaCallbackTest.java new file mode 100644 index 00000000000..36e6397d62c --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/LambdaCallbackTest.java @@ -0,0 +1,63 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class LambdaCallbackTest { + + @Test + void onEpochBegin() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + int expectedEpoch = 101; + Map exceptedLog = new HashMap<>(); + exceptedLog.put("acc", 0.98); + instance.setOnEpochBegin( + (epoch, log) -> { + assertEquals(expectedEpoch, epoch); + assertEquals(exceptedLog, log); + called.set(true); + }); + + Map epochLog = new HashMap<>(); + epochLog.put("acc", 0.98); + instance.onEpochBegin(101, epochLog); + + assertTrue(called.get()); + } + + @Test + void onEpochEnd() {} + + @Test + void onTrainBatchBegin() {} + + @Test + void onTrainBatchEnd() {} + + @Test + void onTrainBegin() {} + + @Test + void onTrainEnd() {} +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/ProgbarLoggerTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/ProgbarLoggerTest.java new file mode 100644 index 00000000000..408aab0fa81 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/ProgbarLoggerTest.java @@ -0,0 +1,176 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.junit.jupiter.api.Test; +import org.tensorflow.framework.callbacks.util.ProgressBar; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.fail; + +class ProgbarLoggerTest { + + @Test + void testNoTarget() { + File tmpFile = null; + try { + tmpFile = File.createTempFile("tf-test-progbar", ".txt"); + System.out.println(tmpFile); + try (PrintWriter writer = new PrintWriter(new FileWriter(tmpFile))) { + int numEpochs = 1; + int numSteps = 10; + Map params = new HashMap<>(); + + params.put("verbose", ProgressBar.VerboseMode.VERBOSE); + params.put("size", numSteps); + params.put("num_steps", numSteps); + params.put("writer", writer); + ProgbarLogger instance = + new ProgbarLogger(params, null, ProgressBar.CountMode.STEPS, Arrays.asList("acc")); + + Map logs = new HashMap<>(); + logs.put("acc", 0.95); + instance.onTrainBegin(null); + for (int epoch = 0; epoch < numEpochs; epoch++) { + instance.onEpochBegin(epoch, null); + for (int step = 0; step < numSteps; step++) { + instance.onTrainBatchBegin(step, logs); + try { + Thread.sleep(100); + } catch (InterruptedException ignore) { + } + instance.onTrainBatchEnd(step, logs); + } + // instance.onEpochEnd(epoch, logs); + } + instance.onTrainEnd(null); + } catch (IOException ex) { + fail(ex); + } + List results = readResults(tmpFile); + // 1/Unknown - 0s 105ms/steps - acc: 0.9500 + // 10/Unknown - 1s 104ms/steps - acc: 0.9500 + Pattern p1 = + Pattern.compile(" [1 ][0-9]/Unknown - [0-9]s [1-9][0-9][0-9]ms/steps - acc: 0.9500"); + + results.forEach( + line -> { + if (!line.trim().isEmpty()) { + Matcher m = p1.matcher(line); + if (!m.matches()) { + fail("unexpected output \"" + line + "\""); + } + } + }); + } catch (IOException ex) { + fail(ex); + } finally { + if (tmpFile != null) { + // tmpFile.delete(); + } + } + } + + @Test + void testTarget() { + + File tmpFile = null; + try { + tmpFile = File.createTempFile("tf-test-progbar", ".txt"); + try (PrintWriter writer = new PrintWriter(new FileWriter(tmpFile))) { + int numEpochs = 10; + int numSteps = 10; + Map params = new HashMap<>(); + + params.put("verbose", ProgressBar.VerboseMode.VERBOSE); + params.put("size", numSteps); + params.put("num_steps", numSteps); + params.put("steps", numSteps); + params.put("writer", writer); + ProgbarLogger instance = + new ProgbarLogger(params, null, ProgressBar.CountMode.STEPS, Arrays.asList("acc")); + + Map logs = new HashMap<>(); + logs.put("acc", 0.88); + + instance.onTrainBegin(null); + for (int epoch = 0; epoch < numEpochs; epoch++) { + instance.onEpochBegin(epoch, null); + for (int step = 0; step < numSteps; step++) { + instance.onTrainBatchBegin(step, logs); + try { + Thread.sleep(10); + } catch (InterruptedException ignore) { + } + instance.onTrainBatchEnd(step, logs); + } + + instance.onEpochEnd(epoch, logs); + } + instance.onTrainEnd(null); + } + + List results = readResults(tmpFile); + // 1/10 [==>...........................] - 0s - ETA: 0s - acc: 0.8800 + Pattern p1 = Pattern.compile(" [1-9]/10 \\[==*>\\.*\\] - 0s - ETA: 0s - acc: 0.8800"); + // 10/10 [==============================] - 0s 12ms/steps - acc: 0.8800 + Pattern p2 = Pattern.compile("10/10 \\[==*\\] - 0s [1-9][0-9]*ms/steps - acc: 0.8800"); + String finalLine = "10/10 [==============================] - 0s - ETA: 0s - acc: 0.8800"; + results.forEach( + line -> { + if (!line.trim().isEmpty()) { + Matcher m = p1.matcher(line); + if (!m.matches()) { + m = p2.matcher(line); + if (!m.matches()) { + if (!line.equals(finalLine)) { + fail("unexpected output \"" + line + "\""); + } + } + } + } + }); + + } catch (IOException ex) { + fail(ex); + } finally { + if (tmpFile != null) { + tmpFile.delete(); + } + } + } + + private List readResults(File file) { + try (BufferedReader reader = new BufferedReader(new FileReader(file))) { + return reader.lines().collect(Collectors.toList()); + } catch (IOException ex) { + fail("cannot read tmp file", ex); + } + return null; // should not happen + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormatTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormatTest.java new file mode 100644 index 00000000000..decc93e50f5 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormatTest.java @@ -0,0 +1,48 @@ +package org.tensorflow.framework.callbacks.util; + +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class PathPlaceholderStringFormatTest { + + @Test + public void testPlaceholder() { + String filePath = "weights.{epoch:02d}-{val_loss:.2f}.hdf5"; + + Map logs = new HashMap<>(); + logs.put("val_loss", 0.71); + + // test with val_loss and 1 digit epoch + String result = PathPlaceholderStringFormat.convertFilePath(filePath, 1, logs); + String expect = "weights.01-0.71.hdf5"; + assertEquals(expect, result); + + // test with val_loss and 2 digit epoch + result = PathPlaceholderStringFormat.convertFilePath(filePath, 12, logs); + expect = "weights.12-0.71.hdf5"; + assertEquals(expect, result); + + // test with val_loss and 2 digit epoch and an added log variable + logs.put("acc", 0.21); + logs.put("val_loss", 0.99); + result = PathPlaceholderStringFormat.convertFilePath(filePath, 12, logs); + expect = "weights.12-0.99.hdf5"; + assertEquals(expect, result); + + // test with empty logs variable + logs.clear(); + result = PathPlaceholderStringFormat.convertFilePath(filePath, 123, logs); + expect = "weights.123-0.00.hdf5"; + assertEquals(expect, result); + + // test with no formatting + filePath = "weights.hdf5"; + result = PathPlaceholderStringFormat.convertFilePath(filePath, 0, logs); + expect = "weights.hdf5"; + assertEquals(expect, result); + } +} From 9dcddcd43c5b17e9181ec575873b7c7cfac057f0 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 25 Apr 2021 19:27:37 -0400 Subject: [PATCH 2/6] Initial checkin --- tensorflow-framework/pom.xml | 6 + .../framework/callbacks/CSVLogger.java | 269 ++++++++++ .../framework/callbacks/Callback.java | 262 ++++++++++ .../framework/callbacks/CallbackList.java | 230 ++++++++ .../framework/callbacks/History.java | 91 ++++ .../framework/callbacks/LambdaCallback.java | 233 +++++++++ .../tensorflow/framework/callbacks/Mode.java | 22 + .../framework/callbacks/ProgbarLogger.java | 353 +++++++++++++ .../framework/callbacks/UpdateFreq.java | 23 + .../framework/callbacks/VerboseMode.java | 16 + .../util/PathPlaceholderStringFormat.java | 92 ++++ .../framework/callbacks/util/ProgressBar.java | 489 ++++++++++++++++++ .../framework/callbacks/CSVLoggerTest.java | 113 ++++ .../framework/callbacks/CallbackListTest.java | 52 ++ .../framework/callbacks/HistoryTest.java | 59 +++ .../callbacks/LambdaCallbackTest.java | 63 +++ .../callbacks/ProgbarLoggerTest.java | 176 +++++++ .../util/PathPlaceholderStringFormatTest.java | 48 ++ 18 files changed, 2597 insertions(+) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CallbackList.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/History.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Mode.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/ProgbarLogger.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/UpdateFreq.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/VerboseMode.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormat.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/ProgressBar.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CallbackListTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/HistoryTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/LambdaCallbackTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/ProgbarLoggerTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormatTest.java diff --git a/tensorflow-framework/pom.xml b/tensorflow-framework/pom.xml index af7f47815d5..0ce06282d29 100644 --- a/tensorflow-framework/pom.xml +++ b/tensorflow-framework/pom.xml @@ -43,6 +43,12 @@ org.tensorflow tensorflow-core-api ${project.version} + + + + org.apache.commons + commons-csv + 1.8 org.junit.jupiter diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java new file mode 100644 index 00000000000..f009dfadfc0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java @@ -0,0 +1,269 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVPrinter; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.family.TNumber; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.StringJoiner; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +/** + * Callback that streams epoch results to a CSV file. public + * + *

Supports all values that can be represented as a string + * + * @param the data type for the weights in the model + */ +class CSVLogger extends Callback implements AutoCloseable { + + public static final char DEFAULT_SEPARATOR = ','; + public static final boolean DEFAULT_APPEND = false; + + private final File file; + private final char separator; + private final boolean append; + private List keys; + private boolean appendHeader = true; + + private CSVPrinter writer; + + /** + * Creates a CSVLogger callback using {@link #DEFAULT_SEPARATOR} to separate elements in the csv + * file, and {@link #DEFAULT_APPEND} for the append value. + * + * @param file the csv file + */ + public CSVLogger(File file) { + this(file, DEFAULT_SEPARATOR, DEFAULT_APPEND); + } + + /** + * Creates a CSVLogger callback using {@link #DEFAULT_SEPARATOR} to separate elements in the csv + * file, and {@link #DEFAULT_APPEND} for the append value. + * + * @param filename filename of the csv file + */ + public CSVLogger(String filename) { + this(new File(filename), DEFAULT_SEPARATOR, DEFAULT_APPEND); + } + + /** + * Creates a CSVLogger callback using {@link #DEFAULT_APPEND} for the append value. + * + * @param file the csv file + * @param separator string used to separate elements in the csv file. + */ + public CSVLogger(File file, char separator) { + this(file, separator, false); + } + + /** + * Creates a CSVLogger callback using {@link #DEFAULT_APPEND} for the append value. + * + * @param filename filename of the csv file + * @param separator string used to separate elements in the csv file. + */ + public CSVLogger(String filename, char separator) { + this(new File(filename), separator, false); + } + + /** + * Creates a CSVLogger callback. + * + * @param filename filename of the csv file + * @param separator the character used to separate elements in the csv file. + * @param append if true, append if file exists (useful for continuing training). if false, + * overwrite existing file. + */ + public CSVLogger(String filename, char separator, boolean append) { + this(new File(filename), separator, append); + } + + /** + * Creates a CSVLogger callback. + * + * @param file the csv file + * @param separator the character used to separate elements in the csv file. + * @param append if true, append if file exists (useful for continuing training). if false, + * overwrite existing file. + */ + public CSVLogger(File file, char separator, boolean append) { + this.file = file; + this.separator = separator; + this.append = append; + } + + /** {@inheritDoc} */ + @Override + public void onTrainBegin(Map logs) { + appendHeader = !append || !file.exists(); + } + + // TODO Should we handle Java arrays?? + @SuppressWarnings("unchecked") + private String handleValue(Object val) { + + if (val instanceof String) { + return val.toString(); + } else if (val instanceof NdArray) { // todo + boolean isScalar = ((NdArray) val).rank() == 0; + if (isScalar) { + return ((NdArray) val).getObject().toString(); + } else { + NdArray array = (NdArray) val; + return toString(array); + } + } else if (val instanceof Collection) { + return "[" + + ((Collection) val).stream().map(Object::toString).collect(Collectors.joining(",")) + + "]"; + } else { + return val.toString(); + } + } + + /** + * coverts an NdArray to a printable string + * + * @param ndArray the NdArray + * @return the printable string + */ + private String toString(NdArray ndArray) { + Iterator> iterator = ndArray.scalars().iterator(); + Shape shape = ndArray.shape(); + if (shape.numDimensions() == 0) { + if (!iterator.hasNext()) { + return ""; + } + return valToString(iterator.next().getObject()); + } + return toString(iterator, shape, 0); + } + + private String toString(Iterator> iterator, Shape shape, int dimension) { + if (dimension < shape.numDimensions() - 1) { + StringJoiner joiner = new StringJoiner("", "[", "]"); + for (long i = 0, size = shape.size(dimension); i < size; ++i) { + String element = toString(iterator, shape, dimension + 1); + joiner.add(element); + } + return joiner.toString(); + } else { + StringJoiner joiner = new StringJoiner(", ", "[", "]"); + for (long i = 0, size = shape.size(dimension); i < size; ++i) { + Object element = iterator.next().getObject(); + joiner.add(valToString(element)); + } + return joiner.toString(); + } + } + + /** + * Converts a value to a printable string + * + * @param val the value + * @return tje printable string + */ + private String valToString(Object val) { + if (val instanceof Number) { + Number nVal = (Number) val; + if (nVal instanceof Float || nVal instanceof Double) { + return String.format("%e", nVal.doubleValue()); + } else if (nVal instanceof Byte) { + return String.format("0x%2x", nVal.byteValue()); + } else { + return String.format("%d", nVal.longValue()); + } + } else { + return val.toString(); + } + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public void onEpochEnd(int epoch, Map logs) { + logs = logs == null ? Collections.EMPTY_MAP : logs; + + if (keys == null) { + keys = new ArrayList<>(logs.keySet()); + Collections.sort(this.keys); + } + + if (writer == null) { + try { + List fieldNames = new ArrayList<>(); + fieldNames.add("epoch"); + fieldNames.addAll(this.keys); + CSVFormat csvFormat = + appendHeader + ? CSVFormat.EXCEL + .withHeader(fieldNames.toArray(new String[0])) + .withDelimiter(separator) + : CSVFormat.EXCEL.withDelimiter(separator); + writer = new CSVPrinter(new FileWriter(file, append), csvFormat); + } catch (IOException ex) { + Logger.getLogger(CSVLogger.class.getName()).log(Level.SEVERE, null, ex); + return; + } + } + + /* TODO include when integrated with Model + if (getModel().isStopTraining()) { + final Map flogs = logs; + keys.forEach( + key -> { + if (!flogs.containsKey(key)) { + flogs.put(key, Double.NaN); + } + }); + } + */ + try { + final List values = new ArrayList<>(); + final Map logsFinal = logs; + values.add(String.valueOf(epoch)); + keys.forEach(key -> values.add(handleValue(logsFinal.get(key)))); + writer.printRecord(values); + writer.flush(); + } catch (IOException ex) { + Logger.getLogger(CSVLogger.class.getName()).log(Level.SEVERE, null, ex); + } + } + + /** {@inheritDoc} */ + @Override + public void close() throws IOException { + if (writer != null) { + writer.close(); + writer = null; + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java new file mode 100644 index 00000000000..0ed72e6592c --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java @@ -0,0 +1,262 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import java.util.Collections; +import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Abstract base class used to build new callbacks. + * + *

The logs map that callback methods take as argument will contain keys for quantities relevant + * to the current batch or epoch (see method-specific docstrings). + */ +public abstract class Callback { + protected Map params; + // TODO protected Model model; + + /** Creates a Callback */ + protected Callback() { + this(null); + } + + /** + * Creates a Callback + * + * @param params Training parameters + */ + protected Callback(Map params) { + this.params = params; + } + + /** + * Creates a Callback + * + * @param params Training parameters + * @param model the Model + */ + /* TODO + protected Callback(Map params, Model model) {= + this.params = params; + this.model = model; + } + */ + + /** + * Performs custom processing at the the start of an epoch. This method should only be Performs + * custom processing during TRAIN mode. This method is empty. Extend this class to handle this + * event. + * + * @param epoch index of epoch. + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onEpochBegin(int epoch, Map logs) {} + + /** + * Performs custom processing at the end of an epoch.This method should only be Performs custom + * processing during TRAIN mode. This method is empty. Extend this class to handle this event. + * + * @param epoch index of epoch. + * @param logs metric results for this training epoch, and for the validation epoch if validation + * is performed. Validation result keys are prefixed with `val_`. + */ + @SuppressWarnings("unused") + public void onEpochEnd(int epoch, Map logs) {} + + /** + * Performs custom processing at the beginning of a training batch in `fit` methods. This method + * is empty. Extend this class to handle this event. + * + * @param batch the batch index + * @param logs Has keys `batch` and `size` representing the current batch number and the size of + * the batch. + */ + @SuppressWarnings("unused") + public void onTrainBatchBegin(int batch, Map logs) {} + + /** + * Performs custom processing at the end of a training batch in `fit` methods. This method is + * empty. Extend this class to handle this event. + * + * @param batch index of batch within the current epoch. + * @param logs Metric results for this batch. + */ + @SuppressWarnings("unused") + public void onTrainBatchEnd(int batch, Map logs) {} + + /** + * Performs custom processing at the beginning of training. This method is empty. Extend this + * class to handle this event. + * + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onTrainBegin(Map logs) {} + + /** + * Performs custom processing at the end of training. This method is empty. Extend this class to + * handle this event. + * + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onTrainEnd(Map logs) {} + + /** + * Performs custom processing at the beginning of a batch in `evaluate` methods. Also Performs + * custom processing at the beginning of a validation batch in the `fit` methods, if validation + * data is provided. This method is empty. Extend this class to handle this event. + * + * @param batch the batch number + * @param logs Has keys `batch` and `size` representing the current batch number and the size of + * the batch. + */ + @SuppressWarnings("unused") + public void onTestBatchBegin(int batch, Map logs) {} + + /** + * Performs custom processing at the end of a batch in `evaluate` methods. Also Performs custom + * processing at the end of a validation batch in the `fit` methods, if validation data is + * provided. + * + *

This method is empty. Extend this class to handle this event. + * + * @param batch the batch number + * @param logs Metric results for this batch. + */ + @SuppressWarnings("unused") + public void onTestBatchEnd(int batch, Map logs) {} + + /** + * Performs custom processing at the beginning of evaluation or validation. This method is empty. + * Extend this class to handle this event. + * + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onTestBegin(Map logs) {} + + /** + * Performs custom processing at the end of evaluation or validation. This method is empty. Extend + * this class to handle this event. + * + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onTestEnd(Map logs) {} + + /** + * Performs custom processing at the beginning of a batch in `predict` methods. This method is + * empty. Extend this class to handle this event. + * + * @param batch index of batch within the current epoch. + * @param logs Has keys `batch` and `size` representing the current batch number and the size of + * the batch. + */ + @SuppressWarnings("unused") + public void onPredictBatchBegin(int batch, Map logs) {} + + /** + * Performs custom processing at the end of a batch in `predict` methods. This method is empty. + * Extend this class to handle this event. + * + * @param batch index of batch within the current epoch. + * @param logs Metric results for this batch. + */ + @SuppressWarnings("unused") + public void onPredictBatchEnd(int batch, Map logs) {} + + /** + * Performs custom processing at the beginning of prediction. This method is empty. Extend this + * class to handle this event. + * + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onPredictBegin(Map logs) {} + + /** + * Performs custom processing at the end of prediction. This method is empty. Extend this class to + * handle this event. + * + * @param logs metric results + */ + @SuppressWarnings("unused") + public void onPredictEnd(Map logs) {} + + /** + * Gets a monitor value from the value logs + * + * @param logs the value logs + * @param monitor the monitor to fetch + * @return the monitor value, returns null if the monitor value is not in logs. + */ + @SuppressWarnings("unchecked") + protected Number getMonitorValue(Map logs, String monitor) { + logs = logs == null ? Collections.EMPTY_MAP : logs; + Number monitorValue = logs.get(monitor); + if (monitorValue != null) { + Logger.getLogger(getClass().getName()) + .log( + Level.WARNING, + String.format( + "Early stopping conditioned on metric `%s` which is not available. Available metrics are: %s", + monitor, String.join(",", logs.keySet()))); + } + return monitorValue; + } + + /** + * Gets the params + * + * @return the params + */ + public Map getParams() { + return params; + } + + /** + * Sets the params + * + * @param params the params to set + */ + public void setParams(Map params) { + this.params = params; + } + + /** + * Gets the model + * + * @return the model + */ + /* TODO + public Model getModel() { + return model; + TODO */ + + /** + * Sets the model + * + * @param model the model + */ + /* TODO + public void setModel(Model model) { + this.model = model; + } + TODO */ +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CallbackList.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CallbackList.java new file mode 100644 index 00000000000..a1124f1b266 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CallbackList.java @@ -0,0 +1,230 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +// TODO import org.tensorflow.framework.model.Model; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Container for {@link Callback} instances. + * + *

This object wraps a list of {@link Callback} instances, making it possible to call them all at + * once via a single endpoint (e.g. {@code callbackList.onEpochEnd(...)}). + */ +public class CallbackList extends Callback { + private final List callbacks = new ArrayList<>(); + + // TODO private final Model model; + private History history; + + /** Creates a CallbackList */ + public CallbackList() { + super(); + } + + /** + * Creates a CallbackList + * + * @param addHistory Whether a {@link History} callback should be added, if one does not already + * exist in the {@code callbacks} list. + */ + public CallbackList(boolean addHistory) { + addDefaultCallbacks(addHistory); + } + + /** + * Creates a CallbackList + * + * @param callbacks List of {@link Callback} instances. + * @param addHistory Whether a {@link History} callback should be added, if one does not already + * exist * in the {@code callbacks} list. + */ + public CallbackList(List callbacks, boolean addHistory) { + this(addHistory); + this.callbacks.addAll(callbacks); + } + + /* TODO add when integrated with Model + // /** + * Creates a CallbackList + * + * @param model the model these callbacks are used with. + * @param addHistory Whether a {@link History} callback should be added, if one does not already + * exist in the {@code callbacks} list. + // * + + public CallbackList(Model model, boolean addHistory) { + this.model = model; + addDefaultCallbacks(addHistory); + } + TODO */ + + /* TODO add when integrated with Model + // /** + * Creates a CallbackList + * + * @param model the model these callbacks are used with. + * @param callbacks List of {@link Callback} instances. + * @param addHistory Whether a {@link History} callback should be added, if one does not already + * exist in the {@code callbacks} list. + // * + + public CallbackList(Model model, List callbacks, boolean addHistory) { + this(model, addHistory); + this.callbacks.addAll(callbacks); + } + TODO */ + + /** + * Adds Callback's that are always present. + * + * @param addHistory Whether a {@link History} callback should be added, if one does not already + * exist in the {@code callbacks} list. + */ + private void addDefaultCallbacks(boolean addHistory) { + callbacks.forEach( + c -> { + if (c instanceof History) { + history = (History) c; + } + }); + if (history == null && addHistory) { + history = new History(); + addCallback(history); + } + } + + /** + * Adds a callback + * + * @param callback the callback + */ + public void addCallback(Callback callback) { + callbacks.add(callback); + } + + /** + * Adds callbacks + * + * @param callbacks the callbacks + */ + public void addCallbacks(List callbacks) { + this.callbacks.addAll(callbacks); + } + + /** {@inheritDoc } */ + @Override + public void onTrainBegin(Map logs) { + callbacks.forEach(c -> c.onTrainBegin(logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTrainEnd(Map logs) { + callbacks.forEach(c -> c.onTrainEnd(logs)); + } + /** {@inheritDoc } */ + @Override + public void onEpochBegin(int epoch, Map logs) { + callbacks.forEach(c -> c.onEpochBegin(epoch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onEpochEnd(int epoch, Map logs) { + callbacks.forEach(c -> c.onEpochEnd(epoch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTrainBatchBegin(int batch, Map logs) { + callbacks.forEach(c -> c.onTrainBatchBegin(batch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTrainBatchEnd(int batch, Map logs) { + callbacks.forEach(c -> c.onTrainBatchEnd(batch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTestBatchBegin(int batch, Map logs) { + callbacks.forEach(c -> c.onTestBatchBegin(batch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTestBatchEnd(int batch, Map logs) { + callbacks.forEach(c -> c.onTestBatchEnd(batch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTestBegin(Map logs) { + callbacks.forEach(c -> c.onTestBegin(logs)); + } + + /** {@inheritDoc } */ + @Override + public void onTestEnd(Map logs) { + callbacks.forEach(c -> c.onTestEnd(logs)); + } + + /** {@inheritDoc } */ + @Override + public void onPredictBatchBegin(int batch, Map logs) { + callbacks.forEach(c -> c.onPredictBatchBegin(batch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onPredictBatchEnd(int batch, Map logs) { + callbacks.forEach(c -> c.onPredictBatchEnd(batch, logs)); + } + + /** {@inheritDoc } */ + @Override + public void onPredictBegin(Map logs) { + callbacks.forEach(c -> c.onPredictBegin(logs)); + } + + /** {@inheritDoc } */ + @Override + public void onPredictEnd(Map logs) { + callbacks.forEach(c -> c.onPredictEnd(logs)); + } + + /** + * Gets the callbacks + * + * @return the callbacks + */ + public List getCallbacks() { + return callbacks; + } + + /** + * Gets the history callback + * + * @return the history callback + */ + public History getHistory() { + return history; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/History.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/History.java new file mode 100644 index 00000000000..53ce9373d8f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/History.java @@ -0,0 +1,91 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Callback that records events into a History object. + * + *

This callback is automatically applied to every model. The History object gets returned by the + * fit method of models. + */ +public class History extends Callback { + private final Map> history = new HashMap<>(); + private final List epoch = new ArrayList<>(); + + /** Creates a History Callback */ + public History() { + super(); + } + + /* TODO + * Creates a History Callback + * + * @param params Training parameters + * @param model the Model + * + + public History( Model model) {= + super(null, model); + } + TODO */ + + /** {@inheritDoc} */ + @Override + public void onTrainBegin(Map logs) { + epoch.clear(); + } + + /** {@inheritDoc} */ + @Override + public void onEpochEnd(int epoch, Map logs) { + Map localLogs = logs == null ? Collections.emptyMap() : logs; + this.epoch.add(epoch); + + logs.entrySet() + .forEach( + e -> { + List item = history.get(e.getKey()); + if (item == null) { + item = new ArrayList<>(); + history.put(e.getKey(), item); + } + item.add(e.getValue()); + }); + } + + /** + * Gets the History map for each log value + * + * @return the History map for each log value + */ + public Map> getHistory() { + return history; + } + + /** + * Gets the history of epochs + * + * @return the history of epochs + */ + public List getEpoch() { + return epoch; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java new file mode 100644 index 00000000000..af976eee2ff --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java @@ -0,0 +1,233 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +/** + * Callback for creating simple, custom callbacks on-the-fly. + * + *

Example: + * + *

{@code
+ * LambdaCallbacks batchPrintCallback = new LambdaCallbacks();
+ * batchPrintCallback.setOnTrainBatchBegin((batch, logs)->
+ *     System.out.println("Batch: " + batch + " started");
+ * }
+ * + *

This callback is constructed with anonymous functions that will be called at the appropriate + * time. Note that the callbacks expects positional arguments, as: + * + *

    + *
  • onEpochBegin and onEpochEnd expect two positional arguments: + * epoch, logs + *
  • onBatchBegin and onBatchEnd expect two positional arguments: + * batch, logs + *
  • onTrainBegin and onTrainEnd expect one positional argument: + * logs + *
+ */ +public class LambdaCallback extends Callback { + + /** Called at the beginning of every epoch. expect two positional arguments: `epoch`, `logs` */ + private BiConsumer> onEpochBegin; + + /** Called at the end of every epoch. expect two positional arguments: `epoch`, `logs` */ + private BiConsumer> onEpochEnd; + + /** Called at the beginning of every batch. expect two positional arguments: `batch`, `logs` */ + private BiConsumer> onTrainBatchBegin; + + /** called at the end of every batch. expect two positional arguments: `batch`, `logs` */ + private BiConsumer> onTrainBatchEnd; + + /** called at the beginning of model training. expect one positional argument: `logs` */ + private Consumer> onTrainBegin; + + /** called at the end of model training. expect one positional argument: `logs` */ + private Consumer> onTrainEnd; + + /** Creates a LambdaCallbacks callback */ + public LambdaCallback() { + super(); + } + + /** + * Creates a LambdaCallbacks callback + * + * @param params Training parameters + */ + public LambdaCallback(Map params) { + super(params); + } + + /** {@inheritDoc} */ + @Override + public void onEpochBegin(int epoch, Map logs) { + if (this.onEpochBegin != null) { + this.onEpochBegin.accept(epoch, logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onEpochEnd(int epoch, Map logs) { + if (this.onEpochEnd != null) { + this.onEpochEnd.accept(epoch, logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onTrainBatchBegin(int batch, Map logs) { + if (this.onTrainBatchBegin != null) { + this.onTrainBatchBegin.accept(batch, logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onTrainBatchEnd(int batch, Map logs) { + if (this.onTrainBatchEnd != null) { + this.onTrainBatchEnd.accept(batch, logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onTrainBegin(Map logs) { + if (this.onTrainBegin != null) { + this.onTrainBegin.accept(logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onTrainEnd(Map logs) { + if (this.onTrainEnd != null) { + this.onTrainEnd.accept(logs); + } + } + + /** + * Gets the onEpochBegin lambda function + * + * @return the onEpochBegin lambda function + */ + public BiConsumer> getOnEpochBegin() { + return onEpochBegin; + } + + /** + * Sets the onEpochBegin lambda function + * + * @param onEpochBegin lambda function to set + */ + public void setOnEpochBegin(BiConsumer> onEpochBegin) { + this.onEpochBegin = onEpochBegin; + } + + /** + * Gets the onEpochEnd lambda function + * + * @return the onEpochEnd lambda function + */ + public BiConsumer> getOnEpochEnd() { + return onEpochEnd; + } + + /** + * Sets the onEpochEnd lambda function + * + * @param onEpochEnd the lambda function + */ + public void setOnEpochEnd(BiConsumer> onEpochEnd) { + this.onEpochEnd = onEpochEnd; + } + + /** + * Gets the onTrainBatchBegin lambda function + * + * @return the onTrainBatchBegin lambda function + */ + public BiConsumer> getOnTrainBatchBegin() { + return onTrainBatchBegin; + } + + /** + * Sets the onTrainBatchBegin lambda function + * + * @param onTrainBatchBegin the lambda function + */ + public void setOnTrainBatchBegin(BiConsumer> onTrainBatchBegin) { + this.onTrainBatchBegin = onTrainBatchBegin; + } + + /** + * Gets the onTrainBatchEnd lambda function + * + * @return the onTrainBatchEnd lambda function + */ + public BiConsumer> getOnTrainBatchEnd() { + return onTrainBatchEnd; + } + + /** + * Sets the onTrainBatchEnd lambda function + * + * @param onTrainBatchEnd the onTrainBatchEnd lambda function + */ + public void setOnTrainBatchEnd(BiConsumer> onTrainBatchEnd) { + this.onTrainBatchEnd = onTrainBatchEnd; + } + + /** + * Gets the onTrainBegin lambda function + * + * @return the onTrainBegin lambda function + */ + public Consumer> getOnTrainBegin() { + return onTrainBegin; + } + + /** + * Sets the onTrainBegin lambda function + * + * @param onTrainBegin the onTrainBegin lambda function + */ + public void setOnTrainBegin(Consumer> onTrainBegin) { + this.onTrainBegin = onTrainBegin; + } + + /** + * Gets the onTrainEnd lambda function + * + * @return the onTrainEnd lambda function + */ + public Consumer> getOnTrainEnd() { + return onTrainEnd; + } + + /** + * Sets the onTrainEnd lambda function + * + * @param onTrainEnd the onTrainEnd lambda function + */ + public void setOnTrainEnd(Consumer> onTrainEnd) { + this.onTrainEnd = onTrainEnd; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Mode.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Mode.java new file mode 100644 index 00000000000..50035c0731b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Mode.java @@ -0,0 +1,22 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +/** The mode on when a Callback takes action. */ +public enum Mode { + AUTO, + MIN, + MAX +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/ProgbarLogger.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/ProgbarLogger.java new file mode 100644 index 00000000000..112fa8f3cd1 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/ProgbarLogger.java @@ -0,0 +1,353 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.tensorflow.framework.callbacks.util.ProgressBar; + +import java.io.PrintWriter; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.IntSupplier; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** Callback that prints metrics to console. */ +public class ProgbarLogger extends Callback { + + private final ProgressBar.CountMode unit; + private Set statefulMetrics; + private int seen = 0; + private ProgressBar progbar = null; + private Integer target = null; + private ProgressBar.VerboseMode verbose = ProgressBar.VerboseMode.VERBOSE; + private int epochs = 1; + private boolean calledInFit = false; + + // TODO wire these up to Model + private final IntSupplier getTrainCounter = null; + private final IntSupplier getTestCounter = null; + private final IntSupplier getPredictCounter = null; + + private PrintWriter writer = null; + + /** Create a ProgbarLogger */ + public ProgbarLogger() { + this(null, null, ProgressBar.CountMode.SAMPLES, (List) null); + } + + /** + * Create a ProgbarLogger + * + * @param mode Whether the progress bar should count SAMPLES seen or STEPS (batches) seen. + */ + public ProgbarLogger(ProgressBar.CountMode mode) { + this(null, null, mode, (List) null); + } + + /** + * Create a ProgbarLogger + * + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgbarLogger(List statefulMetrics) { + this(null, null, ProgressBar.CountMode.SAMPLES, statefulMetrics); + } + + /** + * Create a ProgbarLogger + * + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgbarLogger(String... statefulMetrics) { + this(null, null, ProgressBar.CountMode.SAMPLES, Arrays.asList(statefulMetrics)); + } + + /** + * Create a ProgbarLogger + * + * @param mode Whether the progress bar should count SAMPLES seen or STEPS (batches) seen. + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgbarLogger(ProgressBar.CountMode mode, List statefulMetrics) { + this(null, null, mode, statefulMetrics); + } + + /** + * Create a ProgbarLogger + * + * @param mode Whether the progress bar should count SAMPLES seen or STEPS (batches) seen. + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgbarLogger(ProgressBar.CountMode mode, String... statefulMetrics) { + this(null, null, mode, Arrays.asList(statefulMetrics)); + } + + /** + * Create a ProgbarLogger + * + * @param params Training parameters + */ + public ProgbarLogger(Map params) { + this(params, null, ProgressBar.CountMode.SAMPLES, (List) null); + } + + /** + * Create a ProgbarLogger + * + * @param params Training parameters + * @param model Reference of the model being trained. + */ + public ProgbarLogger(Map params, Object model) { + this(params, model, ProgressBar.CountMode.SAMPLES, (List) null); + } + + /** + * Create a ProgbarLogger + * + * @param params Training parameters + * @param model Reference of the model being trained. + * @param mode Whether the progress bar should count SAMPLES seen or STEPS (batches) seen. + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgbarLogger( + Map params, + Object model, + ProgressBar.CountMode mode, + String... statefulMetrics) { + this(params, model, mode, Arrays.asList(statefulMetrics)); + } + + /** + * Create a ProgbarLogger + * + * @param params Training parameters + * @param model Reference of the model being trained. + * @param unit Whether the progress bar should count SAMPLES seen or STEPS (batches) seen. + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgbarLogger( + Map params, + Object model, + ProgressBar.CountMode unit, + List statefulMetrics) { + // TODO super(params, model); + + this.unit = unit; + this.statefulMetrics = + statefulMetrics != null ? new HashSet<>(statefulMetrics) : new HashSet<>(); + setParams(params); + } + + /** {@inheritDoc} */ + @Override + public final void setParams(Map params) { + if (params == null) { + return; + } + super.setParams(params); + verbose = + ((ProgressBar.VerboseMode) params.getOrDefault("verbose", ProgressBar.VerboseMode.VERBOSE)); + epochs = (Integer) params.getOrDefault("epochs", 1); + target = + unit == ProgressBar.CountMode.STEPS + ? (Integer) params.get("steps") + : (Integer) params.get("samples"); + writer = (PrintWriter) params.get("writer"); + + if (target == null) { + /* TODO wire into Model + getTrainCounter = () -> model.getTrainCounter(); + getTestCounter = () -> model.getTestCounter(); + getPredictCounter = () -> model.getPredictCounter(); + + */ + } + } + + /** {@inheritDoc} */ + @Override + public void onTrainBegin(Map logs) { + // When this logger is called inside fit, validation is silent. + calledInFit = true; + } + + /** {@inheritDoc} */ + @Override + public void onTestBegin(Map logs) { + if (!calledInFit) { + resetProgBar(); + maybeInitProgbar(); + } + } + + /** {@inheritDoc} */ + @Override + public void onPredictBegin(Map logs) { + resetProgBar(); + maybeInitProgbar(); + } + + /** {@inheritDoc} */ + @Override + public void onEpochBegin(int epoch, Map logs) { + resetProgBar(); + maybeInitProgbar(); + + if (verbose != ProgressBar.VerboseMode.SILENT && epochs > 1) { + Logger.getLogger(ProgbarLogger.class.getName()) + .log(Level.INFO, String.format("Epoch %d/%d", (epoch + 1), epochs)); + } + } + + @Override + public void onTrainBatchEnd(int batch, Map logs) { + batchUpdateProgbar(batch, logs); + } + + @Override + public void onTestBatchEnd(int batch, Map logs) { + if (!calledInFit) { + batchUpdateProgbar(batch, logs); + } + } + + @Override + public void onPredictBatchEnd(int batch, Map logs) { + // Don't pass prediction results. + super.onPredictBatchEnd(batch, null); + } + + /** {@inheritDoc} */ + @Override + public void onEpochEnd(int epoch, Map logs) { + finalizeProgbar(logs, getTrainCounter); + } + + /** {@inheritDoc} */ + @Override + public void onTestEnd(Map logs) { + if (!calledInFit) { + finalizeProgbar(logs, getTestCounter); + } + } + + /** {@inheritDoc} */ + @Override + public void onPredictEnd(Map logs) { + finalizeProgbar(logs, getPredictCounter); + } + + /** + * Updates the {@link ProgressBar} + * + * @param batch the batch number + * @param logs loss or metric results + */ + private void batchUpdateProgbar(int batch, Map logs) { + Map llogs = logs == null ? Collections.emptyMap() : logs; + maybeInitProgbar(); + + if (unit == ProgressBar.CountMode.STEPS) { + seen = batch + 1; + } else { + // make shallow copy + llogs = new HashMap<>(llogs); + Number batchSize = llogs.getOrDefault("size", 0); + Number numSteps = llogs.getOrDefault("num_steps", 1); + llogs.remove("batch"); + int addSeen = numSteps.intValue() * batchSize.intValue(); + seen += addSeen; + } + + if (verbose != ProgressBar.VerboseMode.SILENT) { + progbar.update(seen, llogs, false); + } + } + + /** + * Finalizes the Progess Bar + * + * @param logs results to apply + * @param getCounter gets the counter from the model + */ + private void finalizeProgbar(Map logs, IntSupplier getCounter) { + if (progbar != null) { + Integer counter = null; + if (target == null) { + if (getCounter != null) { + int counterValue = getCounter.getAsInt(); + if (unit == ProgressBar.CountMode.SAMPLES) { + Number size = logs.getOrDefault("", 1); + counterValue *= size.intValue(); + } + counter = counterValue; + } + target = counter == null ? seen : counter; + progbar.setTarget(target); + } + progbar.update(seen, logs, true); + } + } + + private void resetProgBar() { + seen = 0; + progbar = null; + } + + private void maybeInitProgbar() { + if (statefulMetrics == null) { + /* TODO - Model + if(model != null) { + statefulMetrics = new ArrayList<>(); + model.getMetrics().forEach(m -> statefulMetrics.add(m.getName)); + }else { + */ + statefulMetrics = new HashSet<>(); + // TODO - Model } + } + if (progbar == null) { + if (writer == null) { + progbar = new ProgressBar(target, verbose, unit, statefulMetrics); + } else { + progbar = + new ProgressBar( + writer, + target, + ProgressBar.DEFAULT_WIDTH, + verbose, + ProgressBar.DEFAULT_INTERVAL, + unit, + statefulMetrics); + } + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/UpdateFreq.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/UpdateFreq.java new file mode 100644 index 00000000000..61713630e9f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/UpdateFreq.java @@ -0,0 +1,23 @@ +/* Copyright 202 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +/** Enum that defines the update frequency */ +public enum UpdateFreq { + /** Update on every epoch */ + EPOCH, + /** Update on every batch */ + BATCH +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/VerboseMode.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/VerboseMode.java new file mode 100644 index 00000000000..705603b05f1 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/VerboseMode.java @@ -0,0 +1,16 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormat.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormat.java new file mode 100644 index 00000000000..e2340c4ed73 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormat.java @@ -0,0 +1,92 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks.util; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * These are utilities for generating a file path from an original file path containing named + * formatting options. + */ +public class PathPlaceholderStringFormat { + private static final Pattern PYTHON_MATCH = Pattern.compile("\\{(\\w+):([\\w.]+)}"); + + /** + * Converts a filepath containing named formatting options, which will be filled with the value of + * epoch and keys in logs (passed in onEpochEnd). + * + *

For example: + * + *

if filepath is * weights.{epoch:02d}-{val_loss:.2f}.hdf5, then the + * model checkpoints will be saved with the epoch number and the validation loss in the filename + * (e.g. "weights.561-0.71.hdf5"). + * + * @param filename the filename containing the formatting options + * @param epoch the epoch + * @param logs the logs map that contain the values + * @return the converted file path name + */ + public static String convertFilePath(String filename, int epoch, Map logs) { + List vars = new ArrayList<>(); + String format = getFilePath(filename, vars); + List values = new ArrayList<>(); + vars.forEach( + key -> { + if (key.equals("epoch")) values.add(epoch); + else if (logs.containsKey(key)) values.add(logs.get(key).doubleValue()); + else values.add(0.0); + }); + return String.format(format, values.toArray()); + } + + /** + * Creates a {@link String#format} string for formatting the filepath for including the log values + * identified by the original filepath placeholder names + * + * @param filename the filename with the placeholders. + * @param vars the list is populated with the log names from the placeholder names found in the + * original file path string that will be included in resulting name + * @return the String format for formatting the values identified from the placeholder names. + */ + private static String getFilePath(String filename, List vars) { + Matcher m = PYTHON_MATCH.matcher(filename); + StringBuilder sb = new StringBuilder(); + int beginIndex = 0; + Map indexMap = new HashMap<>(); + int lastIndex = 1; + while (m.find()) { + int start = m.start(0); + int end = m.end(0); + String variable = m.group(1); + vars.add(variable); + String format = m.group(2); + Integer index = indexMap.get(variable); + if (index == null) { + indexMap.put(variable, lastIndex); + index = lastIndex++; + } + sb.append(filename, beginIndex, start); + sb.append('%').append(index).append('$').append(format); + beginIndex = end; + } + sb.append(filename.substring(beginIndex)); + return sb.toString(); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/ProgressBar.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/ProgressBar.java new file mode 100644 index 00000000000..cf645157c13 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/ProgressBar.java @@ -0,0 +1,489 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks.util; + +import java.io.Console; +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Displays a progress bar. + * + *

Output may be sent to {@link Console} if one is present, otherwise it is sent to {@link + * System#out}. Whether a Java virtual machine has a console is dependent upon the underlying + * platform. + */ +public class ProgressBar { + public static final int DEFAULT_WIDTH = 30; + public static final VerboseMode DEFAULT_VERBOSE = VerboseMode.VERBOSE; + + // NOTE: Console may be null + public static final double DEFAULT_INTERVAL = 50; // msecs + public static final CountMode DEFAULT_COUNT_MODE = CountMode.STEPS; + private static final double MICRO_SECONDS = 1E-3; + private static final long MILLI_SECOND = 1L; + private static final long SECOND = 1000L; + private static final long MINUTE = 60L * SECOND; + private static final long HOUR = 60L * MINUTE; + private final int width; + private final VerboseMode verbose; + private final long interval; + private final Set statefulMetrics; + private final CountMode unit; + private final Map> values = new HashMap<>(); + private final List valuesOrdered = new ArrayList<>(); + private final long start = System.currentTimeMillis(); + // will be null if java system console is not present + private final Console console; + // defaults to stdout if console is null; + private final PrintWriter writer; + private final boolean dynamicDisplay; + private Integer target; + private int totalWidth; + private int seenSoFar; + private long lastUpdate; + private Long timeAfterFirstStep; + /** Create a ProgressBar */ + public ProgressBar() { + this(null, DEFAULT_WIDTH, DEFAULT_VERBOSE, DEFAULT_INTERVAL, DEFAULT_COUNT_MODE, null); + } + /** + * Create a ProgressBar + * + * @param target Total number of steps expected, null if unknown. + */ + public ProgressBar(Integer target) { + this(target, DEFAULT_WIDTH, DEFAULT_VERBOSE, DEFAULT_INTERVAL, DEFAULT_COUNT_MODE, null); + } + + /** + * Create a ProgressBar + * + * @param target Total number of steps expected, null if unknown. + * @param verbose Verbosity mode + * @param unit Display name for step counts, "step" or "sample". + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgressBar( + Integer target, VerboseMode verbose, CountMode unit, Set statefulMetrics) { + this(target, DEFAULT_WIDTH, verbose, DEFAULT_INTERVAL, unit, statefulMetrics); + } + + /** + * Create a ProgressBar + * + * @param target Total number of steps expected, null if unknown. + * @param width Progress bar width on screen. + * @param verbose Verbosity mode, false is silent + * @param interval Minimum visual progress update interval (in milliseconds). + * @param unit Display name for step counts, "step" or "sample". + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgressBar( + Integer target, + int width, + VerboseMode verbose, + double interval, + CountMode unit, + Set statefulMetrics) { + this.target = target; + this.width = width; + this.verbose = verbose; + this.interval = (long) interval; + this.unit = unit; + this.statefulMetrics = statefulMetrics == null ? Collections.emptySet() : statefulMetrics; + switch (verbose) { + case VERBOSE: + console = System.console(); + writer = console != null ? null : new PrintWriter(System.out); + break; + case SEMI_VERBOSE: + writer = new PrintWriter(System.out); + console = null; + break; + default: + writer = null; + console = null; + break; + } + dynamicDisplay = console != null; + } + + /** + * Create a ProgressBar + * + * @param writer the writer to use rather than the defaults of {@link Console} or {@link + * System#out} + * @param target Total number of steps expected, null if unknown. + * @param width Progress bar width on screen. + * @param verbose Verbosity mode, false is silent + * @param interval Minimum visual progress update interval (in milliseconds). + * @param unit Display name for step counts, "step" or "sample". + * @param statefulMetrics names of metrics that should not be averaged over an epoch. Metrics in + * this list will be logged as-is. All others will be averaged over time (e.g. loss, etc). If + * not provided, defaults to the Model's metrics. + */ + public ProgressBar( + PrintWriter writer, + Integer target, + int width, + VerboseMode verbose, + double interval, + CountMode unit, + Set statefulMetrics) { + this.target = target; + this.width = width; + this.verbose = verbose; + this.interval = (long) interval; + this.unit = unit; + this.statefulMetrics = statefulMetrics == null ? Collections.emptySet() : statefulMetrics; + switch (verbose) { + case VERBOSE: + case SEMI_VERBOSE: + console = null; + this.writer = writer; + break; + default: + this.writer = null; + console = null; + break; + } + dynamicDisplay = false; + } + + /** + * Updates the progress bar. + * + * @param seen Index of current step. + * @param logValues map of log values to apply + */ + public void update(int seen, Map logValues) { + update(seen, logValues, null); + } + + /** + * Updates the progress bar. + * + * @param current Index of current step, if null the current step is calculated. + * @param logValues Map of log values + * @param finalize Whether this is the last update for the progress bar. If null, defaults to + * {@code current >= self.target}. + */ + public void update(Integer current, Map logValues, Boolean finalize) { + boolean shouldFinalize; + + int iCurrent = current == null ? 0 : current; + + if (finalize == null) { + if (target == null) { + shouldFinalize = false; + } else { + shouldFinalize = iCurrent >= target; + } + } else { + shouldFinalize = finalize; + } + + Map lValues = logValues == null ? Collections.emptyMap() : logValues; + lValues.forEach( + (key, value) -> { + if (!this.valuesOrdered.contains(key)) { + this.valuesOrdered.add(key); + } + if (!this.statefulMetrics.contains(key)) { + // In the case that progress bar doesn't have a target value in the first + // epoch, both onTrainBatchEnd and onEpochEnd will be called, which will + // cause 'current' and 'seenSoFar' to have the same value. Force + // the minimal value to 1 here, otherwise stateful_metric will be 0s. + int valueBase = Math.max(iCurrent - seenSoFar, 1); + // stores the pair, the value and its base, (10, 10) == 1, 10,100 = .1 + if (!values.containsKey(key)) { + values.put(key, Arrays.asList(value.doubleValue() * valueBase, valueBase)); + } else { + List currentValues = values.get(key); + double v1 = currentValues.get(0).doubleValue() + value.doubleValue() * valueBase; + int b1 = currentValues.get(1).intValue() + valueBase; + values.put(key, Arrays.asList(v1, b1)); + } + } else { + values.put(key, Arrays.asList(value, 1)); + } + }); + this.seenSoFar = iCurrent; + long now = System.currentTimeMillis(); + // convert time to seconds + double timeSinceStart = (double) (now - this.start) / SECOND; + StringBuilder info = new StringBuilder(String.format(" - %.0fs", timeSinceStart)); + + if (this.verbose == VerboseMode.VERBOSE) { + if (now - lastUpdate < this.interval && !shouldFinalize) { + return; + } + + int prevTotalWidth = this.totalWidth; + if (dynamicDisplay) { + // backspace to beginning of line + this.console.printf("%s\r", repeat("\b", prevTotalWidth)); + } else { + writer.print('\n'); + } + StringBuilder bar = new StringBuilder(); + if (target != null) { + int numDigits = target > 0 ? (int) Math.log10(target) + 1 : 1; + String formatStr = String.format("%%%dd/%%d [", numDigits); + bar.append(String.format(formatStr, iCurrent, target)); + double prog = (double) iCurrent / target; + int progWidth = (int) (this.width * prog); + if (progWidth > 0) { + bar.append(repeat("=", progWidth - 1)); + if (iCurrent < target) { + bar.append('>'); + } else { + bar.append('='); + } + bar.append(repeat(".", width - progWidth)); + bar.append(']'); + } + } else { + bar.append(String.format("%7d/Unknown", iCurrent)); + } + this.totalWidth = bar.length(); + print(bar.toString()); + + // NOTE: in millis, Python is in seconds + double timePerUnit = estimateStepDuration(current, now); + if (target == null || shouldFinalize) { + if (timePerUnit >= SECOND || timePerUnit == 0.0) { // seconds + info.append( + String.format(" %.0fs/%s", timePerUnit / SECOND, unit.toString().toLowerCase())); + } else if (timePerUnit >= 1) { // milliseconds + info.append(String.format(" %.0fms/%s", timePerUnit, unit.toString().toLowerCase())); + } else { // microseconds + info.append( + String.format(" %.0fus/%s", timePerUnit * SECOND, unit.toString().toLowerCase())); + } + } else { + double eta = timePerUnit * (target - iCurrent); + String etaFormat; + if (eta > HOUR) { // greater than an hour + + etaFormat = + String.format( + "%d:%02d:%02d", // hh:mm:ss + (int) (eta / HOUR), (int) ((eta % HOUR) / MINUTE), (int) (eta % MINUTE) / SECOND); + } else if (eta > MINUTE) { + etaFormat = + String.format( + "%d:%02d", // mm:ss + (int) (eta / MINUTE), (int) (eta % MINUTE) / SECOND); + } else { + etaFormat = String.format("%ds", (int) (eta / SECOND)); // seconds + } + info.append(" - ETA: ").append(etaFormat); + } + + this.valuesOrdered.forEach( + key -> { + info.append(String.format(" - %s:", key)); + List vals = values.get(key); + double avg = vals.get(0).doubleValue() / Math.max(1.0, vals.get(1).doubleValue()); + if (Math.abs(avg) > 1e-3) { // Normal number + info.append(String.format(" %.4f", avg)); + } else { // Floating point notation. + info.append(String.format(" %.4e", avg)); + } + }); + totalWidth += info.length(); + if (prevTotalWidth > totalWidth) { + info.append(repeat(" ", prevTotalWidth - totalWidth)); + } + if (shouldFinalize) { + info.append('\n'); + } + print(info.toString(), true); + + } else if (verbose == VerboseMode.SEMI_VERBOSE) { + if (shouldFinalize) { + int numDigits = target > 0 ? (int) Math.log10(target) + 1 : 1; + String formatStr = String.format("%%%dd/%%d [", numDigits); + final StringBuilder tmpInfo = + new StringBuilder(String.format(formatStr, iCurrent, target)).append(info); + valuesOrdered.forEach( + k -> { + tmpInfo.append(String.format(" - %s:", k)); + List valEntry = values.get(k); + // TODO average + double avg = + valEntry.get(0).doubleValue() / Math.max(1.0, valEntry.get(1).doubleValue()); + if (avg > 1e-3) { + tmpInfo.append(String.format(" %.4f", avg)); + } else { + tmpInfo.append(String.format(" %.4ef", avg)); + } + }); + print(tmpInfo.toString(), true); + } + } + this.lastUpdate = now; + } + + /** + * Print the string to the output stream without flushing + * + * @param s the string + */ + private void print(String s) { + print(s, false); + } + + /** + * Print the string to the output stream + * + * @param s the string + * @param doFlush whether to flush the output after printing. + */ + private void print(String s, boolean doFlush) { + if (this.console != null) { + this.console.printf(s); + if (doFlush) { + this.console.flush(); + } + } else { + writer.print(s); + if (doFlush) { + writer.flush(); + } + } + } + + /** + * Estimate the duration of a single step, in millis + * + *

Given the step number, current, and the corresponding time, now, this function returns an + * estimate for how long a single step takes. If this is called before one step has been completed + * (i.e.{@code current == 0}) then zero is given as an estimate. The duration estimate ignores the + * duration of the (assumed to be non-representative) first step for estimates when more steps are + * available (i.e. {@code current 1}). + * + * @param current Index of current step + * @param now the current time + * @return the estimate of the duration of a single step. + */ + private double estimateStepDuration(Integer current, long now) { + if (current != null) { + // there are a few special scenarios here: + // 1) somebody is calling the progress bar without ever supplying step 1 + // 2) somebody is calling the progress bar and supplies step one mulitple + // times, e.g. as part of a finalizing call + // in these cases, we just fall back to the simple calculation + double timePerUnit; + if (current == 0) { + timePerUnit = (now - start); + } else if (timeAfterFirstStep != null && current > 2) { + timePerUnit = (double) (now - timeAfterFirstStep) / (double) (current - 1); + } else { + timePerUnit = (double) (now - start) / (double) current; + } + if (current == 1) { + timeAfterFirstStep = now; + } + return timePerUnit; + + } else { + return 0; + } + } + + /** + * Repeats the string s, count times. + * + * @param s the string to repeat + * @param count the number of times to repeat + * @return the repeated string + */ + private String repeat(String s, int count) { + return new String(new char[count]).replace("\0", s); + } + + /** updates the progress bar by one unit */ + public void increment() { + add(1); + } + + /** + * updates the progress bar by one unit + * + * @param logValues map of log values to apply + */ + public void increment(Map logValues) { + add(1, logValues); + } + + /** + * update the progress bar + * + * @param n the number of units to add to the current number + */ + public void add(int n) { + add(n, null); + } + + /** + * update the progress bar + * + * @param n the number of units to add to the current number + * @param logValues map of log values to apply + */ + public void add(int n, Map logValues) { + this.update(this.seenSoFar + n, logValues); + } + + /** @return the target */ + public Integer getTarget() { + return target; + } + + /** @param target the target to set */ + public void setTarget(Integer target) { + this.target = target; + } + + public enum CountMode { + /** the progress bar should count steps () */ + STEPS, + /** the progress bar should count samples */ + SAMPLES + } + + /** Verbosity mode */ + public enum VerboseMode { + /** Do not log output */ + SILENT, + /** verbose, try to use {@link Console}, if available */ + VERBOSE, + /** Semi verbose, Use {@link System#out} */ + SEMI_VERBOSE + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java new file mode 100644 index 00000000000..63988fe4166 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java @@ -0,0 +1,113 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVRecord; +import org.junit.jupiter.api.Test; +import org.tensorflow.types.TFloat64; + +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.io.Reader; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.fail; + +class CSVLoggerTest { + + @Test + public void testStandAlone() { + try { + int epoch = 0; + double[] values = {0.95, 0.90, 0.85, 0.90, 0.99, Double.NaN}; + File tmpFile = File.createTempFile("tf-test", ".csv"); + Map logs = new HashMap<>(); + try (CSVLogger csvLogger = new CSVLogger<>(tmpFile)) { + csvLogger.onTrainBegin(null); + for (int i = 0; i < values.length; i++) { + logs.put("accuracy", values[epoch]); + csvLogger.onEpochEnd(epoch++, logs); + } + + try (Reader reader = new FileReader(tmpFile)) { + Iterable records = CSVFormat.EXCEL.withFirstRecordAsHeader().parse(reader); + int iv = 0; + for (CSVRecord record : records) { + String epochStr = record.get("epoch"); + assertNotNull(epochStr); + String valueStr = record.get("accuracy"); + assertNotNull(valueStr); + assertEquals(iv, Integer.valueOf(epochStr)); + double v = Double.valueOf(valueStr); + assertEquals(values[iv++], v, 0e-6); + } + } finally { + tmpFile.delete(); + } + } + + } catch (IOException ex) { + fail(ex); + } + } + + @Test + public void testStandAlone2Vals() { + try { + int epoch = 0; + double[] valuesAcc = {0.95, 0.90, 0.85, 0.90, 0.99, Double.NaN}; + double[] valuesErr = {1e-1, 1e-2, 1e-3, 1e-4, 1e-5, Double.NaN}; + File tmpFile = File.createTempFile("tf-test", ".csv"); + Map logs = new HashMap<>(); + try (CSVLogger csvLogger = new CSVLogger<>(tmpFile)) { + csvLogger.onTrainBegin(null); + for (int i = 0; i < valuesAcc.length; i++) { + logs.put("accuracy", valuesAcc[epoch]); + logs.put("error", valuesErr[epoch]); + csvLogger.onEpochEnd(epoch++, logs); + } + + try (Reader reader = new FileReader(tmpFile)) { + Iterable records = CSVFormat.EXCEL.withFirstRecordAsHeader().parse(reader); + int iv = 0; + for (CSVRecord record : records) { + String epochStr = record.get("epoch"); + assertNotNull(epochStr); + String valueStr = record.get("accuracy"); + assertNotNull(valueStr); + String errorStr = record.get("error"); + assertNotNull(errorStr); + assertEquals(iv, Integer.valueOf(epochStr)); + double v = Double.valueOf(valueStr); + assertEquals(valuesAcc[iv], v, 0e-6); + double e = Double.valueOf(errorStr); + assertEquals(valuesErr[iv], e, 0e-8); + iv++; + } + } finally { + tmpFile.delete(); + } + } + + } catch (IOException ex) { + fail(ex); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CallbackListTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CallbackListTest.java new file mode 100644 index 00000000000..0e4ef138a26 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CallbackListTest.java @@ -0,0 +1,52 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +class CallbackListTest { + + @Test + public void testUpdates() { + Map logs = new HashMap<>(); + logs.put("acc", 0.98); + LambdaCallback lambdaCB = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + + lambdaCB.setOnEpochBegin( + (epoch, log) -> { + called.set(true); + }); + + CallbackList instance = new CallbackList(true); + History history = instance.getHistory(); + instance.addCallback(lambdaCB); + + instance.onTrainBegin(null); + instance.onEpochBegin(0, logs); + instance.onEpochEnd(0, logs); + instance.onTrainEnd(null); + + assertTrue(history.getHistory().containsKey("acc")); + assert (history.getHistory().get("acc").size() == 1); + assertTrue(called.get()); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/HistoryTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/HistoryTest.java new file mode 100644 index 00000000000..b11e3887f0c --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/HistoryTest.java @@ -0,0 +1,59 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class HistoryTest { + + @Test + void testOnTrainBegin() { + History instance = new History(); + + instance.onTrainBegin(null); + Map logs = new HashMap<>(); + logs.put("acc", 0.99); + logs.put("err", 0.012345); + int totalEpochs = 100; + for (int epoch = 0; epoch < totalEpochs; epoch++) { + instance.onEpochEnd(epoch, logs); + } + assertEquals(totalEpochs, instance.getEpoch().size()); + + Map> results = instance.getHistory(); + assertEquals(2, results.size()); + assertEquals(results.get("acc").size(), totalEpochs); + assertEquals(results.get("err").size(), totalEpochs); + + instance.onTrainBegin(null); + assertEquals(0, instance.getEpoch().size()); + for (int epoch = 0; epoch < totalEpochs; epoch++) { + instance.onEpochEnd(epoch, logs); + } + + assertEquals(totalEpochs, instance.getEpoch().size()); + + results = instance.getHistory(); + assertEquals(2, results.size()); + assertEquals(results.get("acc").size(), totalEpochs * 2); + assertEquals(results.get("err").size(), totalEpochs * 2); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/LambdaCallbackTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/LambdaCallbackTest.java new file mode 100644 index 00000000000..36e6397d62c --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/LambdaCallbackTest.java @@ -0,0 +1,63 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class LambdaCallbackTest { + + @Test + void onEpochBegin() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + int expectedEpoch = 101; + Map exceptedLog = new HashMap<>(); + exceptedLog.put("acc", 0.98); + instance.setOnEpochBegin( + (epoch, log) -> { + assertEquals(expectedEpoch, epoch); + assertEquals(exceptedLog, log); + called.set(true); + }); + + Map epochLog = new HashMap<>(); + epochLog.put("acc", 0.98); + instance.onEpochBegin(101, epochLog); + + assertTrue(called.get()); + } + + @Test + void onEpochEnd() {} + + @Test + void onTrainBatchBegin() {} + + @Test + void onTrainBatchEnd() {} + + @Test + void onTrainBegin() {} + + @Test + void onTrainEnd() {} +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/ProgbarLoggerTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/ProgbarLoggerTest.java new file mode 100644 index 00000000000..408aab0fa81 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/ProgbarLoggerTest.java @@ -0,0 +1,176 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.callbacks; + +import org.junit.jupiter.api.Test; +import org.tensorflow.framework.callbacks.util.ProgressBar; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.fail; + +class ProgbarLoggerTest { + + @Test + void testNoTarget() { + File tmpFile = null; + try { + tmpFile = File.createTempFile("tf-test-progbar", ".txt"); + System.out.println(tmpFile); + try (PrintWriter writer = new PrintWriter(new FileWriter(tmpFile))) { + int numEpochs = 1; + int numSteps = 10; + Map params = new HashMap<>(); + + params.put("verbose", ProgressBar.VerboseMode.VERBOSE); + params.put("size", numSteps); + params.put("num_steps", numSteps); + params.put("writer", writer); + ProgbarLogger instance = + new ProgbarLogger(params, null, ProgressBar.CountMode.STEPS, Arrays.asList("acc")); + + Map logs = new HashMap<>(); + logs.put("acc", 0.95); + instance.onTrainBegin(null); + for (int epoch = 0; epoch < numEpochs; epoch++) { + instance.onEpochBegin(epoch, null); + for (int step = 0; step < numSteps; step++) { + instance.onTrainBatchBegin(step, logs); + try { + Thread.sleep(100); + } catch (InterruptedException ignore) { + } + instance.onTrainBatchEnd(step, logs); + } + // instance.onEpochEnd(epoch, logs); + } + instance.onTrainEnd(null); + } catch (IOException ex) { + fail(ex); + } + List results = readResults(tmpFile); + // 1/Unknown - 0s 105ms/steps - acc: 0.9500 + // 10/Unknown - 1s 104ms/steps - acc: 0.9500 + Pattern p1 = + Pattern.compile(" [1 ][0-9]/Unknown - [0-9]s [1-9][0-9][0-9]ms/steps - acc: 0.9500"); + + results.forEach( + line -> { + if (!line.trim().isEmpty()) { + Matcher m = p1.matcher(line); + if (!m.matches()) { + fail("unexpected output \"" + line + "\""); + } + } + }); + } catch (IOException ex) { + fail(ex); + } finally { + if (tmpFile != null) { + // tmpFile.delete(); + } + } + } + + @Test + void testTarget() { + + File tmpFile = null; + try { + tmpFile = File.createTempFile("tf-test-progbar", ".txt"); + try (PrintWriter writer = new PrintWriter(new FileWriter(tmpFile))) { + int numEpochs = 10; + int numSteps = 10; + Map params = new HashMap<>(); + + params.put("verbose", ProgressBar.VerboseMode.VERBOSE); + params.put("size", numSteps); + params.put("num_steps", numSteps); + params.put("steps", numSteps); + params.put("writer", writer); + ProgbarLogger instance = + new ProgbarLogger(params, null, ProgressBar.CountMode.STEPS, Arrays.asList("acc")); + + Map logs = new HashMap<>(); + logs.put("acc", 0.88); + + instance.onTrainBegin(null); + for (int epoch = 0; epoch < numEpochs; epoch++) { + instance.onEpochBegin(epoch, null); + for (int step = 0; step < numSteps; step++) { + instance.onTrainBatchBegin(step, logs); + try { + Thread.sleep(10); + } catch (InterruptedException ignore) { + } + instance.onTrainBatchEnd(step, logs); + } + + instance.onEpochEnd(epoch, logs); + } + instance.onTrainEnd(null); + } + + List results = readResults(tmpFile); + // 1/10 [==>...........................] - 0s - ETA: 0s - acc: 0.8800 + Pattern p1 = Pattern.compile(" [1-9]/10 \\[==*>\\.*\\] - 0s - ETA: 0s - acc: 0.8800"); + // 10/10 [==============================] - 0s 12ms/steps - acc: 0.8800 + Pattern p2 = Pattern.compile("10/10 \\[==*\\] - 0s [1-9][0-9]*ms/steps - acc: 0.8800"); + String finalLine = "10/10 [==============================] - 0s - ETA: 0s - acc: 0.8800"; + results.forEach( + line -> { + if (!line.trim().isEmpty()) { + Matcher m = p1.matcher(line); + if (!m.matches()) { + m = p2.matcher(line); + if (!m.matches()) { + if (!line.equals(finalLine)) { + fail("unexpected output \"" + line + "\""); + } + } + } + } + }); + + } catch (IOException ex) { + fail(ex); + } finally { + if (tmpFile != null) { + tmpFile.delete(); + } + } + } + + private List readResults(File file) { + try (BufferedReader reader = new BufferedReader(new FileReader(file))) { + return reader.lines().collect(Collectors.toList()); + } catch (IOException ex) { + fail("cannot read tmp file", ex); + } + return null; // should not happen + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormatTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormatTest.java new file mode 100644 index 00000000000..decc93e50f5 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/util/PathPlaceholderStringFormatTest.java @@ -0,0 +1,48 @@ +package org.tensorflow.framework.callbacks.util; + +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class PathPlaceholderStringFormatTest { + + @Test + public void testPlaceholder() { + String filePath = "weights.{epoch:02d}-{val_loss:.2f}.hdf5"; + + Map logs = new HashMap<>(); + logs.put("val_loss", 0.71); + + // test with val_loss and 1 digit epoch + String result = PathPlaceholderStringFormat.convertFilePath(filePath, 1, logs); + String expect = "weights.01-0.71.hdf5"; + assertEquals(expect, result); + + // test with val_loss and 2 digit epoch + result = PathPlaceholderStringFormat.convertFilePath(filePath, 12, logs); + expect = "weights.12-0.71.hdf5"; + assertEquals(expect, result); + + // test with val_loss and 2 digit epoch and an added log variable + logs.put("acc", 0.21); + logs.put("val_loss", 0.99); + result = PathPlaceholderStringFormat.convertFilePath(filePath, 12, logs); + expect = "weights.12-0.99.hdf5"; + assertEquals(expect, result); + + // test with empty logs variable + logs.clear(); + result = PathPlaceholderStringFormat.convertFilePath(filePath, 123, logs); + expect = "weights.123-0.00.hdf5"; + assertEquals(expect, result); + + // test with no formatting + filePath = "weights.hdf5"; + result = PathPlaceholderStringFormat.convertFilePath(filePath, 0, logs); + expect = "weights.hdf5"; + assertEquals(expect, result); + } +} From 9efff83a02343e813d87a3b2bfa122b8d56d3a5c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 6 May 2021 17:26:41 -0400 Subject: [PATCH 3/6] Added missing methods in Lambda Callback --- .../framework/callbacks/LambdaCallback.java | 237 ++++++++++++++++ .../callbacks/LambdaCallbackTest.java | 253 +++++++++++++++++- 2 files changed, 485 insertions(+), 5 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java index af976eee2ff..6facdf73e4e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java @@ -61,6 +61,30 @@ public class LambdaCallback extends Callback { /** called at the end of model training. expect one positional argument: `logs` */ private Consumer> onTrainEnd; + /** Called at the beginning of every batch. expect two positional arguments: `batch`, `logs` */ + private BiConsumer> onTestBatchBegin; + + /** called at the end of every batch. expect two positional arguments: `batch`, `logs` */ + private BiConsumer> onTestBatchEnd; + + /** called at the beginning of model training. expect one positional argument: `logs` */ + private Consumer> onTestBegin; + + /** called at the end of model training. expect one positional argument: `logs` */ + private Consumer> onTestEnd; + + /** Called at the beginning of every batch. expect two positional arguments: `batch`, `logs` */ + private BiConsumer> onPredictBatchBegin; + + /** called at the end of every batch. expect two positional arguments: `batch`, `logs` */ + private BiConsumer> onPredictBatchEnd; + + /** called at the beginning of model training. expect one positional argument: `logs` */ + private Consumer> onPredictBegin; + + /** called at the end of model training. expect one positional argument: `logs` */ + private Consumer> onPredictEnd; + /** Creates a LambdaCallbacks callback */ public LambdaCallback() { super(); @@ -123,6 +147,72 @@ public void onTrainEnd(Map logs) { } } + /** {@inheritDoc} */ + @Override + public void onTestBatchBegin(int batch, Map logs) { + if (this.onTestBatchBegin != null) { + this.onTestBatchBegin.accept(batch, logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onTestBatchEnd(int batch, Map logs) { + if (this.onTestBatchEnd != null) { + this.onTestBatchEnd.accept(batch, logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onTestBegin(Map logs) { + if (this.onTestBegin != null) { + this.onTestBegin.accept(logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onTestEnd(Map logs) { + if (this.onTestEnd != null) { + this.onTestEnd.accept(logs); + } + } + + + /** {@inheritDoc} */ + @Override + public void onPredictBatchBegin(int batch, Map logs) { + if (this.onPredictBatchBegin != null) { + this.onPredictBatchBegin.accept(batch, logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onPredictBatchEnd(int batch, Map logs) { + if (this.onPredictBatchEnd != null) { + this.onPredictBatchEnd.accept(batch, logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onPredictBegin(Map logs) { + if (this.onPredictBegin != null) { + this.onPredictBegin.accept(logs); + } + } + + /** {@inheritDoc} */ + @Override + public void onPredictEnd(Map logs) { + if (this.onPredictEnd != null) { + this.onPredictEnd.accept(logs); + } + } + + /** * Gets the onEpochBegin lambda function * @@ -230,4 +320,151 @@ public Consumer> getOnTrainEnd() { public void setOnTrainEnd(Consumer> onTrainEnd) { this.onTrainEnd = onTrainEnd; } + + /** + * Gets the onTestBatchBegin lambda function + * + * @return the onTestBatchBegin lambda function + */ + public BiConsumer> getOnTestBatchBegin() { + return onTestBatchBegin; + } + + /** + * Sets the onTestBatchBegin lambda function + * + * @param onTestBatchBegin the lambda function + */ + public void setOnTestBatchBegin(BiConsumer> onTestBatchBegin) { + this.onTestBatchBegin = onTestBatchBegin; + } + + /** + * Gets the onTestBatchEnd lambda function + * + * @return the onTestBatchEnd lambda function + */ + public BiConsumer> getOnTestBatchEnd() { + return onTestBatchEnd; + } + + /** + * Sets the onTestBatchEnd lambda function + * + * @param onTestBatchEnd the onTestBatchEnd lambda function + */ + public void setOnTestBatchEnd(BiConsumer> onTestBatchEnd) { + this.onTestBatchEnd = onTestBatchEnd; + } + + /** + * Gets the onTestBegin lambda function + * + * @return the onTestBegin lambda function + */ + public Consumer> getOnTestBegin() { + return onTestBegin; + } + + /** + * Sets the onTestBegin lambda function + * + * @param onTestBegin the onTestBegin lambda function + */ + public void setOnTestBegin(Consumer> onTestBegin) { + this.onTestBegin = onTestBegin; + } + + /** + * Gets the onTestBegin lambda function + * + * @return the onTestEnd lambda function + */ + public Consumer> onTestEnd() { + return onTestEnd; + } + + /** + * Sets the onTestEnd lambda function + * + * @param onTestEnd the onTestBegin lambda function + */ + public void setOnTestEnd(Consumer> onTestEnd) { + this.onTestEnd = onTestEnd; + } + + + /** + * Gets the onPredictBatchBegin lambda function + * + * @return the onPredictBatchBegin lambda function + */ + public BiConsumer> getOnPredictBatchBegin() { + return onPredictBatchBegin; + } + + /** + * Sets the onPredictBatchBegin lambda function + * + * @param onPredictBatchBegin the lambda function + */ + public void setOnPredictBatchBegin(BiConsumer> onPredictBatchBegin) { + this.onPredictBatchBegin = onPredictBatchBegin; + } + + /** + * Gets the onPredictBatchEnd lambda function + * + * @return the onPredictBatchEnd lambda function + */ + public BiConsumer> getOnPredictBatchEnd() { + return onPredictBatchEnd; + } + + /** + * Sets the onPredictBatchEnd lambda function + * + * @param onPredictBatchEnd the onPredictBatchEnd lambda function + */ + public void setOnPredictBatchEnd(BiConsumer> onPredictBatchEnd) { + this.onPredictBatchEnd = onPredictBatchEnd; + } + + /** + * Gets the onPredictBegin lambda function + * + * @return the onPredictBegin lambda function + */ + public Consumer> getOnPredictBegin() { + return onPredictBegin; + } + + /** + * Sets the onPredictBegin lambda function + * + * @param onPredictBegin the onPredictBegin lambda function + */ + public void setOnPredictBegin(Consumer> onPredictBegin) { + this.onPredictBegin = onPredictBegin; + } + + /** + * Gets the onPredictEnd lambda function + * + * @return the onPredictEnd lambda function + */ + public Consumer> getOnPredictEnd() { + return onPredictEnd; + } + + /** + * Sets the onPredictEnd lambda function + * + * @param onPredictEnd the onPredictEnd lambda function + */ + public void setOnPredictEnd(Consumer> onPredictEnd) { + this.onPredictEnd = onPredictEnd; + } + + } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/LambdaCallbackTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/LambdaCallbackTest.java index 36e6397d62c..1520b5f376b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/LambdaCallbackTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/LambdaCallbackTest.java @@ -47,17 +47,260 @@ void onEpochBegin() { } @Test - void onEpochEnd() {} + void onEpochEnd() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + int expectedEpoch = 101; + Map exceptedLog = new HashMap<>(); + exceptedLog.put("acc", 0.98); + instance.setOnEpochEnd( + (epoch, log) -> { + assertEquals(expectedEpoch, epoch); + assertEquals(exceptedLog, log); + called.set(true); + }); + + Map epochLog = new HashMap<>(); + epochLog.put("acc", 0.98); + instance.onEpochEnd(101, epochLog); + + assertTrue(called.get()); + } @Test - void onTrainBatchBegin() {} + void onTrainBatchBegin() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + int expectedBatch = 101; + Map exceptedLog = new HashMap<>(); + exceptedLog.put("acc", 0.98); + instance.setOnTrainBatchBegin( + (batch, log) -> { + assertEquals(expectedBatch, batch); + assertEquals(exceptedLog, log); + called.set(true); + }); + + Map epochLog = new HashMap<>(); + epochLog.put("acc", 0.98); + instance.onTrainBatchBegin(101, epochLog); + + assertTrue(called.get()); + } @Test - void onTrainBatchEnd() {} + void onTrainBatchEnd() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + int expectedBatch = 101; + Map exceptedLog = new HashMap<>(); + exceptedLog.put("acc", 0.98); + instance.setOnTrainBatchEnd( + (batch, log) -> { + assertEquals(expectedBatch, batch); + assertEquals(exceptedLog, log); + called.set(true); + }); + + Map epochLog = new HashMap<>(); + epochLog.put("acc", 0.98); + instance.onTrainBatchEnd(101, epochLog); + + assertTrue(called.get()); + } + + @Test + void onTrainBegin() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + Map exceptedLog = new HashMap<>(); + exceptedLog.put("acc", 0.98); + instance.setOnTrainBegin( + logs -> { + assertEquals(exceptedLog, logs); + called.set(true); + }); + + Map log = new HashMap<>(); + log.put("acc", 0.98); + instance.onTrainBegin(log); + + assertTrue(called.get()); + } + + @Test + void onTrainEnd() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + Map expectedLog = new HashMap<>(); + expectedLog.put("acc", 0.98); + instance.setOnTrainEnd( + logs -> { + assertEquals(expectedLog, logs); + called.set(true); + }); + Map log = new HashMap<>(); + log.put("acc", 0.98); + instance.onTrainEnd(log); + + assertTrue(called.get()); + } @Test - void onTrainBegin() {} + void onTestBatchBegin() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + int expectedBatch = 101; + Map exceptedLog = new HashMap<>(); + exceptedLog.put("acc", 0.98); + instance.setOnTestBatchBegin( + (batch, log) -> { + assertEquals(expectedBatch, batch); + assertEquals(exceptedLog, log); + called.set(true); + }); + + Map epochLog = new HashMap<>(); + epochLog.put("acc", 0.98); + instance.onTestBatchBegin(101, epochLog); + + assertTrue(called.get()); + } @Test - void onTrainEnd() {} + void onTestBatchEnd() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + int expectedBatch = 101; + Map exceptedLog = new HashMap<>(); + exceptedLog.put("acc", 0.98); + instance.setOnTestBatchEnd( + (batch, log) -> { + assertEquals(expectedBatch, batch); + assertEquals(exceptedLog, log); + called.set(true); + }); + + Map epochLog = new HashMap<>(); + epochLog.put("acc", 0.98); + instance.onTestBatchEnd(101, epochLog); + + assertTrue(called.get()); + } + + @Test + void onTestBegin() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + Map exceptedLog = new HashMap<>(); + exceptedLog.put("acc", 0.98); + instance.setOnTestBegin( + logs -> { + assertEquals(exceptedLog, logs); + called.set(true); + }); + + Map log = new HashMap<>(); + log.put("acc", 0.98); + instance.onTestBegin(log); + + assertTrue(called.get()); + } + + @Test + void onTestEnd() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + Map expectedLog = new HashMap<>(); + expectedLog.put("acc", 0.98); + instance.setOnTestEnd( + logs -> { + assertEquals(expectedLog, logs); + called.set(true); + }); + Map log = new HashMap<>(); + log.put("acc", 0.98); + instance.onTestEnd(log); + + assertTrue(called.get()); + } + + @Test + void onPredictBatchBegin() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + int expectedBatch = 101; + Map exceptedLog = new HashMap<>(); + exceptedLog.put("acc", 0.98); + instance.setOnPredictBatchBegin( + (batch, log) -> { + assertEquals(expectedBatch, batch); + assertEquals(exceptedLog, log); + called.set(true); + }); + + Map epochLog = new HashMap<>(); + epochLog.put("acc", 0.98); + instance.onPredictBatchBegin(101, epochLog); + + assertTrue(called.get()); + } + + @Test + void onPredictBatchEnd() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + int expectedBatch = 101; + Map exceptedLog = new HashMap<>(); + exceptedLog.put("acc", 0.98); + instance.setOnPredictBatchEnd( + (batch, log) -> { + assertEquals(expectedBatch, batch); + assertEquals(exceptedLog, log); + called.set(true); + }); + + Map epochLog = new HashMap<>(); + epochLog.put("acc", 0.98); + instance.onPredictBatchEnd(101, epochLog); + + assertTrue(called.get()); + } + + @Test + void onPredictBegin() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + Map exceptedLog = new HashMap<>(); + exceptedLog.put("acc", 0.98); + instance.setOnPredictBegin( + logs -> { + assertEquals(exceptedLog, logs); + called.set(true); + }); + + Map log = new HashMap<>(); + log.put("acc", 0.98); + instance.onPredictBegin(log); + + assertTrue(called.get()); + } + + @Test + void onPredictEnd() { + LambdaCallback instance = new LambdaCallback(); + final AtomicBoolean called = new AtomicBoolean(false); + Map expectedLog = new HashMap<>(); + expectedLog.put("acc", 0.98); + instance.setOnPredictEnd( + logs -> { + assertEquals(expectedLog, logs); + called.set(true); + }); + Map log = new HashMap<>(); + log.put("acc", 0.98); + instance.onPredictEnd(log); + + assertTrue(called.get()); + } } From 6aa84eb7ceca5fae980127dca4335c0f3cde08a5 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 6 May 2021 17:27:11 -0400 Subject: [PATCH 4/6] Remove unused class, this is part of ProgressBar --- .../framework/callbacks/VerboseMode.java | 16 ---------------- 1 file changed, 16 deletions(-) delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/VerboseMode.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/VerboseMode.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/VerboseMode.java deleted file mode 100644 index 705603b05f1..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/VerboseMode.java +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.callbacks; - From 4b3bb7cadb5425f10c3e0c0e6a4fa32b18ddbe68 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 6 May 2021 17:27:43 -0400 Subject: [PATCH 5/6] Fixes from PR Comments --- .../framework/callbacks/CSVLogger.java | 24 +++-- .../framework/callbacks/Callback.java | 89 ++++++++----------- .../framework/callbacks/ProgbarLogger.java | 16 ++-- .../framework/callbacks/util/ProgressBar.java | 3 +- .../framework/callbacks/CSVLoggerTest.java | 6 +- 5 files changed, 65 insertions(+), 73 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java index f009dfadfc0..92379fb6f67 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java @@ -35,13 +35,13 @@ import java.util.stream.Collectors; /** - * Callback that streams epoch results to a CSV file. public + * Callback that streams epoch results to a CSV file. * *

Supports all values that can be represented as a string * * @param the data type for the weights in the model */ -class CSVLogger extends Callback implements AutoCloseable { +public class CSVLogger extends Callback implements AutoCloseable { public static final char DEFAULT_SEPARATOR = ','; public static final boolean DEFAULT_APPEND = false; @@ -138,7 +138,7 @@ private String handleValue(Object val) { return ((NdArray) val).getObject().toString(); } else { NdArray array = (NdArray) val; - return toString(array); + return ndArrayToString(array); } } else if (val instanceof Collection) { return "[" @@ -155,7 +155,7 @@ private String handleValue(Object val) { * @param ndArray the NdArray * @return the printable string */ - private String toString(NdArray ndArray) { + private String ndArrayToString(NdArray ndArray) { Iterator> iterator = ndArray.scalars().iterator(); Shape shape = ndArray.shape(); if (shape.numDimensions() == 0) { @@ -164,14 +164,22 @@ private String toString(NdArray ndArray) { } return valToString(iterator.next().getObject()); } - return toString(iterator, shape, 0); + return ndArrayToString(iterator, shape, 0); } - private String toString(Iterator> iterator, Shape shape, int dimension) { + /** + * coverts an NdArray iterator to a printable string + * + * @param iterator the NdArray iterator + * @param shape the shape of the NdArray item + * @param dimension the dimension within the overall NDArray tree + * @return the printable string + */ + private String ndArrayToString(Iterator> iterator, Shape shape, int dimension) { if (dimension < shape.numDimensions() - 1) { StringJoiner joiner = new StringJoiner("", "[", "]"); for (long i = 0, size = shape.size(dimension); i < size; ++i) { - String element = toString(iterator, shape, dimension + 1); + String element = ndArrayToString(iterator, shape, dimension + 1); joiner.add(element); } return joiner.toString(); @@ -189,7 +197,7 @@ private String toString(Iterator> iterator, Shape shape, in * Converts a value to a printable string * * @param val the value - * @return tje printable string + * @return the printable string */ private String valToString(Object val) { if (val instanceof Number) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java index 0ed72e6592c..bbdaf1b7995 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java @@ -24,14 +24,19 @@ * *

The logs map that callback methods take as argument will contain keys for quantities relevant * to the current batch or epoch (see method-specific docstrings). + * + *

This class has empty implementations for {@code onTrainBatchBegin/End}, {@code + * onTrainBegin/End}, {@code onTestBatchBegin/End}, {@code onTestBegin/End}, {@code + * onPredictBatchBegin/End}, and {@code onPredictBegin/End}. Subclasses should override these + * methods for specific processing. */ public abstract class Callback { - protected Map params; + protected final Map params; // TODO protected Model model; /** Creates a Callback */ protected Callback() { - this(null); + this(Collections.emptyMap()); } /** @@ -43,13 +48,13 @@ protected Callback(Map params) { this.params = params; } - /** + /* TODO with Model * Creates a Callback * * @param params Training parameters * @param model the Model */ - /* TODO + /* TODO with Model protected Callback(Map params, Model model) {= this.params = params; this.model = model; @@ -57,9 +62,8 @@ protected Callback(Map params, Model model) {= */ /** - * Performs custom processing at the the start of an epoch. This method should only be Performs - * custom processing during TRAIN mode. This method is empty. Extend this class to handle this - * event. + * Performs custom processing at the the start of an epoch. This method should only be called + * during TRAIN mode. * * @param epoch index of epoch. * @param logs metric results @@ -68,30 +72,28 @@ protected Callback(Map params, Model model) {= public void onEpochBegin(int epoch, Map logs) {} /** - * Performs custom processing at the end of an epoch.This method should only be Performs custom - * processing during TRAIN mode. This method is empty. Extend this class to handle this event. + * Performs custom processing at the end of an epoch. This method should only be called during + * TRAIN mode. * * @param epoch index of epoch. * @param logs metric results for this training epoch, and for the validation epoch if validation - * is performed. Validation result keys are prefixed with `val_`. + * is performed. Validation result keys are prefixed with {@code val_}. */ @SuppressWarnings("unused") public void onEpochEnd(int epoch, Map logs) {} /** - * Performs custom processing at the beginning of a training batch in `fit` methods. This method - * is empty. Extend this class to handle this event. + * Performs custom processing at the beginning of a training batch in {@code model.fit} methods. * * @param batch the batch index - * @param logs Has keys `batch` and `size` representing the current batch number and the size of - * the batch. + * @param logs Has keys {@code batch} and {@code size} representing the current batch number and + * the size of the batch. */ @SuppressWarnings("unused") public void onTrainBatchBegin(int batch, Map logs) {} /** - * Performs custom processing at the end of a training batch in `fit` methods. This method is - * empty. Extend this class to handle this event. + * Performs custom processing at the end of a training batch in {@code model.fit} methods. * * @param batch index of batch within the current epoch. * @param logs Metric results for this batch. @@ -100,8 +102,7 @@ public void onTrainBatchBegin(int batch, Map logs) {} public void onTrainBatchEnd(int batch, Map logs) {} /** - * Performs custom processing at the beginning of training. This method is empty. Extend this - * class to handle this event. + * Performs custom processing at the beginning of training. * * @param logs metric results */ @@ -109,8 +110,7 @@ public void onTrainBatchEnd(int batch, Map logs) {} public void onTrainBegin(Map logs) {} /** - * Performs custom processing at the end of training. This method is empty. Extend this class to - * handle this event. + * Performs custom processing at the end of training. * * @param logs metric results */ @@ -118,23 +118,21 @@ public void onTrainBegin(Map logs) {} public void onTrainEnd(Map logs) {} /** - * Performs custom processing at the beginning of a batch in `evaluate` methods. Also Performs - * custom processing at the beginning of a validation batch in the `fit` methods, if validation - * data is provided. This method is empty. Extend this class to handle this event. + * Performs custom processing at the beginning of a batch in {@code model.evaluate} methods. Also + * Performs custom processing at the beginning of a validation batch in the {@code fit} methods, + * if validation data is provided. * * @param batch the batch number - * @param logs Has keys `batch` and `size` representing the current batch number and the size of - * the batch. + * @param logs Has keys {@code batch} and {@code size} representing the current batch number and + * the size of the batch. */ @SuppressWarnings("unused") public void onTestBatchBegin(int batch, Map logs) {} /** - * Performs custom processing at the end of a batch in `evaluate` methods. Also Performs custom - * processing at the end of a validation batch in the `fit` methods, if validation data is - * provided. - * - *

This method is empty. Extend this class to handle this event. + * Performs custom processing at the end of a batch in {@code model.evaluate} methods. Also Performs + * custom processing at the end of a validation batch in the {@code fit} methods, if validation + * data is provided. * * @param batch the batch number * @param logs Metric results for this batch. @@ -143,8 +141,7 @@ public void onTestBatchBegin(int batch, Map logs) {} public void onTestBatchEnd(int batch, Map logs) {} /** - * Performs custom processing at the beginning of evaluation or validation. This method is empty. - * Extend this class to handle this event. + * Performs custom processing at the beginning of evaluation or validation. * * @param logs metric results */ @@ -152,8 +149,7 @@ public void onTestBatchEnd(int batch, Map logs) {} public void onTestBegin(Map logs) {} /** - * Performs custom processing at the end of evaluation or validation. This method is empty. Extend - * this class to handle this event. + * Performs custom processing at the end of evaluation or validation. * * @param logs metric results */ @@ -161,19 +157,17 @@ public void onTestBegin(Map logs) {} public void onTestEnd(Map logs) {} /** - * Performs custom processing at the beginning of a batch in `predict` methods. This method is - * empty. Extend this class to handle this event. + * Performs custom processing at the beginning of a batch in {@code model.predict} methods. * * @param batch index of batch within the current epoch. - * @param logs Has keys `batch` and `size` representing the current batch number and the size of - * the batch. + * @param logs Has keys {@code batch} and {@code size} representing the current batch number and + * the size of the batch. */ @SuppressWarnings("unused") public void onPredictBatchBegin(int batch, Map logs) {} /** - * Performs custom processing at the end of a batch in `predict` methods. This method is empty. - * Extend this class to handle this event. + * Performs custom processing at the end of a batch in {@code model.predict} methods. * * @param batch index of batch within the current epoch. * @param logs Metric results for this batch. @@ -182,8 +176,7 @@ public void onPredictBatchBegin(int batch, Map logs) {} public void onPredictBatchEnd(int batch, Map logs) {} /** - * Performs custom processing at the beginning of prediction. This method is empty. Extend this - * class to handle this event. + * Performs custom processing at the beginning of prediction. * * @param logs metric results */ @@ -191,8 +184,7 @@ public void onPredictBatchEnd(int batch, Map logs) {} public void onPredictBegin(Map logs) {} /** - * Performs custom processing at the end of prediction. This method is empty. Extend this class to - * handle this event. + * Performs custom processing at the end of prediction. * * @param logs metric results */ @@ -230,15 +222,6 @@ public Map getParams() { return params; } - /** - * Sets the params - * - * @param params the params to set - */ - public void setParams(Map params) { - this.params = params; - } - /** * Gets the model * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/ProgbarLogger.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/ProgbarLogger.java index 112fa8f3cd1..b96c42c3ae6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/ProgbarLogger.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/ProgbarLogger.java @@ -160,20 +160,20 @@ public ProgbarLogger( ProgressBar.CountMode unit, List statefulMetrics) { // TODO super(params, model); + super(params); this.unit = unit; this.statefulMetrics = statefulMetrics != null ? new HashSet<>(statefulMetrics) : new HashSet<>(); - setParams(params); + init(); } - /** {@inheritDoc} */ - @Override - public final void setParams(Map params) { - if (params == null) { - return; - } - super.setParams(params); + + /** + * Initializes the ProgbarLogger + */ + private final void init() { + Map params = getParams(); verbose = ((ProgressBar.VerboseMode) params.getOrDefault("verbose", ProgressBar.VerboseMode.VERBOSE)); epochs = (Integer) params.getOrDefault("epochs", 1); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/ProgressBar.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/ProgressBar.java index cf645157c13..0d5a4a4b678 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/ProgressBar.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/util/ProgressBar.java @@ -424,6 +424,7 @@ private double estimateStepDuration(Integer current, long now) { * @return the repeated string */ private String repeat(String s, int count) { + // TODO JDK 11 update with s.repeat(count) return new String(new char[count]).replace("\0", s); } @@ -471,7 +472,7 @@ public void setTarget(Integer target) { } public enum CountMode { - /** the progress bar should count steps () */ + /** the progress bar should count steps */ STEPS, /** the progress bar should count samples */ SAMPLES diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java index 63988fe4166..5eedb881158 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java @@ -30,7 +30,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.fail; -class CSVLoggerTest { +public class CSVLoggerTest { @Test public void testStandAlone() { @@ -59,7 +59,7 @@ public void testStandAlone() { assertEquals(values[iv++], v, 0e-6); } } finally { - tmpFile.delete(); + tmpFile.deleteOnExit(); } } @@ -102,7 +102,7 @@ public void testStandAlone2Vals() { iv++; } } finally { - tmpFile.delete(); + tmpFile.deleteOnExit(); } } From 66a4bdd65d6292cf7cc363d5128fd42b7242139a Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 7 May 2021 10:33:25 -0400 Subject: [PATCH 6/6] Change tmoFile to deleteOnExit, remove finally block --- .../tensorflow/framework/callbacks/CSVLoggerTest.java | 6 ++---- .../framework/callbacks/ProgbarLoggerTest.java | 11 ++--------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java index 5eedb881158..7467779caf0 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/CSVLoggerTest.java @@ -38,6 +38,7 @@ public void testStandAlone() { int epoch = 0; double[] values = {0.95, 0.90, 0.85, 0.90, 0.99, Double.NaN}; File tmpFile = File.createTempFile("tf-test", ".csv"); + tmpFile.deleteOnExit(); Map logs = new HashMap<>(); try (CSVLogger csvLogger = new CSVLogger<>(tmpFile)) { csvLogger.onTrainBegin(null); @@ -58,8 +59,6 @@ public void testStandAlone() { double v = Double.valueOf(valueStr); assertEquals(values[iv++], v, 0e-6); } - } finally { - tmpFile.deleteOnExit(); } } @@ -75,6 +74,7 @@ public void testStandAlone2Vals() { double[] valuesAcc = {0.95, 0.90, 0.85, 0.90, 0.99, Double.NaN}; double[] valuesErr = {1e-1, 1e-2, 1e-3, 1e-4, 1e-5, Double.NaN}; File tmpFile = File.createTempFile("tf-test", ".csv"); + tmpFile.deleteOnExit(); Map logs = new HashMap<>(); try (CSVLogger csvLogger = new CSVLogger<>(tmpFile)) { csvLogger.onTrainBegin(null); @@ -101,8 +101,6 @@ public void testStandAlone2Vals() { assertEquals(valuesErr[iv], e, 0e-8); iv++; } - } finally { - tmpFile.deleteOnExit(); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/ProgbarLoggerTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/ProgbarLoggerTest.java index 408aab0fa81..b85f5737a84 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/ProgbarLoggerTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/callbacks/ProgbarLoggerTest.java @@ -40,7 +40,7 @@ void testNoTarget() { File tmpFile = null; try { tmpFile = File.createTempFile("tf-test-progbar", ".txt"); - System.out.println(tmpFile); + tmpFile.deleteOnExit(); try (PrintWriter writer = new PrintWriter(new FileWriter(tmpFile))) { int numEpochs = 1; int numSteps = 10; @@ -89,10 +89,6 @@ void testNoTarget() { }); } catch (IOException ex) { fail(ex); - } finally { - if (tmpFile != null) { - // tmpFile.delete(); - } } } @@ -102,6 +98,7 @@ void testTarget() { File tmpFile = null; try { tmpFile = File.createTempFile("tf-test-progbar", ".txt"); + tmpFile.deleteOnExit(); try (PrintWriter writer = new PrintWriter(new FileWriter(tmpFile))) { int numEpochs = 10; int numSteps = 10; @@ -158,10 +155,6 @@ void testTarget() { } catch (IOException ex) { fail(ex); - } finally { - if (tmpFile != null) { - tmpFile.delete(); - } } }