Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 95 additions & 6 deletions src/TorchSharp/Tensor/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System;
using System.Linq;
using System.Collections.Generic;
using System.ComponentModel;
using System.Globalization;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -1531,7 +1532,7 @@ public Tensor reshape(params long[] shape)
static extern IntPtr THSTensor_flatten(IntPtr tensor, long start, long end);

/// <summary>
/// Flattens input by reshaping it into a one-dimensional tensor.
/// Flattens input by reshaping it into a one-dimensional tensor.
/// </summary>
/// <param name="start_dim">The first dim to flatten</param>
/// <param name="end_dim">The last dim to flatten.</param>
Expand Down Expand Up @@ -5431,10 +5432,32 @@ public static implicit operator Tensor(Scalar scalar)
// Specifically added to make F# look good.
public static Tensor op_MinusMinusGreater(Tensor t, torch.nn.Module m) => m.forward(t);

public static TensorStringStyle DefaultOutputStyle = TensorStringStyle.Metadata;

public override string ToString() => ToString(DefaultOutputStyle);

/// <summary>
/// Tensor-specific ToString()
/// </summary>
/// <param name="style">
/// The style to use -- either 'metadata,' 'julia,' or 'numpy'
/// </param>
/// <param name="fltFormat">The floating point format to use for each individual number.</param>
/// <param name="width">The line width to enforce</param>
/// <param name="cultureInfo">The culture, which affects how numbers are formatted.</param>
/// <returns></returns>
public string ToString(TensorStringStyle style, string fltFormat = "g5", int width = 100,
CultureInfo? cultureInfo = null) => style switch {
TensorStringStyle.Metadata => ToMetadataString(),
TensorStringStyle.Julia => ToJuliaString(fltFormat, width, cultureInfo),
TensorStringStyle.Numpy => ToNumpyString(this, ndim, true, fltFormat, cultureInfo),
_ => throw new InvalidEnumArgumentException("Not supported type")
};

/// <summary>
/// Get a string representation of the tensor.
/// </summary>
public override string ToString()
private string ToMetadataString()
{
if (Handle == IntPtr.Zero) return "";

Expand All @@ -5457,20 +5480,86 @@ public override string ToString()
return sb.ToString();
}

private static string ToNumpyString(Tensor t, long mdim, bool isFCreate, string fltFormat, CultureInfo? cultureInfo)
{
var actualCulturInfo = cultureInfo ?? CultureInfo.CurrentCulture;

var dim = t.dim();
if (t.size().Length == 0) return "";
var sb = new StringBuilder(isFCreate ? string.Join("", Enumerable.Repeat(' ', (int) (mdim - dim))) : "");
sb.Append('[');
var currentSize = t.size()[0];
if (dim == 1) {
if (currentSize <= 6) {
for (var i = 0; i < currentSize - 1; i++) {
PrintValue(sb, t.dtype, t[i].ToScalar(), fltFormat, actualCulturInfo);
sb.Append(' ');
}

PrintValue(sb, t.dtype, t[currentSize - 1].ToScalar(), fltFormat, actualCulturInfo);
} else {
for (var i = 0; i < 3; i++) {
PrintValue(sb, t.dtype, t[i].ToScalar(), fltFormat, actualCulturInfo);
sb.Append(' ');
}

sb.Append("... ");

for (var i = currentSize - 3; i < currentSize - 1; i++) {
PrintValue(sb, t.dtype, t[i].ToScalar(), fltFormat, actualCulturInfo);
sb.Append(' ');
}

PrintValue(sb, t.dtype, t[currentSize - 1].ToScalar(), fltFormat, actualCulturInfo);
}
} else {
var newline = string.Join("", Enumerable.Repeat(Environment.NewLine, (int) dim - 1).ToList());
if (currentSize <= 6) {
sb.Append(ToNumpyString(t[0], mdim, false, fltFormat, cultureInfo));
sb.Append(newline);
for (var i = 1; i < currentSize - 1; i++) {
sb.Append(ToNumpyString(t[i], mdim, true, fltFormat, cultureInfo));
sb.Append(newline);
}

sb.Append(ToNumpyString(t[currentSize - 1], mdim, true, fltFormat, cultureInfo));
} else {
sb.Append(ToNumpyString(t[0], mdim, false, fltFormat, cultureInfo));
sb.Append(newline);
for (var i = 1; i < 3; i++) {
sb.Append(ToNumpyString(t[i], mdim, true, fltFormat, cultureInfo));
sb.Append(newline);
}

sb.Append(string.Join("", Enumerable.Repeat(' ', (int) (mdim - dim))));
sb.Append(" ...");
sb.Append(newline);

for (var i = currentSize - 3; i < currentSize - 1; i++) {
sb.Append(ToNumpyString(t[i], mdim, true, fltFormat, cultureInfo));
sb.Append(newline);
}

sb.Append(ToNumpyString(t[currentSize - 1], mdim, true, fltFormat, cultureInfo));
}
}

sb.Append("]");
return sb.ToString();
}

/// <summary>
/// Get a verbose string representation of a tensor.
/// </summary>
/// <param name="withData">Boolean, used to discriminate.</param>
/// <param name="fltFormat">The format string to use for floating point values.</param>
/// <param name="width">The width of each line of the output string.</param>
/// <param name="cultureInfo">The CulturInfo to use when formatting the text</param>

public string ToString(bool withData, string fltFormat = "g5", int width = 100, CultureInfo? cultureInfo = null)
private string ToJuliaString(string fltFormat = "g5", int width = 100, CultureInfo? cultureInfo = null)
{
var actualCulturInfo = cultureInfo ?? CultureInfo.CurrentCulture;
if (!withData) return this.ToString();

var builder = new StringBuilder(this.ToString());
var builder = new StringBuilder(this.ToMetadataString());

if (Dimensions == 0) {

Expand Down
27 changes: 21 additions & 6 deletions src/TorchSharp/Tensor/TensorExtensionMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ namespace TorchSharp
{
using static torch;

public enum TensorStringStyle
{
Metadata,
Julia,
Numpy
}

/// <summary>
/// A few extensions to the Tensor type.
/// </summary>
Expand All @@ -31,30 +38,38 @@ public static Modules.Parameter AsParameter(this Tensor tensor)
/// Get a string representation of the tensor.
/// </summary>
/// <param name="tensor">The input tensor.</param>
/// <param name="style">
/// The style to use -- either 'metadata,' 'julia,' or 'numpy'
/// </param>
/// <param name="fltFormat">The format string to use for floating point values.</param>
/// <param name="width">The width of each line of the output string.</param>
/// <returns></returns>
/// <remarks>
/// This method does exactly the same as ToString(bool, string, int), but is shorter,
/// looks more like Python 'str' and doesn't require a boolean argument 'true' in order
/// looks more like Python 'str' and doesn't require a style argument in order
/// to discriminate.
///
/// Primarily intended for use in interactive notebooks.
/// </remarks>
public static string str(this Tensor tensor, string fltFormat = "g5", int width = 100)
public static string str(this Tensor tensor, TensorStringStyle style = TensorStringStyle.Julia, string fltFormat = "g5", int width = 100)
{
return tensor.ToString(true, fltFormat, width);
return tensor.ToString(style, fltFormat, width);
}

/// <summary>
/// Uses Console.WriteLine to print a tensor expression on stdout. This is intended for
/// .NET Interactive notebook use, primarily.
/// interactive notebook use, primarily.
/// </summary>
/// <param name="t">The input tensor.</param>
/// <param name="style">
/// The style to use -- either 'metadata,' 'julia,' or 'numpy'
/// </param>
/// <param name="fltFormat">The format string to use for floating point values.</param>
/// <param name="width">The width of each line of the output string.</param>
/// <returns></returns>
public static Tensor print(this Tensor t, string fltFormat = "g5", int width = 100)
public static Tensor print(this Tensor t, TensorStringStyle style = TensorStringStyle.Julia, string fltFormat = "g5", int width = 100)
{
Console.WriteLine(t.str());
Console.WriteLine(t.str(style, fltFormat, width));
return t;
}

Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/TorchVision/Functional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -898,8 +898,8 @@ private static IList<float> GetPerspectiveCoefficients(IList<IList<int>> startpo

using var b_matrix = torch.tensor(startpoints.SelectMany(sp => sp).ToArray(), dtype: torch.float32).view(8);

var a_str = a_matrix.ToString(true);
var b_str = b_matrix.ToString(true);
var a_str = a_matrix.ToString(TensorStringStyle.Julia);
var b_str = b_matrix.ToString(TensorStringStyle.Julia);

var t0 = torch.linalg.lstsq(a_matrix, b_matrix);

Expand Down
7 changes: 6 additions & 1 deletion src/TorchSharp/Utils/Decompress.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
using ICSharpCode.SharpZipLib.GZip;
using ICSharpCode.SharpZipLib.Tar;

//NOTE: This code was inspired by code found int the SciSharpStack-Examples repository.
// This code was inspired by code found in the SciSharpStack-Examples repository located at:
//
// https://github.com/SciSharp/SciSharp-Stack-Examples
//
// Original License: https://github.com/SciSharp/SciSharp-Stack-Examples/blob/master/LICENSE
//
// Original copyright information was not found at the above location.

namespace TorchSharp.Utils
{
Expand Down
2 changes: 1 addition & 1 deletion test/TorchSharpTest/TestDataLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public void DataLoaderTest1()
Assert.True(iterator.MoveNext());
Assert.Equal(iterator.Current["data"], torch.tensor(rawArray: new[]{1L, 1L}, dimensions: new[]{2L}, dtype: torch.ScalarType.Int32));
Assert.Equal(iterator.Current["label"], torch.tensor(rawArray: new[]{13L, 13L}, dimensions: new[]{2L}, dtype: torch.ScalarType.Int32));
Assert.Equal(iterator.Current["index"].ToString(true), torch.tensor(rawArray: new[]{0L, 1L}, dimensions: new[]{2L}, dtype: torch.ScalarType.Int64).ToString(true));
Assert.Equal(iterator.Current["index"].ToString(TensorStringStyle.Julia), torch.tensor(rawArray: new[]{0L, 1L}, dimensions: new[]{2L}, dtype: torch.ScalarType.Int64).ToString(TensorStringStyle.Julia));
iterator.Dispose();
}

Expand Down
Loading