Skip to content

Commit c01832a

Browse files
committed
<feat>(Logics):add high performance logical AND function with axis and keepdim support also test Upgrade to .NET 8.
1 parent 004ce2b commit c01832a

File tree

2 files changed

+203
-20
lines changed

2 files changed

+203
-20
lines changed

src/NumSharp.Core/Logic/np.all.cs

Lines changed: 121 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,23 @@ public static bool all(NDArray a)
2626
#else
2727

2828
#region Compute
29-
switch (a.typecode)
30-
{
31-
case NPTypeCode.Boolean: return _all_linear<bool>(a.MakeGeneric<bool>());
32-
case NPTypeCode.Byte: return _all_linear<byte>(a.MakeGeneric<byte>());
33-
case NPTypeCode.Int16: return _all_linear<short>(a.MakeGeneric<short>());
34-
case NPTypeCode.UInt16: return _all_linear<ushort>(a.MakeGeneric<ushort>());
35-
case NPTypeCode.Int32: return _all_linear<int>(a.MakeGeneric<int>());
36-
case NPTypeCode.UInt32: return _all_linear<uint>(a.MakeGeneric<uint>());
37-
case NPTypeCode.Int64: return _all_linear<long>(a.MakeGeneric<long>());
38-
case NPTypeCode.UInt64: return _all_linear<ulong>(a.MakeGeneric<ulong>());
39-
case NPTypeCode.Char: return _all_linear<char>(a.MakeGeneric<char>());
40-
case NPTypeCode.Double: return _all_linear<double>(a.MakeGeneric<double>());
41-
case NPTypeCode.Single: return _all_linear<float>(a.MakeGeneric<float>());
42-
case NPTypeCode.Decimal: return _all_linear<decimal>(a.MakeGeneric<decimal>());
43-
default:
44-
throw new NotSupportedException();
45-
}
29+
switch (a.typecode)
30+
{
31+
case NPTypeCode.Boolean: return _all_linear<bool>(a.MakeGeneric<bool>());
32+
case NPTypeCode.Byte: return _all_linear<byte>(a.MakeGeneric<byte>());
33+
case NPTypeCode.Int16: return _all_linear<short>(a.MakeGeneric<short>());
34+
case NPTypeCode.UInt16: return _all_linear<ushort>(a.MakeGeneric<ushort>());
35+
case NPTypeCode.Int32: return _all_linear<int>(a.MakeGeneric<int>());
36+
case NPTypeCode.UInt32: return _all_linear<uint>(a.MakeGeneric<uint>());
37+
case NPTypeCode.Int64: return _all_linear<long>(a.MakeGeneric<long>());
38+
case NPTypeCode.UInt64: return _all_linear<ulong>(a.MakeGeneric<ulong>());
39+
case NPTypeCode.Char: return _all_linear<char>(a.MakeGeneric<char>());
40+
case NPTypeCode.Double: return _all_linear<double>(a.MakeGeneric<double>());
41+
case NPTypeCode.Single: return _all_linear<float>(a.MakeGeneric<float>());
42+
case NPTypeCode.Decimal: return _all_linear<decimal>(a.MakeGeneric<decimal>());
43+
default:
44+
throw new NotSupportedException();
45+
}
4646
#endregion
4747
#endif
4848
}
@@ -51,12 +51,113 @@ public static bool all(NDArray a)
5151
/// Test whether all array elements along a given axis evaluate to True.
5252
/// </summary>
5353
/// <param name="a">Input array or object that can be converted to an array.</param>
54-
/// <param name="axis">Axis or axes along which a logical OR reduction is performed. The default (axis = None) is to perform a logical OR over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis.</param>
54+
/// <param name="axis">Axis or axes along which a logical AND reduction is performed. The default (axis = None) is to perform a logical OR over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis.</param>
5555
/// <returns>A new boolean or ndarray is returned unless out is specified, in which case a reference to out is returned.</returns>
5656
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.all.html</remarks>
57-
public static NDArray<bool> all(NDArray nd, int axis)
57+
public static NDArray<bool> all(NDArray nd, int axis, bool keepdims = false)
5858
{
59-
throw new NotImplementedException(); //TODO
59+
if (axis < 0)
60+
axis = nd.ndim + axis;
61+
if (axis < 0 || axis >= nd.ndim)
62+
{
63+
throw new ArgumentOutOfRangeException(nameof(axis));
64+
}
65+
if (nd.ndim == 0)
66+
{
67+
throw new ArgumentException("Can't operate with zero array");
68+
}
69+
if (nd == null)
70+
{
71+
throw new ArgumentException("Can't operate with null array");
72+
}
73+
74+
int[] inputShape = nd.shape;
75+
int[] outputShape = new int[keepdims ? inputShape.Length : inputShape.Length - 1];
76+
int outputIndex = 0;
77+
for (int i = 0; i < inputShape.Length; i++)
78+
{
79+
if (i != axis)
80+
{
81+
outputShape[outputIndex++] = inputShape[i];
82+
}
83+
else if (keepdims)
84+
{
85+
outputShape[outputIndex++] = 1; // 保留轴,但长度为1
86+
}
87+
}
88+
89+
NDArray<bool> resultArray = (NDArray<bool>)zeros<bool>(outputShape);
90+
Span<bool> resultSpan = resultArray.GetData().AsSpan<bool>();
91+
92+
int axisSize = inputShape[axis];
93+
94+
// It help to build an index
95+
int preAxisStride = 1;
96+
for (int i = 0; i < axis; i++)
97+
{
98+
preAxisStride *= inputShape[i];
99+
}
100+
101+
int postAxisStride = 1;
102+
for (int i = axis + 1; i < inputShape.Length; i++)
103+
{
104+
postAxisStride *= inputShape[i];
105+
}
106+
107+
108+
// Operate different logic by TypeCode
109+
bool computationSuccess = false;
110+
switch (nd.typecode)
111+
{
112+
case NPTypeCode.Boolean: computationSuccess = ComputeAllPerAxis<bool>(nd.MakeGeneric<bool>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
113+
case NPTypeCode.Byte: computationSuccess = ComputeAllPerAxis<byte>(nd.MakeGeneric<byte>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
114+
case NPTypeCode.Int16: computationSuccess = ComputeAllPerAxis<short>(nd.MakeGeneric<short>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
115+
case NPTypeCode.UInt16: computationSuccess = ComputeAllPerAxis<ushort>(nd.MakeGeneric<ushort>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
116+
case NPTypeCode.Int32: computationSuccess = ComputeAllPerAxis<int>(nd.MakeGeneric<int>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
117+
case NPTypeCode.UInt32: computationSuccess = ComputeAllPerAxis<uint>(nd.MakeGeneric<uint>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
118+
case NPTypeCode.Int64: computationSuccess = ComputeAllPerAxis<long>(nd.MakeGeneric<long>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
119+
case NPTypeCode.UInt64: computationSuccess = ComputeAllPerAxis<ulong>(nd.MakeGeneric<ulong>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
120+
case NPTypeCode.Char: computationSuccess = ComputeAllPerAxis<char>(nd.MakeGeneric<char>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
121+
case NPTypeCode.Double: computationSuccess = ComputeAllPerAxis<double>(nd.MakeGeneric<double>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
122+
case NPTypeCode.Single: computationSuccess = ComputeAllPerAxis<float>(nd.MakeGeneric<float>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
123+
case NPTypeCode.Decimal: computationSuccess = ComputeAllPerAxis<decimal>(nd.MakeGeneric<decimal>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
124+
default:
125+
throw new NotSupportedException($"Type {nd.typecode} is not supported");
126+
}
127+
128+
if (!computationSuccess)
129+
{
130+
throw new InvalidOperationException("Failed to compute all() along the specified axis");
131+
}
132+
133+
return resultArray;
134+
}
135+
136+
private static bool ComputeAllPerAxis<T>(NDArray<T> nd, int axis, int preAxisStride, int postAxisStride, int axisSize, Span<bool> resultSpan) where T : unmanaged
137+
{
138+
Span<T> inputSpan = nd.GetData().AsSpan<T>();
139+
140+
141+
for (int o = 0; o < resultSpan.Length; o++)
142+
{
143+
int blockIndex = o / postAxisStride;
144+
int inBlockIndex = o % postAxisStride;
145+
int inputStartIndex = blockIndex * axisSize * postAxisStride + inBlockIndex;
146+
147+
bool currentResult = true;
148+
for (int a = 0; a < axisSize; a++)
149+
{
150+
int inputIndex = inputStartIndex + a * postAxisStride;
151+
if (inputSpan[inputIndex].Equals(default(T)))
152+
{
153+
currentResult = false;
154+
break;
155+
}
156+
}
157+
resultSpan[o] = currentResult;
158+
}
159+
160+
return true;
60161
}
61162

62163
private static bool _all_linear<T>(NDArray<T> nd) where T : unmanaged
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Microsoft.VisualStudio.TestTools.UnitTesting;
6+
using NumSharp;
7+
8+
namespace NumSharp.UnitTest.Logic
9+
{
10+
[TestClass]
11+
public class np_all_axis_Test
12+
{
13+
[TestMethod]
14+
public void np_all_axis_2D()
15+
{
16+
// Test array: [[true, false, true], [true, true, true]]
17+
var arr = np.array(new bool[,] { { true, false, true }, { true, true, true } });
18+
19+
// Test axis=0 (along columns): should be [true, false, true] (all in each column)
20+
var result_axis0 = np.all(arr, axis: 0);
21+
var expected_axis0 = np.array(new bool[] { true, false, true });
22+
Assert.IsTrue(np.array_equal(result_axis0, expected_axis0));
23+
24+
// Test axis=1 (along rows): should be [false, true] (all in each row)
25+
var result_axis1 = np.all(arr, axis: 1);
26+
var expected_axis1 = np.array(new bool[] { false, true });
27+
Assert.IsTrue(np.array_equal(result_axis1, expected_axis1));
28+
}
29+
30+
[TestMethod]
31+
public void np_all_axis_3D()
32+
{
33+
// Create a 3D array for testing
34+
var arr = np.ones(new int[] { 2, 3, 4 }); // All ones (truthy)
35+
arr[0, 1, 2] = 0; // Add one falsy value
36+
37+
// Test different axes
38+
var result_axis0 = np.all(arr, axis: 0); // Shape should be (3, 4)
39+
Assert.AreEqual(2, result_axis0.ndim);
40+
Assert.AreEqual(3, result_axis0.shape[0]);
41+
Assert.AreEqual(4, result_axis0.shape[1]);
42+
43+
var result_axis1 = np.all(arr, axis: 1); // Shape should be (2, 4)
44+
Assert.AreEqual(2, result_axis1.ndim);
45+
Assert.AreEqual(2, result_axis1.shape[0]);
46+
Assert.AreEqual(4, result_axis1.shape[1]);
47+
48+
var result_axis2 = np.all(arr, axis: 2); // Shape should be (2, 3)
49+
Assert.AreEqual(2, result_axis2.ndim);
50+
Assert.AreEqual(2, result_axis2.shape[0]);
51+
Assert.AreEqual(3, result_axis2.shape[1]);
52+
}
53+
54+
[TestMethod]
55+
public void np_all_keepdims()
56+
{
57+
var arr = np.array(new bool[,] { { true, false, true }, { true, true, true } });
58+
59+
// Test with keepdims=true
60+
var result_keepdims = np.all(arr, axis: 0, keepdims: true);
61+
Assert.AreEqual(2, result_keepdims.ndim); // Should maintain original number of dimensions
62+
Assert.AreEqual(1, result_keepdims.shape[0]); // The reduced axis becomes size 1
63+
Assert.AreEqual(3, result_keepdims.shape[1]); // Other dimensions remain
64+
65+
var result_keepdims1 = np.all(arr, axis: 1, keepdims: true);
66+
Assert.AreEqual(2, result_keepdims1.ndim); // Should maintain original number of dimensions
67+
Assert.AreEqual(2, result_keepdims1.shape[0]); // Other dimensions remain
68+
Assert.AreEqual(1, result_keepdims1.shape[1]); // The reduced axis becomes size 1
69+
}
70+
71+
[TestMethod]
72+
public void np_all_different_types()
73+
{
74+
// Test with integer array
75+
var int_arr = np.array(new int[,] { { 1, 2, 3 }, { 4, 0, 6 } }); // Contains a zero (falsy value)
76+
var int_result = np.all(int_arr, axis: 1);
77+
// First row: all non-zero -> true, Second row: contains zero -> false
78+
Assert.AreEqual(true, int_result[0]);
79+
Assert.AreEqual(false, int_result[1]);
80+
}
81+
}
82+
}

0 commit comments

Comments
 (0)