org.junit.jupiter
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java
new file mode 100644
index 00000000000..92379fb6f67
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CSVLogger.java
@@ -0,0 +1,277 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.callbacks;
+
+import org.apache.commons.csv.CSVFormat;
+import org.apache.commons.csv.CSVPrinter;
+import org.tensorflow.ndarray.NdArray;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.types.family.TNumber;
+
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.StringJoiner;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
+
+/**
+ * Callback that streams epoch results to a CSV file.
+ *
+ * Supports all values that can be represented as a string
+ *
+ * @param the data type for the weights in the model
+ */
+public class CSVLogger extends Callback implements AutoCloseable {
+
+ public static final char DEFAULT_SEPARATOR = ',';
+ public static final boolean DEFAULT_APPEND = false;
+
+ private final File file;
+ private final char separator;
+ private final boolean append;
+ private List keys;
+ private boolean appendHeader = true;
+
+ private CSVPrinter writer;
+
+ /**
+ * Creates a CSVLogger callback using {@link #DEFAULT_SEPARATOR} to separate elements in the csv
+ * file, and {@link #DEFAULT_APPEND} for the append value.
+ *
+ * @param file the csv file
+ */
+ public CSVLogger(File file) {
+ this(file, DEFAULT_SEPARATOR, DEFAULT_APPEND);
+ }
+
+ /**
+ * Creates a CSVLogger callback using {@link #DEFAULT_SEPARATOR} to separate elements in the csv
+ * file, and {@link #DEFAULT_APPEND} for the append value.
+ *
+ * @param filename filename of the csv file
+ */
+ public CSVLogger(String filename) {
+ this(new File(filename), DEFAULT_SEPARATOR, DEFAULT_APPEND);
+ }
+
+ /**
+ * Creates a CSVLogger callback using {@link #DEFAULT_APPEND} for the append value.
+ *
+ * @param file the csv file
+ * @param separator string used to separate elements in the csv file.
+ */
+ public CSVLogger(File file, char separator) {
+ this(file, separator, false);
+ }
+
+ /**
+ * Creates a CSVLogger callback using {@link #DEFAULT_APPEND} for the append value.
+ *
+ * @param filename filename of the csv file
+ * @param separator string used to separate elements in the csv file.
+ */
+ public CSVLogger(String filename, char separator) {
+ this(new File(filename), separator, false);
+ }
+
+ /**
+ * Creates a CSVLogger callback.
+ *
+ * @param filename filename of the csv file
+ * @param separator the character used to separate elements in the csv file.
+ * @param append if true, append if file exists (useful for continuing training). if false,
+ * overwrite existing file.
+ */
+ public CSVLogger(String filename, char separator, boolean append) {
+ this(new File(filename), separator, append);
+ }
+
+ /**
+ * Creates a CSVLogger callback.
+ *
+ * @param file the csv file
+ * @param separator the character used to separate elements in the csv file.
+ * @param append if true, append if file exists (useful for continuing training). if false,
+ * overwrite existing file.
+ */
+ public CSVLogger(File file, char separator, boolean append) {
+ this.file = file;
+ this.separator = separator;
+ this.append = append;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onTrainBegin(Map logs) {
+ appendHeader = !append || !file.exists();
+ }
+
+ // TODO Should we handle Java arrays??
+ @SuppressWarnings("unchecked")
+ private String handleValue(Object val) {
+
+ if (val instanceof String) {
+ return val.toString();
+ } else if (val instanceof NdArray) { // todo
+ boolean isScalar = ((NdArray>) val).rank() == 0;
+ if (isScalar) {
+ return ((NdArray>) val).getObject().toString();
+ } else {
+ NdArray> array = (NdArray) val;
+ return ndArrayToString(array);
+ }
+ } else if (val instanceof Collection) {
+ return "["
+ + ((Collection) val).stream().map(Object::toString).collect(Collectors.joining(","))
+ + "]";
+ } else {
+ return val.toString();
+ }
+ }
+
+ /**
+ * coverts an NdArray to a printable string
+ *
+ * @param ndArray the NdArray
+ * @return the printable string
+ */
+ private String ndArrayToString(NdArray> ndArray) {
+ Iterator extends NdArray>> iterator = ndArray.scalars().iterator();
+ Shape shape = ndArray.shape();
+ if (shape.numDimensions() == 0) {
+ if (!iterator.hasNext()) {
+ return "";
+ }
+ return valToString(iterator.next().getObject());
+ }
+ return ndArrayToString(iterator, shape, 0);
+ }
+
+ /**
+ * coverts an NdArray iterator to a printable string
+ *
+ * @param iterator the NdArray iterator
+ * @param shape the shape of the NdArray item
+ * @param dimension the dimension within the overall NDArray tree
+ * @return the printable string
+ */
+ private String ndArrayToString(Iterator extends NdArray>> iterator, Shape shape, int dimension) {
+ if (dimension < shape.numDimensions() - 1) {
+ StringJoiner joiner = new StringJoiner("", "[", "]");
+ for (long i = 0, size = shape.size(dimension); i < size; ++i) {
+ String element = ndArrayToString(iterator, shape, dimension + 1);
+ joiner.add(element);
+ }
+ return joiner.toString();
+ } else {
+ StringJoiner joiner = new StringJoiner(", ", "[", "]");
+ for (long i = 0, size = shape.size(dimension); i < size; ++i) {
+ Object element = iterator.next().getObject();
+ joiner.add(valToString(element));
+ }
+ return joiner.toString();
+ }
+ }
+
+ /**
+ * Converts a value to a printable string
+ *
+ * @param val the value
+ * @return the printable string
+ */
+ private String valToString(Object val) {
+ if (val instanceof Number) {
+ Number nVal = (Number) val;
+ if (nVal instanceof Float || nVal instanceof Double) {
+ return String.format("%e", nVal.doubleValue());
+ } else if (nVal instanceof Byte) {
+ return String.format("0x%2x", nVal.byteValue());
+ } else {
+ return String.format("%d", nVal.longValue());
+ }
+ } else {
+ return val.toString();
+ }
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ @SuppressWarnings("unchecked")
+ public void onEpochEnd(int epoch, Map logs) {
+ logs = logs == null ? Collections.EMPTY_MAP : logs;
+
+ if (keys == null) {
+ keys = new ArrayList<>(logs.keySet());
+ Collections.sort(this.keys);
+ }
+
+ if (writer == null) {
+ try {
+ List fieldNames = new ArrayList<>();
+ fieldNames.add("epoch");
+ fieldNames.addAll(this.keys);
+ CSVFormat csvFormat =
+ appendHeader
+ ? CSVFormat.EXCEL
+ .withHeader(fieldNames.toArray(new String[0]))
+ .withDelimiter(separator)
+ : CSVFormat.EXCEL.withDelimiter(separator);
+ writer = new CSVPrinter(new FileWriter(file, append), csvFormat);
+ } catch (IOException ex) {
+ Logger.getLogger(CSVLogger.class.getName()).log(Level.SEVERE, null, ex);
+ return;
+ }
+ }
+
+ /* TODO include when integrated with Model
+ if (getModel().isStopTraining()) {
+ final Map flogs = logs;
+ keys.forEach(
+ key -> {
+ if (!flogs.containsKey(key)) {
+ flogs.put(key, Double.NaN);
+ }
+ });
+ }
+ */
+ try {
+ final List values = new ArrayList<>();
+ final Map logsFinal = logs;
+ values.add(String.valueOf(epoch));
+ keys.forEach(key -> values.add(handleValue(logsFinal.get(key))));
+ writer.printRecord(values);
+ writer.flush();
+ } catch (IOException ex) {
+ Logger.getLogger(CSVLogger.class.getName()).log(Level.SEVERE, null, ex);
+ }
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void close() throws IOException {
+ if (writer != null) {
+ writer.close();
+ writer = null;
+ }
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java
new file mode 100644
index 00000000000..bbdaf1b7995
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/Callback.java
@@ -0,0 +1,245 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.callbacks;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * Abstract base class used to build new callbacks.
+ *
+ * The logs map that callback methods take as argument will contain keys for quantities relevant
+ * to the current batch or epoch (see method-specific docstrings).
+ *
+ *
This class has empty implementations for {@code onTrainBatchBegin/End}, {@code
+ * onTrainBegin/End}, {@code onTestBatchBegin/End}, {@code onTestBegin/End}, {@code
+ * onPredictBatchBegin/End}, and {@code onPredictBegin/End}. Subclasses should override these
+ * methods for specific processing.
+ */
+public abstract class Callback {
+ protected final Map params;
+ // TODO protected Model model;
+
+ /** Creates a Callback */
+ protected Callback() {
+ this(Collections.emptyMap());
+ }
+
+ /**
+ * Creates a Callback
+ *
+ * @param params Training parameters
+ */
+ protected Callback(Map params) {
+ this.params = params;
+ }
+
+ /* TODO with Model
+ * Creates a Callback
+ *
+ * @param params Training parameters
+ * @param model the Model
+ */
+ /* TODO with Model
+ protected Callback(Map params, Model model) {=
+ this.params = params;
+ this.model = model;
+ }
+ */
+
+ /**
+ * Performs custom processing at the the start of an epoch. This method should only be called
+ * during TRAIN mode.
+ *
+ * @param epoch index of epoch.
+ * @param logs metric results
+ */
+ @SuppressWarnings("unused")
+ public void onEpochBegin(int epoch, Map logs) {}
+
+ /**
+ * Performs custom processing at the end of an epoch. This method should only be called during
+ * TRAIN mode.
+ *
+ * @param epoch index of epoch.
+ * @param logs metric results for this training epoch, and for the validation epoch if validation
+ * is performed. Validation result keys are prefixed with {@code val_}.
+ */
+ @SuppressWarnings("unused")
+ public void onEpochEnd(int epoch, Map logs) {}
+
+ /**
+ * Performs custom processing at the beginning of a training batch in {@code model.fit} methods.
+ *
+ * @param batch the batch index
+ * @param logs Has keys {@code batch} and {@code size} representing the current batch number and
+ * the size of the batch.
+ */
+ @SuppressWarnings("unused")
+ public void onTrainBatchBegin(int batch, Map logs) {}
+
+ /**
+ * Performs custom processing at the end of a training batch in {@code model.fit} methods.
+ *
+ * @param batch index of batch within the current epoch.
+ * @param logs Metric results for this batch.
+ */
+ @SuppressWarnings("unused")
+ public void onTrainBatchEnd(int batch, Map logs) {}
+
+ /**
+ * Performs custom processing at the beginning of training.
+ *
+ * @param logs metric results
+ */
+ @SuppressWarnings("unused")
+ public void onTrainBegin(Map logs) {}
+
+ /**
+ * Performs custom processing at the end of training.
+ *
+ * @param logs metric results
+ */
+ @SuppressWarnings("unused")
+ public void onTrainEnd(Map logs) {}
+
+ /**
+ * Performs custom processing at the beginning of a batch in {@code model.evaluate} methods. Also
+ * Performs custom processing at the beginning of a validation batch in the {@code fit} methods,
+ * if validation data is provided.
+ *
+ * @param batch the batch number
+ * @param logs Has keys {@code batch} and {@code size} representing the current batch number and
+ * the size of the batch.
+ */
+ @SuppressWarnings("unused")
+ public void onTestBatchBegin(int batch, Map logs) {}
+
+ /**
+ * Performs custom processing at the end of a batch in {@code model.evaluate} methods. Also Performs
+ * custom processing at the end of a validation batch in the {@code fit} methods, if validation
+ * data is provided.
+ *
+ * @param batch the batch number
+ * @param logs Metric results for this batch.
+ */
+ @SuppressWarnings("unused")
+ public void onTestBatchEnd(int batch, Map logs) {}
+
+ /**
+ * Performs custom processing at the beginning of evaluation or validation.
+ *
+ * @param logs metric results
+ */
+ @SuppressWarnings("unused")
+ public void onTestBegin(Map logs) {}
+
+ /**
+ * Performs custom processing at the end of evaluation or validation.
+ *
+ * @param logs metric results
+ */
+ @SuppressWarnings("unused")
+ public void onTestEnd(Map logs) {}
+
+ /**
+ * Performs custom processing at the beginning of a batch in {@code model.predict} methods.
+ *
+ * @param batch index of batch within the current epoch.
+ * @param logs Has keys {@code batch} and {@code size} representing the current batch number and
+ * the size of the batch.
+ */
+ @SuppressWarnings("unused")
+ public void onPredictBatchBegin(int batch, Map logs) {}
+
+ /**
+ * Performs custom processing at the end of a batch in {@code model.predict} methods.
+ *
+ * @param batch index of batch within the current epoch.
+ * @param logs Metric results for this batch.
+ */
+ @SuppressWarnings("unused")
+ public void onPredictBatchEnd(int batch, Map logs) {}
+
+ /**
+ * Performs custom processing at the beginning of prediction.
+ *
+ * @param logs metric results
+ */
+ @SuppressWarnings("unused")
+ public void onPredictBegin(Map logs) {}
+
+ /**
+ * Performs custom processing at the end of prediction.
+ *
+ * @param logs metric results
+ */
+ @SuppressWarnings("unused")
+ public void onPredictEnd(Map logs) {}
+
+ /**
+ * Gets a monitor value from the value logs
+ *
+ * @param logs the value logs
+ * @param monitor the monitor to fetch
+ * @return the monitor value, returns null if the monitor value is not in logs.
+ */
+ @SuppressWarnings("unchecked")
+ protected Number getMonitorValue(Map logs, String monitor) {
+ logs = logs == null ? Collections.EMPTY_MAP : logs;
+ Number monitorValue = logs.get(monitor);
+ if (monitorValue != null) {
+ Logger.getLogger(getClass().getName())
+ .log(
+ Level.WARNING,
+ String.format(
+ "Early stopping conditioned on metric `%s` which is not available. Available metrics are: %s",
+ monitor, String.join(",", logs.keySet())));
+ }
+ return monitorValue;
+ }
+
+ /**
+ * Gets the params
+ *
+ * @return the params
+ */
+ public Map getParams() {
+ return params;
+ }
+
+ /**
+ * Gets the model
+ *
+ * @return the model
+ */
+ /* TODO
+ public Model getModel() {
+ return model;
+ TODO */
+
+ /**
+ * Sets the model
+ *
+ * @param model the model
+ */
+ /* TODO
+ public void setModel(Model model) {
+ this.model = model;
+ }
+ TODO */
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CallbackList.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CallbackList.java
new file mode 100644
index 00000000000..a1124f1b266
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/CallbackList.java
@@ -0,0 +1,230 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.callbacks;
+
+// TODO import org.tensorflow.framework.model.Model;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Container for {@link Callback} instances.
+ *
+ * This object wraps a list of {@link Callback} instances, making it possible to call them all at
+ * once via a single endpoint (e.g. {@code callbackList.onEpochEnd(...)}).
+ */
+public class CallbackList extends Callback {
+ private final List callbacks = new ArrayList<>();
+
+ // TODO private final Model extends TFloating> model;
+ private History history;
+
+ /** Creates a CallbackList */
+ public CallbackList() {
+ super();
+ }
+
+ /**
+ * Creates a CallbackList
+ *
+ * @param addHistory Whether a {@link History} callback should be added, if one does not already
+ * exist in the {@code callbacks} list.
+ */
+ public CallbackList(boolean addHistory) {
+ addDefaultCallbacks(addHistory);
+ }
+
+ /**
+ * Creates a CallbackList
+ *
+ * @param callbacks List of {@link Callback} instances.
+ * @param addHistory Whether a {@link History} callback should be added, if one does not already
+ * exist * in the {@code callbacks} list.
+ */
+ public CallbackList(List callbacks, boolean addHistory) {
+ this(addHistory);
+ this.callbacks.addAll(callbacks);
+ }
+
+ /* TODO add when integrated with Model
+ // /**
+ * Creates a CallbackList
+ *
+ * @param model the model these callbacks are used with.
+ * @param addHistory Whether a {@link History} callback should be added, if one does not already
+ * exist in the {@code callbacks} list.
+ // *
+
+ public CallbackList(Model extends TFloating> model, boolean addHistory) {
+ this.model = model;
+ addDefaultCallbacks(addHistory);
+ }
+ TODO */
+
+ /* TODO add when integrated with Model
+ // /**
+ * Creates a CallbackList
+ *
+ * @param model the model these callbacks are used with.
+ * @param callbacks List of {@link Callback} instances.
+ * @param addHistory Whether a {@link History} callback should be added, if one does not already
+ * exist in the {@code callbacks} list.
+ // *
+
+ public CallbackList(Model extends TFloating> model, List callbacks, boolean addHistory) {
+ this(model, addHistory);
+ this.callbacks.addAll(callbacks);
+ }
+ TODO */
+
+ /**
+ * Adds Callback's that are always present.
+ *
+ * @param addHistory Whether a {@link History} callback should be added, if one does not already
+ * exist in the {@code callbacks} list.
+ */
+ private void addDefaultCallbacks(boolean addHistory) {
+ callbacks.forEach(
+ c -> {
+ if (c instanceof History) {
+ history = (History) c;
+ }
+ });
+ if (history == null && addHistory) {
+ history = new History();
+ addCallback(history);
+ }
+ }
+
+ /**
+ * Adds a callback
+ *
+ * @param callback the callback
+ */
+ public void addCallback(Callback callback) {
+ callbacks.add(callback);
+ }
+
+ /**
+ * Adds callbacks
+ *
+ * @param callbacks the callbacks
+ */
+ public void addCallbacks(List callbacks) {
+ this.callbacks.addAll(callbacks);
+ }
+
+ /** {@inheritDoc } */
+ @Override
+ public void onTrainBegin(Map logs) {
+ callbacks.forEach(c -> c.onTrainBegin(logs));
+ }
+
+ /** {@inheritDoc } */
+ @Override
+ public void onTrainEnd(Map logs) {
+ callbacks.forEach(c -> c.onTrainEnd(logs));
+ }
+ /** {@inheritDoc } */
+ @Override
+ public void onEpochBegin(int epoch, Map logs) {
+ callbacks.forEach(c -> c.onEpochBegin(epoch, logs));
+ }
+
+ /** {@inheritDoc } */
+ @Override
+ public void onEpochEnd(int epoch, Map logs) {
+ callbacks.forEach(c -> c.onEpochEnd(epoch, logs));
+ }
+
+ /** {@inheritDoc } */
+ @Override
+ public void onTrainBatchBegin(int batch, Map logs) {
+ callbacks.forEach(c -> c.onTrainBatchBegin(batch, logs));
+ }
+
+ /** {@inheritDoc } */
+ @Override
+ public void onTrainBatchEnd(int batch, Map logs) {
+ callbacks.forEach(c -> c.onTrainBatchEnd(batch, logs));
+ }
+
+ /** {@inheritDoc } */
+ @Override
+ public void onTestBatchBegin(int batch, Map logs) {
+ callbacks.forEach(c -> c.onTestBatchBegin(batch, logs));
+ }
+
+ /** {@inheritDoc } */
+ @Override
+ public void onTestBatchEnd(int batch, Map logs) {
+ callbacks.forEach(c -> c.onTestBatchEnd(batch, logs));
+ }
+
+ /** {@inheritDoc } */
+ @Override
+ public void onTestBegin(Map logs) {
+ callbacks.forEach(c -> c.onTestBegin(logs));
+ }
+
+ /** {@inheritDoc } */
+ @Override
+ public void onTestEnd(Map logs) {
+ callbacks.forEach(c -> c.onTestEnd(logs));
+ }
+
+ /** {@inheritDoc } */
+ @Override
+ public void onPredictBatchBegin(int batch, Map logs) {
+ callbacks.forEach(c -> c.onPredictBatchBegin(batch, logs));
+ }
+
+ /** {@inheritDoc } */
+ @Override
+ public void onPredictBatchEnd(int batch, Map logs) {
+ callbacks.forEach(c -> c.onPredictBatchEnd(batch, logs));
+ }
+
+ /** {@inheritDoc } */
+ @Override
+ public void onPredictBegin(Map logs) {
+ callbacks.forEach(c -> c.onPredictBegin(logs));
+ }
+
+ /** {@inheritDoc } */
+ @Override
+ public void onPredictEnd(Map logs) {
+ callbacks.forEach(c -> c.onPredictEnd(logs));
+ }
+
+ /**
+ * Gets the callbacks
+ *
+ * @return the callbacks
+ */
+ public List getCallbacks() {
+ return callbacks;
+ }
+
+ /**
+ * Gets the history callback
+ *
+ * @return the history callback
+ */
+ public History getHistory() {
+ return history;
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/History.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/History.java
new file mode 100644
index 00000000000..53ce9373d8f
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/History.java
@@ -0,0 +1,91 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.callbacks;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Callback that records events into a History object.
+ *
+ * This callback is automatically applied to every model. The History object gets returned by the
+ * fit method of models.
+ */
+public class History extends Callback {
+ private final Map> history = new HashMap<>();
+ private final List epoch = new ArrayList<>();
+
+ /** Creates a History Callback */
+ public History() {
+ super();
+ }
+
+ /* TODO
+ * Creates a History Callback
+ *
+ * @param params Training parameters
+ * @param model the Model
+ *
+
+ public History( Model model) {=
+ super(null, model);
+ }
+ TODO */
+
+ /** {@inheritDoc} */
+ @Override
+ public void onTrainBegin(Map logs) {
+ epoch.clear();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void onEpochEnd(int epoch, Map logs) {
+ Map localLogs = logs == null ? Collections.emptyMap() : logs;
+ this.epoch.add(epoch);
+
+ logs.entrySet()
+ .forEach(
+ e -> {
+ List item = history.get(e.getKey());
+ if (item == null) {
+ item = new ArrayList<>();
+ history.put(e.getKey(), item);
+ }
+ item.add(e.getValue());
+ });
+ }
+
+ /**
+ * Gets the History map for each log value
+ *
+ * @return the History map for each log value
+ */
+ public Map> getHistory() {
+ return history;
+ }
+
+ /**
+ * Gets the history of epochs
+ *
+ * @return the history of epochs
+ */
+ public List getEpoch() {
+ return epoch;
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java
new file mode 100644
index 00000000000..6facdf73e4e
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/callbacks/LambdaCallback.java
@@ -0,0 +1,470 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.callbacks;
+
+import java.util.Map;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+
+/**
+ * Callback for creating simple, custom callbacks on-the-fly.
+ *
+ * Example:
+ *
+ *
{@code
+ * LambdaCallbacks batchPrintCallback = new LambdaCallbacks();
+ * batchPrintCallback.setOnTrainBatchBegin((batch, logs)->
+ * System.out.println("Batch: " + batch + " started");
+ * }
+ *
+ * This callback is constructed with anonymous functions that will be called at the appropriate
+ * time. Note that the callbacks expects positional arguments, as:
+ *
+ *
+ * onEpochBegin
and onEpochEnd
expect two positional arguments:
+ * epoch
, logs
+ * onBatchBegin
and onBatchEnd
expect two positional arguments:
+ * batch
, logs
+ * onTrainBegin
and onTrainEnd
expect one positional argument:
+ * logs
+ *
+ */
+public class LambdaCallback extends Callback {
+
+ /** Called at the beginning of every epoch. expect two positional arguments: `epoch`, `logs` */
+ private BiConsumer> onEpochBegin;
+
+ /** Called at the end of every epoch. expect two positional arguments: `epoch`, `logs` */
+ private BiConsumer> onEpochEnd;
+
+ /** Called at the beginning of every batch. expect two positional arguments: `batch`, `logs` */
+ private BiConsumer> onTrainBatchBegin;
+
+ /** called at the end of every batch. expect two positional arguments: `batch`, `logs` */
+ private BiConsumer> onTrainBatchEnd;
+
+ /** called at the beginning of model training. expect one positional argument: `logs` */
+ private Consumer