- 
        Couldn't load subscription status. 
- Fork 219
Open
Description
Per our discussion on Gitter, here is a possible implementation for converting Tensors to a String representation. It is still missing some important features, like collapsing long arrays using ellipses, but this can serve as a stepping stone. The functionality is meant to ease troubleshooting/debugging so performance should not be an issue.
import org.tensorflow.Session;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.buffer.DoubleDataBuffer;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.ndarray.buffer.IntDataBuffer;
import org.tensorflow.ndarray.buffer.LongDataBuffer;
import org.tensorflow.ndarray.buffer.ShortDataBuffer;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;
import java.util.StringJoiner;
public final class Tensors
{
	private final Session session;
	/**
	 * @param session the session used by all operations
	 */
	public Tensors(Session session)
	{
		this.session = session;
	}
	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TFloat64 tensor)
	{
		Shape shape = tensor.shape();
		DoubleDataBuffer doubles = tensor.asRawTensor().data().asDoubles();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}
	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TFloat32 tensor)
	{
		Shape shape = tensor.shape();
		FloatDataBuffer doubles = tensor.asRawTensor().data().asFloats();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}
	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TFloat16 tensor)
	{
		Shape shape = tensor.shape();
		FloatDataBuffer doubles = tensor.asRawTensor().data().asFloats();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}
	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TInt64 tensor)
	{
		Shape shape = tensor.shape();
		LongDataBuffer doubles = tensor.asRawTensor().data().asLongs();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}
	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TInt32 tensor)
	{
		Shape shape = tensor.shape();
		IntDataBuffer doubles = tensor.asRawTensor().data().asInts();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}
	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TUint8 tensor)
	{
		Shape shape = tensor.shape();
		ShortDataBuffer doubles = tensor.asRawTensor().data().asShorts();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}
	/**
	 * @param data      the data
	 * @param shape     the shape of the tensor
	 * @param index     the index of the tensor element to start at
	 * @param dimension the current dimension
	 * @param rank      the maximum dimension
	 * @return the String representation of the {@code dimension}
	 */
	private ToStringResponse toString(DataBuffer<?> data, Shape shape, int index, int dimension, int rank)
	{
		int numElements = 0;
		StringJoiner joiner;
		if (dimension < rank)
		{
			joiner = new StringJoiner(",\n", "\t".repeat(dimension) + "[\n", "\n" + "\t".repeat(dimension) + "]");
			for (long i = 0, size = shape.size(rank - 1); i < size; ++i)
			{
				ToStringResponse response = toString(data, shape, index, dimension + 1, rank);
				joiner.add(response.text);
				numElements += response.numElements;
				index += response.numElements;
			}
		}
		else
		{
			joiner = new StringJoiner(",", "\t".repeat(dimension) + "[", "]");
			for (long i = 0, size = shape.size(rank - 1); i < size; ++i)
			{
				joiner.add(String.valueOf(data.getObject(index)));
				++numElements;
				++index;
			}
		}
		return new ToStringResponse(joiner.toString(), numElements);
	}
	/**
	 * @param text        the string representation of a tensor dimension
	 * @param numElements the number of elements contained in {@code text}
	 */
	private record ToStringResponse(String text, int numElements)
	{
	}
}
Metadata
Metadata
Assignees
Labels
No labels