Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callbacks phase 1 #299

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c57a2e7
Merge pull request #3 from tensorflow/master
JimClarke5 Oct 8, 2020
09fc07e
Merge pull request #4 from tensorflow/master
JimClarke5 Oct 27, 2020
a99dcb4
Merge pull request #5 from tensorflow/master
JimClarke5 Nov 17, 2020
ba294ea
Merge pull request #6 from tensorflow/master
JimClarke5 Nov 19, 2020
04f419a
Merge pull request #7 from tensorflow/master
JimClarke5 Dec 30, 2020
02e7ebf
Merge pull request #8 from tensorflow/master
JimClarke5 Jan 29, 2021
e0c9ed8
Merge pull request #9 from tensorflow/master
JimClarke5 Feb 1, 2021
5b0374b
Merge pull request #10 from tensorflow/master
JimClarke5 Feb 11, 2021
e038bbd
Merge pull request #11 from tensorflow/master
JimClarke5 Feb 23, 2021
def3051
Merge pull request #13 from tensorflow/master
JimClarke5 Mar 3, 2021
11748ae
Merge pull request #15 from tensorflow/master
JimClarke5 Mar 21, 2021
a9412ea
Merge pull request #16 from tensorflow/master
JimClarke5 Apr 9, 2021
2ff8dfe
Merge pull request #17 from tensorflow/master
JimClarke5 Apr 22, 2021
df56f1d
Initial checkin
JimClarke5 Apr 25, 2021
ee5e38a
Merge pull request #18 from tensorflow/master
JimClarke5 May 1, 2021
26394d6
Merge pull request #19 from tensorflow/master
JimClarke5 May 2, 2021
9dcddcd
Initial checkin
JimClarke5 Apr 25, 2021
ab2e304
Merge remote-tracking branch 'origin/Callbacks_Phase_1' into Callback…
JimClarke5 May 2, 2021
9efff83
Added missing methods in Lambda Callback
JimClarke5 May 6, 2021
6aa84eb
Remove unused class, this is part of ProgressBar
JimClarke5 May 6, 2021
4b3bb7c
Fixes from PR Comments
JimClarke5 May 6, 2021
66a4bdd
Change tmoFile to deleteOnExit, remove finally block
JimClarke5 May 7, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tensorflow-framework/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-api</artifactId>
<version>${project.version}</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-csv -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-csv</artifactId>
<version>1.8</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.callbacks;

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.types.family.TNumber;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.StringJoiner;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;

/**
* Callback that streams epoch results to a CSV file.
*
* <p>Supports all values that can be represented as a string
*
* @param <T> the data type for the weights in the model
*/
public class CSVLogger<T extends TNumber> extends Callback implements AutoCloseable {

public static final char DEFAULT_SEPARATOR = ',';
public static final boolean DEFAULT_APPEND = false;

private final File file;
private final char separator;
private final boolean append;
private List<String> keys;
private boolean appendHeader = true;

private CSVPrinter writer;

/**
* Creates a CSVLogger callback using {@link #DEFAULT_SEPARATOR} to separate elements in the csv
* file, and {@link #DEFAULT_APPEND} for the append value.
*
* @param file the csv file
*/
public CSVLogger(File file) {
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
this(file, DEFAULT_SEPARATOR, DEFAULT_APPEND);
}

/**
* Creates a CSVLogger callback using {@link #DEFAULT_SEPARATOR} to separate elements in the csv
* file, and {@link #DEFAULT_APPEND} for the append value.
*
* @param filename filename of the csv file
*/
public CSVLogger(String filename) {
this(new File(filename), DEFAULT_SEPARATOR, DEFAULT_APPEND);
}

/**
* Creates a CSVLogger callback using {@link #DEFAULT_APPEND} for the append value.
*
* @param file the csv file
* @param separator string used to separate elements in the csv file.
*/
public CSVLogger(File file, char separator) {
this(file, separator, false);
}

/**
* Creates a CSVLogger callback using {@link #DEFAULT_APPEND} for the append value.
*
* @param filename filename of the csv file
* @param separator string used to separate elements in the csv file.
*/
public CSVLogger(String filename, char separator) {
this(new File(filename), separator, false);
}

/**
* Creates a CSVLogger callback.
*
* @param filename filename of the csv file
* @param separator the character used to separate elements in the csv file.
* @param append if true, append if file exists (useful for continuing training). if false,
* overwrite existing file.
*/
public CSVLogger(String filename, char separator, boolean append) {
this(new File(filename), separator, append);
}

/**
* Creates a CSVLogger callback.
*
* @param file the csv file
* @param separator the character used to separate elements in the csv file.
* @param append if true, append if file exists (useful for continuing training). if false,
* overwrite existing file.
*/
public CSVLogger(File file, char separator, boolean append) {
this.file = file;
this.separator = separator;
this.append = append;
}

/** {@inheritDoc} */
@Override
public void onTrainBegin(Map<String, Number> logs) {
appendHeader = !append || !file.exists();
}

// TODO Should we handle Java arrays??
@SuppressWarnings("unchecked")
private String handleValue(Object val) {

if (val instanceof String) {
return val.toString();
} else if (val instanceof NdArray) { // todo
boolean isScalar = ((NdArray<?>) val).rank() == 0;
if (isScalar) {
return ((NdArray<?>) val).getObject().toString();
} else {
NdArray<?> array = (NdArray<T>) val;
return ndArrayToString(array);
}
} else if (val instanceof Collection) {
return "["
+ ((Collection<T>) val).stream().map(Object::toString).collect(Collectors.joining(","))
+ "]";
} else {
return val.toString();
}
}

/**
* coverts an NdArray to a printable string
*
* @param ndArray the NdArray
* @return the printable string
*/
private String ndArrayToString(NdArray<?> ndArray) {
Iterator<? extends NdArray<?>> iterator = ndArray.scalars().iterator();
Shape shape = ndArray.shape();
if (shape.numDimensions() == 0) {
if (!iterator.hasNext()) {
return "";
}
return valToString(iterator.next().getObject());
}
return ndArrayToString(iterator, shape, 0);
}

/**
* coverts an NdArray iterator to a printable string
*
* @param iterator the NdArray iterator
* @param shape the shape of the NdArray item
* @param dimension the dimension within the overall NDArray tree
* @return the printable string
*/
private String ndArrayToString(Iterator<? extends NdArray<?>> iterator, Shape shape, int dimension) {
if (dimension < shape.numDimensions() - 1) {
StringJoiner joiner = new StringJoiner("", "[", "]");
for (long i = 0, size = shape.size(dimension); i < size; ++i) {
String element = ndArrayToString(iterator, shape, dimension + 1);
joiner.add(element);
}
return joiner.toString();
} else {
StringJoiner joiner = new StringJoiner(", ", "[", "]");
for (long i = 0, size = shape.size(dimension); i < size; ++i) {
Object element = iterator.next().getObject();
joiner.add(valToString(element));
}
return joiner.toString();
}
}

/**
* Converts a value to a printable string
*
* @param val the value
* @return the printable string
*/
private String valToString(Object val) {
if (val instanceof Number) {
Number nVal = (Number) val;
if (nVal instanceof Float || nVal instanceof Double) {
return String.format("%e", nVal.doubleValue());
} else if (nVal instanceof Byte) {
return String.format("0x%2x", nVal.byteValue());
} else {
return String.format("%d", nVal.longValue());
}
} else {
return val.toString();
}
}

/** {@inheritDoc} */
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
@Override
@SuppressWarnings("unchecked")
public void onEpochEnd(int epoch, Map<String, Number> logs) {
logs = logs == null ? Collections.EMPTY_MAP : logs;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Collections.emptyMap() rather than Collections.EMPTY_MAP. I think that might let you remove the warning suppression?


if (keys == null) {
keys = new ArrayList<>(logs.keySet());
Collections.sort(this.keys);
}

if (writer == null) {
try {
List<String> fieldNames = new ArrayList<>();
fieldNames.add("epoch");
fieldNames.addAll(this.keys);
CSVFormat csvFormat =
appendHeader
? CSVFormat.EXCEL
.withHeader(fieldNames.toArray(new String[0]))
.withDelimiter(separator)
: CSVFormat.EXCEL.withDelimiter(separator);
writer = new CSVPrinter(new FileWriter(file, append), csvFormat);
} catch (IOException ex) {
Logger.getLogger(CSVLogger.class.getName()).log(Level.SEVERE, null, ex);
return;
}
}

/* TODO include when integrated with Model
if (getModel().isStopTraining()) {
final Map<String, Number> flogs = logs;
keys.forEach(
key -> {
if (!flogs.containsKey(key)) {
flogs.put(key, Double.NaN);
}
});
}
*/
try {
final List<String> values = new ArrayList<>();
final Map<String, Number> logsFinal = logs;
values.add(String.valueOf(epoch));
keys.forEach(key -> values.add(handleValue(logsFinal.get(key))));
writer.printRecord(values);
writer.flush();
} catch (IOException ex) {
Logger.getLogger(CSVLogger.class.getName()).log(Level.SEVERE, null, ex);
}
}

/** {@inheritDoc} */
@Override
public void close() throws IOException {
if (writer != null) {
writer.close();
writer = null;
}
}
}
Loading