3
3
4
4
using System . Collections . Generic ;
5
5
using System . Linq ;
6
+ using System . Runtime . InteropServices ;
6
7
using Microsoft . VisualBasic ;
7
8
8
9
#pragma warning disable CS8601 // Possible null reference assignment.
@@ -19,7 +20,6 @@ public static partial class Tensor
19
20
/// <param name="lengths">A <see cref="ReadOnlySpan{T}"/> indicating the lengths of each dimension.</param>
20
21
/// <param name="pinned">A <see cref="bool"/> whether the underlying data should be pinned or not.</param>
21
22
public static Tensor < T > Create < T > ( scoped ReadOnlySpan < nint > lengths , bool pinned = false )
22
- where T : IEquatable < T >
23
23
{
24
24
nint linearLength = TensorSpanHelpers . CalculateTotalLength ( lengths ) ;
25
25
T [ ] values = pinned ? GC . AllocateArray < T > ( ( int ) linearLength , pinned ) : ( new T [ linearLength ] ) ;
@@ -33,7 +33,6 @@ public static Tensor<T> Create<T>(scoped ReadOnlySpan<nint> lengths, bool pinned
33
33
/// <param name="strides">A <see cref="ReadOnlySpan{T}"/> indicating the strides of each dimension.</param>
34
34
/// <param name="pinned">A <see cref="bool"/> whether the underlying data should be pinned or not.</param>
35
35
public static Tensor < T > Create < T > ( scoped ReadOnlySpan < nint > lengths , scoped ReadOnlySpan < nint > strides , bool pinned = false )
36
- where T : IEquatable < T >
37
36
{
38
37
nint linearLength = TensorSpanHelpers . CalculateTotalLength ( lengths ) ;
39
38
T [ ] values = pinned ? GC . AllocateArray < T > ( ( int ) linearLength , pinned ) : ( new T [ linearLength ] ) ;
@@ -48,7 +47,7 @@ public static Tensor<T> Create<T>(scoped ReadOnlySpan<nint> lengths, scoped Read
48
47
/// <param name="lengths">A <see cref="ReadOnlySpan{T}"/> indicating the lengths of each dimension.</param>
49
48
/// <exception cref="ArgumentOutOfRangeException"></exception>
50
49
public static Tensor < T > Create < T > ( T [ ] values , scoped ReadOnlySpan < nint > lengths )
51
- where T : IEquatable < T > => Create ( values , lengths , [ ] ) ;
50
+ => Create ( values , lengths , [ ] ) ;
52
51
53
52
/// <summary>
54
53
/// Creates a <see cref="Tensor{T}"/> from the provided <paramref name="values"/>. If the product of the
@@ -60,51 +59,40 @@ public static Tensor<T> Create<T>(T[] values, scoped ReadOnlySpan<nint> lengths)
60
59
/// <param name="isPinned">A <see cref="bool"/> indicating whether the <paramref name="values"/> were pinned or not.</param>
61
60
/// <exception cref="ArgumentOutOfRangeException"></exception>
62
61
public static Tensor < T > Create < T > ( T [ ] values , scoped ReadOnlySpan < nint > lengths , scoped ReadOnlySpan < nint > strides , bool isPinned = false )
63
- where T : IEquatable < T >
64
62
{
65
63
return new Tensor < T > ( values , lengths , strides , isPinned ) ;
66
64
}
67
65
68
66
/// <summary>
69
- /// Creates a <see cref="Tensor{T}"/> and does not initialize it. If <paramref name="pinned"/> is true, the memory will be pinned.
70
- /// </summary>
71
- /// <param name="lengths">A <see cref="ReadOnlySpan{T}"/> indicating the lengths of each dimension.</param>
72
- /// <param name="pinned">A <see cref="bool"/> whether the underlying data should be pinned or not.</param>
73
- public static Tensor < T > CreateUninitialized < T > ( scoped ReadOnlySpan < nint > lengths , bool pinned = false )
74
- where T : IEquatable < T > => CreateUninitialized < T > ( lengths , [ ] , pinned ) ;
75
-
76
-
77
- /// <summary>
78
- /// Creates a <see cref="Tensor{T}"/> and does not initialize it. If <paramref name="pinned"/> is true, the memory will be pinned.
67
+ /// Creates a <see cref="Tensor{T}"/> and initializes it with the data from <paramref name="data"/>.
79
68
/// </summary>
80
- /// <param name="lengths">A <see cref="ReadOnlySpan{T}"/> indicating the lengths of each dimension.</param>
81
- /// <param name="strides">A <see cref="ReadOnlySpan{T}"/> indicating the strides of each dimension.</param>
82
- /// <param name="pinned">A <see cref="bool"/> whether the underlying data should be pinned or not.</param>
83
- public static Tensor < T > CreateUninitialized < T > ( scoped ReadOnlySpan < nint > lengths , scoped ReadOnlySpan < nint > strides , bool pinned = false )
84
- where T : IEquatable < T >
69
+ /// <param name="data">A <see cref="IEnumerable{T}"/> with the data to use for the initialization.</param>
70
+ /// <param name="lengths"></param>
71
+ public static Tensor < T > Create < T > ( IEnumerable < T > data , scoped ReadOnlySpan < nint > lengths )
85
72
{
86
- nint linearLength = TensorSpanHelpers . CalculateTotalLength ( lengths ) ;
87
- T [ ] values = GC . AllocateUninitializedArray < T > ( ( int ) linearLength , pinned ) ;
88
- return new Tensor < T > ( values , lengths , strides , pinned ) ;
73
+ T [ ] values = data . ToArray ( ) ;
74
+ return new Tensor < T > ( values , lengths . IsEmpty ? [ values . Length ] : lengths , false ) ;
89
75
}
90
76
91
77
/// <summary>
92
78
/// Creates a <see cref="Tensor{T}"/> and initializes it with the data from <paramref name="data"/>.
93
79
/// </summary>
94
80
/// <param name="data">A <see cref="IEnumerable{T}"/> with the data to use for the initialization.</param>
95
- public static Tensor < T > CreateFromEnumerable < T > ( IEnumerable < T > data )
96
- where T : IEquatable < T > , IEqualityOperators < T , T , bool >
81
+ /// <param name="lengths"></param>
82
+ /// <param name="strides"></param>
83
+ /// <param name="isPinned"></param>
84
+ public static Tensor < T > Create < T > ( IEnumerable < T > data , scoped ReadOnlySpan < nint > lengths , scoped ReadOnlySpan < nint > strides , bool isPinned = false )
97
85
{
98
86
T [ ] values = data . ToArray ( ) ;
99
- return new Tensor < T > ( values , [ values . Length ] , false ) ;
87
+ return new Tensor < T > ( values , lengths . IsEmpty ? [ values . Length ] : lengths , strides , isPinned ) ;
100
88
}
101
89
102
90
/// <summary>
103
91
/// Creates a <see cref="Tensor{T}"/> and initializes it with random data uniformly distributed.
104
92
/// </summary>
105
93
/// <param name="lengths">A <see cref="ReadOnlySpan{T}"/> indicating the lengths of each dimension.</param>
106
94
public static Tensor < T > CreateAndFillUniformDistribution < T > ( params scoped ReadOnlySpan < nint > lengths )
107
- where T : IEquatable < T > , IEqualityOperators < T , T , bool > , IFloatingPoint < T >
95
+ where T : IFloatingPoint < T >
108
96
{
109
97
nint linearLength = TensorSpanHelpers . CalculateTotalLength ( lengths ) ;
110
98
T [ ] values = new T [ linearLength ] ;
@@ -121,16 +109,16 @@ public static Tensor<T> CreateAndFillUniformDistribution<T>(params scoped ReadOn
121
109
/// </summary>
122
110
/// <param name="lengths">A <see cref="ReadOnlySpan{T}"/> indicating the lengths of each dimension.</param>
123
111
public static Tensor < T > CreateAndFillGaussianNormalDistribution < T > ( params scoped ReadOnlySpan < nint > lengths )
124
- where T : IEquatable < T > , IEqualityOperators < T , T , bool > , IFloatingPoint < T >
112
+ where T : IFloatingPoint < T >
125
113
{
126
114
nint linearLength = TensorSpanHelpers . CalculateTotalLength ( lengths ) ;
127
115
T [ ] values = new T [ linearLength ] ;
128
- GaussianDistribution ( ref values , linearLength ) ;
116
+ GaussianDistribution < T > ( values , linearLength ) ;
129
117
return new Tensor < T > ( values , lengths , false ) ;
130
118
}
131
119
132
- private static void GaussianDistribution < T > ( ref T [ ] values , nint linearLength )
133
- where T : IEquatable < T > , IEqualityOperators < T , T , bool > , IFloatingPoint < T >
120
+ private static void GaussianDistribution < T > ( in Span < T > values , nint linearLength )
121
+ where T : IFloatingPoint < T >
134
122
{
135
123
Random rand = Random . Shared ;
136
124
for ( int i = 0 ; i < linearLength ; i ++ )
@@ -141,5 +129,45 @@ private static void GaussianDistribution<T>(ref T[] values, nint linearLength)
141
129
}
142
130
}
143
131
#endregion
132
+
133
+ /// <summary>
134
+ /// Creates a <see cref="Tensor{T}"/> and does not initialize it. If <paramref name="pinned"/> is true, the memory will be pinned.
135
+ /// </summary>
136
+ /// <param name="lengths">A <see cref="ReadOnlySpan{T}"/> indicating the lengths of each dimension.</param>
137
+ /// <param name="pinned">A <see cref="bool"/> whether the underlying data should be pinned or not.</param>
138
+ public static Tensor < T > CreateUninitialized < T > ( scoped ReadOnlySpan < nint > lengths , bool pinned = false )
139
+ => CreateUninitialized < T > ( lengths , [ ] , pinned ) ;
140
+
141
+ /// <summary>
142
+ /// Creates a <see cref="Tensor{T}"/> and does not initialize it. If <paramref name="pinned"/> is true, the memory will be pinned.
143
+ /// </summary>
144
+ /// <param name="lengths">A <see cref="ReadOnlySpan{T}"/> indicating the lengths of each dimension.</param>
145
+ /// <param name="strides">A <see cref="ReadOnlySpan{T}"/> indicating the strides of each dimension.</param>
146
+ /// <param name="pinned">A <see cref="bool"/> whether the underlying data should be pinned or not.</param>
147
+ public static Tensor < T > CreateUninitialized < T > ( scoped ReadOnlySpan < nint > lengths , scoped ReadOnlySpan < nint > strides , bool pinned = false )
148
+ {
149
+ nint linearLength = TensorSpanHelpers . CalculateTotalLength ( lengths ) ;
150
+ T [ ] values = GC . AllocateUninitializedArray < T > ( ( int ) linearLength , pinned ) ;
151
+ return new Tensor < T > ( values , lengths , strides , pinned ) ;
152
+ }
153
+
154
+ public static ref readonly TensorSpan < T > FillGaussianNormalDistribution < T > ( in TensorSpan < T > destination ) where T : IFloatingPoint < T >
155
+ {
156
+ Span < T > span = MemoryMarshal . CreateSpan < T > ( ref destination . _reference , ( int ) destination . _flattenedLength ) ;
157
+
158
+ GaussianDistribution < T > ( span , destination . _flattenedLength ) ;
159
+
160
+ return ref destination ;
161
+ }
162
+
163
+ public static ref readonly TensorSpan < T > FillUniformDistribution < T > ( in TensorSpan < T > destination ) where T : IFloatingPoint < T >
164
+ {
165
+ Span < T > span = MemoryMarshal . CreateSpan < T > ( ref destination . _reference , ( int ) destination . _flattenedLength ) ;
166
+
167
+ for ( int i = 0 ; i < span . Length ; i ++ )
168
+ span [ i ] = T . CreateChecked ( Random . Shared . NextDouble ( ) ) ;
169
+
170
+ return ref destination ;
171
+ }
144
172
}
145
173
}
0 commit comments