Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to convert Tensor to String representation #268

Open
cowwoc opened this issue Apr 1, 2021 · 3 comments · May be fixed by #272
Open

Ability to convert Tensor to String representation #268

cowwoc opened this issue Apr 1, 2021 · 3 comments · May be fixed by #272

Comments

@cowwoc
Copy link

cowwoc commented Apr 1, 2021

Per our discussion on Gitter, here is a possible implementation for converting Tensors to a String representation. It is still missing some important features, like collapsing long arrays using ellipses, but this can serve as a stepping stone. The functionality is meant to ease troubleshooting/debugging so performance should not be an issue.

import org.tensorflow.Session;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.buffer.DoubleDataBuffer;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.ndarray.buffer.IntDataBuffer;
import org.tensorflow.ndarray.buffer.LongDataBuffer;
import org.tensorflow.ndarray.buffer.ShortDataBuffer;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;

import java.util.StringJoiner;

public final class Tensors
{
	private final Session session;

	/**
	 * @param session the session used by all operations
	 */
	public Tensors(Session session)
	{
		this.session = session;
	}

	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TFloat64 tensor)
	{
		Shape shape = tensor.shape();
		DoubleDataBuffer doubles = tensor.asRawTensor().data().asDoubles();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}

	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TFloat32 tensor)
	{
		Shape shape = tensor.shape();
		FloatDataBuffer doubles = tensor.asRawTensor().data().asFloats();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}

	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TFloat16 tensor)
	{
		Shape shape = tensor.shape();
		FloatDataBuffer doubles = tensor.asRawTensor().data().asFloats();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}

	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TInt64 tensor)
	{
		Shape shape = tensor.shape();
		LongDataBuffer doubles = tensor.asRawTensor().data().asLongs();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}

	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TInt32 tensor)
	{
		Shape shape = tensor.shape();
		IntDataBuffer doubles = tensor.asRawTensor().data().asInts();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}

	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TUint8 tensor)
	{
		Shape shape = tensor.shape();
		ShortDataBuffer doubles = tensor.asRawTensor().data().asShorts();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}

	/**
	 * @param data      the data
	 * @param shape     the shape of the tensor
	 * @param index     the index of the tensor element to start at
	 * @param dimension the current dimension
	 * @param rank      the maximum dimension
	 * @return the String representation of the {@code dimension}
	 */
	private ToStringResponse toString(DataBuffer<?> data, Shape shape, int index, int dimension, int rank)
	{
		int numElements = 0;
		StringJoiner joiner;
		if (dimension < rank)
		{
			joiner = new StringJoiner(",\n", "\t".repeat(dimension) + "[\n", "\n" + "\t".repeat(dimension) + "]");
			for (long i = 0, size = shape.size(rank - 1); i < size; ++i)
			{
				ToStringResponse response = toString(data, shape, index, dimension + 1, rank);
				joiner.add(response.text);
				numElements += response.numElements;
				index += response.numElements;
			}
		}
		else
		{
			joiner = new StringJoiner(",", "\t".repeat(dimension) + "[", "]");
			for (long i = 0, size = shape.size(rank - 1); i < size; ++i)
			{
				joiner.add(String.valueOf(data.getObject(index)));
				++numElements;
				++index;
			}
		}
		return new ToStringResponse(joiner.toString(), numElements);
	}

	/**
	 * @param text        the string representation of a tensor dimension
	 * @param numElements the number of elements contained in {@code text}
	 */
	private record ToStringResponse(String text, int numElements)
	{
	}
}
@rnett
Copy link
Contributor

rnett commented Apr 1, 2021

Looks good, are you planning on making a PR for this?

If so, some initial comments:

  • It looks like the session is leftover from a previous version?

  • I think a toString(Tensor t) (or TType, but Tensor is better) w/ resolution done based on the runtime type would be useful, too, since the concrete type isn't always known. Perhaps make this the only method.

  • It would also be good to have this as a default method in Operand too, that only does something in an eager session (like asTensor()). Probably return a N/A message rather than throwing, since it's for debugging.

@cowwoc
Copy link
Author

cowwoc commented Apr 2, 2021

@rnett Good suggestions. I'll try to formulate a PR.

Question though, since this method is meant strictly for debugging, couldn't we implement it for non-eager sessions as well? We could spin up a graph runner, evaluate the operand, and return the String representation.

@rnett
Copy link
Contributor

rnett commented Apr 2, 2021

You could, but if that tensor depends on anything non-constant (i.e. placeholders or variables), you won't be able to get it, since the session has no way of knowing about those inputs. And I would think most things you want to debug would have dependencies like that. Plus I'm not sure sessions support adding things to the graph after the session is created, and you'd have to re-run the whole graph each time you called asString.

Once we finish functions and eager gradients, most debugging should be done in eager mode anyways. You'd almost always use functions instead of graphs, and there would be a global "execute functions in eager mode" like in Python. It's not necessarily impossible to have asString in graph mode, but it's not easy and won't ever fit very well, and since we have this coming I don't think it's worth it. Feel free to come up with an implementation and make a PR, but maybe PR just the eager version first.

@cowwoc cowwoc linked a pull request Apr 3, 2021 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants