Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/TorchSharp/Autograd.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ public static IList<Tensor> grad(IList<Tensor> outputs, IList<Tensor> inputs, IL
using var grads = new PinnedArray<IntPtr>();
using var results = new PinnedArray<IntPtr>();

IntPtr outsRef = outs.CreateArray(outputs.Select(p => p.Handle).ToArray());
IntPtr insRef = ins.CreateArray(inputs.Select(p => p.Handle).ToArray());
IntPtr gradsRef = grad_outputs == null ? IntPtr.Zero : grads.CreateArray(grad_outputs.Select(p => p.Handle).ToArray());
IntPtr outsRef = outs.CreateArray(outputs.ToHandleArray());
IntPtr insRef = ins.CreateArray(inputs.ToHandleArray());
IntPtr gradsRef = grad_outputs == null ? IntPtr.Zero : grads.CreateArray(grad_outputs.ToHandleArray());
long gradsLength = grad_outputs == null ? 0 : grads.Array.Length;

THSAutograd_grad(outsRef, outs.Array.Length, insRef, ins.Array.Length, gradsRef, gradsLength, retain_graph, create_graph, allow_unused, results.CreateArray);
Expand Down Expand Up @@ -178,9 +178,9 @@ public static void backward(IList<Tensor> tensors, IList<Tensor> grad_tensors =
using var ts = new PinnedArray<IntPtr>();
using var gts = new PinnedArray<IntPtr>();
using var ins = new PinnedArray<IntPtr>();
IntPtr tensRef = ts.CreateArray(tensors.Select(p => p.Handle).ToArray());
IntPtr gradsRef = grad_tensors == null ? IntPtr.Zero : gts.CreateArray(grad_tensors.Select(p => p.Handle).ToArray());
IntPtr insRef = inputs == null ? IntPtr.Zero : ins.CreateArray(inputs.Select(p => p.Handle).ToArray());
IntPtr tensRef = ts.CreateArray(tensors.ToHandleArray());
IntPtr gradsRef = grad_tensors == null ? IntPtr.Zero : gts.CreateArray(grad_tensors.ToHandleArray());
IntPtr insRef = inputs == null ? IntPtr.Zero : ins.CreateArray(inputs.ToHandleArray());
long insLength = inputs == null ? 0 : ins.Array.Length;
long gradsLength = grad_tensors == null ? 0 : gts.Array.Length;

Expand Down
10 changes: 5 additions & 5 deletions src/TorchSharp/AutogradFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ internal List<Tensor> ComputeVariableInput(object[] args)
internal void SetNextEdges(List<Tensor> inputVars, bool isExecutable)
{
using var l = new PinnedArray<IntPtr>();
THSAutograd_CSharpNode_setNextEdges(handle, l.CreateArrayWithSize(inputVars.Select(v => v.Handle).ToArray()), isExecutable);
THSAutograd_CSharpNode_setNextEdges(handle, l.CreateArrayWithSize(inputVars.ToHandleArray()), isExecutable);
CheckForErrors();
}

Expand All @@ -166,10 +166,10 @@ internal List<Tensor> WrapOutputs(List<Tensor> inputVars, List<Tensor> outputs,
using var outputArr = new PinnedArray<IntPtr>();
using var resultsArr = new PinnedArray<IntPtr>();

var varsPtr = varsArr.CreateArrayWithSize(inputVars.Select(v => v.Handle).ToArray());
var diffsPtr = diffArr.CreateArrayWithSize(_context.NonDifferentiableTensors.Select(v => v.Handle).ToArray());
var dirtyPtr = diffArr.CreateArrayWithSize(_context.DirtyTensors.Select(v => v.Handle).ToArray());
var outputPtr = outputArr.CreateArrayWithSize(outputs.Select(v => v.Handle).ToArray());
var varsPtr = varsArr.CreateArrayWithSize(inputVars.ToHandleArray());
var diffsPtr = diffArr.CreateArrayWithSize(_context.NonDifferentiableTensors.ToHandleArray());
var dirtyPtr = diffArr.CreateArrayWithSize(_context.DirtyTensors.ToHandleArray());
var outputPtr = outputArr.CreateArrayWithSize(outputs.ToHandleArray());

THSAutograd_Function_wrapOutputs(varsPtr, diffsPtr, dirtyPtr, outputPtr, isExecutable ? handle : new(), resultsArr.CreateArray);
CheckForErrors();
Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/LinearAlgebra.cs
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ public static Tensor multi_dot(IList<Tensor> tensors)
}

using (var parray = new PinnedArray<IntPtr>()) {
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());
var res = THSLinalg_multi_dot(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero)
torch.CheckForErrors();
Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/NN/Utils/RNNUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public static (torch.Tensor, torch.Tensor) pad_packed_sequence(PackedSequence se
/// <returns>The padded tensor</returns>
public static torch.Tensor pad_sequence(IEnumerable<torch.Tensor> sequences, bool batch_first = false, double padding_value = 0.0)
{
var sequences_arg = sequences.Select(p => p.Handle).ToArray();
var sequences_arg = sequences.ToHandleArray();
var res = THSNN_pad_sequence(sequences_arg, sequences_arg.Length, batch_first, padding_value);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new torch.Tensor(res);
Expand All @@ -69,7 +69,7 @@ public static torch.Tensor pad_sequence(IEnumerable<torch.Tensor> sequences, boo
/// <returns>The packed batch of variable length sequences</returns>
public static PackedSequence pack_sequence(IEnumerable<torch.Tensor> sequences, bool enforce_sorted = true)
{
var sequences_arg = sequences.Select(p => p.Handle).ToArray();
var sequences_arg = sequences.ToHandleArray();
var res = THSNN_pack_sequence(sequences_arg, sequences_arg.Length, enforce_sorted);
if (res.IsInvalid) { torch.CheckForErrors(); }
return new PackedSequence(res);
Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/Optimizers/LBFGS.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public static LBFGS LBFGS(IEnumerable<Parameter> parameters, double lr = 0.01, l
if (!max_eval.HasValue) max_eval = 5 * max_iter / 4;

using var parray = new PinnedArray<IntPtr>();
IntPtr paramsRef = parray.CreateArray(parameters.Select(p => p.Handle).ToArray());
IntPtr paramsRef = parray.CreateArray(parameters.ToHandleArray());

var res = THSNN_LBFGS_ctor(paramsRef, parray.Array.Length, lr, max_iter, max_eval.Value, tolerange_grad, tolerance_change, history_size);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
Expand Down
142 changes: 112 additions & 30 deletions src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,39 @@ public static Tensor cat(IList<Tensor> tensors, long dim = 0)
}

using var parray = new PinnedArray<IntPtr>();
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());

var res = THSTensor_cat(tensorsRef, parray.Array.Length, dim);
if (res == IntPtr.Zero) CheckForErrors();
return new Tensor(res);
}

// https://pytorch.org/docs/stable/generated/torch.cat
/// <summary>
/// Concatenates the given sequence of tensors in the given dimension.
/// </summary>
/// <param name="tensors">A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.</param>
/// <param name="dim">The dimension over which the tensors are concatenated</param>
/// <remarks> All tensors must either have the same shape (except in the concatenating dimension) or be empty.</remarks>
public static Tensor cat(Tensor[] tensors, long dim = 0) => torch.cat((ReadOnlySpan<Tensor>)tensors, dim);

// https://pytorch.org/docs/stable/generated/torch.cat
/// <summary>
/// Concatenates the given sequence of tensors in the given dimension.
/// </summary>
/// <param name="tensors">A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.</param>
/// <param name="dim">The dimension over which the tensors are concatenated</param>
/// <remarks> All tensors must either have the same shape (except in the concatenating dimension) or be empty.</remarks>
public static Tensor cat(ReadOnlySpan<Tensor> tensors, long dim = 0)
{
switch (tensors.Length)
{
case <=0: throw new ArgumentException(nameof(tensors));
case 1: return tensors[0].alias();
}

using var parray = new PinnedArray<IntPtr>();
IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());

var res = THSTensor_cat(tensorsRef, parray.Array.Length, dim);
if (res == IntPtr.Zero) CheckForErrors();
Expand All @@ -60,6 +92,24 @@ public static Tensor cat(IList<Tensor> tensors, long dim = 0)
/// <remarks> All tensors must either have the same shape (except in the concatenating dimension) or be empty.</remarks>
public static Tensor concat(IList<Tensor> tensors, long dim = 0) => torch.cat(tensors, dim);

// https://pytorch.org/docs/stable/generated/torch.concat
/// <summary>
/// Alias of torch.cat()
/// </summary>
/// <param name="tensors">A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.</param>
/// <param name="dim">The dimension over which the tensors are concatenated</param>
/// <remarks> All tensors must either have the same shape (except in the concatenating dimension) or be empty.</remarks>
public static Tensor concat(Tensor[] tensors, long dim = 0) => torch.cat(tensors, dim);

// https://pytorch.org/docs/stable/generated/torch.concat
/// <summary>
/// Alias of torch.cat()
/// </summary>
/// <param name="tensors">A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.</param>
/// <param name="dim">The dimension over which the tensors are concatenated</param>
/// <remarks> All tensors must either have the same shape (except in the concatenating dimension) or be empty.</remarks>
public static Tensor concat(ReadOnlySpan<Tensor> tensors, long dim = 0) => torch.cat(tensors, dim);

// https://pytorch.org/docs/stable/generated/torch.conj
/// <summary>
/// Returns a view of input with a flipped conjugate bit. If input has a non-complex dtype, this function just returns input.
Expand Down Expand Up @@ -103,7 +153,7 @@ public static Tensor[] dsplit(Tensor input, (long, long, long, long) indices_or_
/// <returns></returns>
/// <remarks>This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by torch.atleast_3d().</remarks>
public static Tensor dstack(params Tensor[] tensors)
=> dstack((IEnumerable<Tensor>)tensors);
=> dstack(tensors.ToHandleArray());

// https://pytorch.org/docs/stable/generated/torch.dstack
/// <summary>
Expand All @@ -113,31 +163,39 @@ public static Tensor dstack(params Tensor[] tensors)
/// <returns></returns>
/// <remarks>This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by torch.atleast_3d().</remarks>
public static Tensor dstack(IList<Tensor> tensors)
{
using (var parray = new PinnedArray<IntPtr>()) {
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
=> dstack(tensors.ToHandleArray());

var res = THSTensor_dstack(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}
}
// https://pytorch.org/docs/stable/generated/torch.dstack
/// <summary>
/// Stack tensors in sequence depthwise (along third axis).
/// </summary>
/// <param name="tensors"></param>
/// <returns></returns>
/// <remarks>This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by torch.atleast_3d().</remarks>
public static Tensor dstack(ReadOnlySpan<Tensor> tensors)
=> dstack(tensors.ToHandleArray());

// https://pytorch.org/docs/stable/generated/torch.dstack
/// <summary>
/// Stack tensors in sequence depthwise (along third axis).
/// </summary>
/// <param name="tensors"></param>
/// <returns></returns>
/// <remarks>This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by torch.atleast_3d().</remarks>
public static Tensor dstack(IEnumerable<Tensor> tensors)
=> dstack(tensors.ToHandleArray());

static Tensor dstack(IntPtr[] tensors)
{
using var parray = new PinnedArray<IntPtr>();
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
var res = THSTensor_dstack(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
using (var parray = new PinnedArray<IntPtr>()) {
IntPtr tensorsRef = parray.CreateArray(tensors);

var res = THSTensor_dstack(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}
}

// https://pytorch.org/docs/stable/generated/torch.gather
/// <summary>
/// Gathers values along an axis specified by dim.
Expand Down Expand Up @@ -192,14 +250,7 @@ public static Tensor[] hsplit(Tensor input, (long, long, long, long) indices_or_
/// <param name="tensors"></param>
/// <returns></returns>
public static Tensor hstack(IList<Tensor> tensors)
{
using var parray = new PinnedArray<IntPtr>();
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());

var res = THSTensor_hstack(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}
=> hstack(tensors.ToHandleArray());

// https://pytorch.org/docs/stable/generated/torch.hstack
/// <summary>
Expand All @@ -208,9 +259,7 @@ public static Tensor hstack(IList<Tensor> tensors)
/// <param name="tensors"></param>
/// <returns></returns>
public static Tensor hstack(params Tensor[] tensors)
{
return hstack((IEnumerable<Tensor>)tensors);
}
=> hstack(tensors.ToHandleArray());

// https://pytorch.org/docs/stable/generated/torch.hstack
/// <summary>
Expand All @@ -219,9 +268,21 @@ public static Tensor hstack(params Tensor[] tensors)
/// <param name="tensors"></param>
/// <returns></returns>
public static Tensor hstack(IEnumerable<Tensor> tensors)
=> hstack(tensors.ToHandleArray());

// https://pytorch.org/docs/stable/generated/torch.hstack
/// <summary>
/// Stack tensors in sequence horizontally (column wise).
/// </summary>
/// <param name="tensors"></param>
/// <returns></returns>
public static Tensor hstack(ReadOnlySpan<Tensor> tensors)
=> hstack(tensors.ToHandleArray());

static Tensor hstack(IntPtr[] tensors)
{
using var parray = new PinnedArray<IntPtr>();
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
IntPtr tensorsRef = parray.CreateArray(tensors);

var res = THSTensor_hstack(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero) { CheckForErrors(); }
Expand Down Expand Up @@ -474,7 +535,7 @@ public static Tensor[] split(Tensor tensor, long[] split_size_or_sections, long
public static Tensor stack(IEnumerable<Tensor> tensors, long dim = 0)
{
using var parray = new PinnedArray<IntPtr>();
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray());

var res = THSTensor_stack(tensorsRef, parray.Array.Length, dim);
if (res == IntPtr.Zero) { CheckForErrors(); }
Expand Down Expand Up @@ -560,9 +621,30 @@ public static Tensor[] vsplit(Tensor input, long[] indices_or_sections)
/// <param name="tensors"></param>
/// <returns></returns>
public static Tensor vstack(IList<Tensor> tensors)
=> vstack(tensors.ToHandleArray());

// https://pytorch.org/docs/stable/generated/torch.vstack
/// <summary>
/// Stack tensors in sequence vertically (row wise).
/// </summary>
/// <param name="tensors"></param>
/// <returns></returns>
public static Tensor vstack(Tensor[] tensors)
=> vstack(tensors.ToHandleArray());

// https://pytorch.org/docs/stable/generated/torch.vstack
/// <summary>
/// Stack tensors in sequence vertically (row wise).
/// </summary>
/// <param name="tensors"></param>
/// <returns></returns>
public static Tensor vstack(ReadOnlySpan<Tensor> tensors)
=> vstack(tensors.ToHandleArray());

static Tensor vstack(IntPtr[] tensors)
{
using var parray = new PinnedArray<IntPtr>();
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
IntPtr tensorsRef = parray.CreateArray(tensors);

var res = THSTensor_vstack(tensorsRef, parray.Array.Length);
if (res == IntPtr.Zero) { CheckForErrors(); }
Expand Down
Loading