Skip to content

Commit 4edf91d

Browse files
zhuoshui-AIOceania2018
authored andcommitted
update np.any function and add py.complex function
1 parent c01832a commit 4edf91d

File tree

5 files changed

+177
-6
lines changed

5 files changed

+177
-6
lines changed

src/NumSharp.Core/Logic/np.all.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public static NDArray<bool> all(NDArray nd, int axis, bool keepdims = false)
8282
}
8383
else if (keepdims)
8484
{
85-
outputShape[outputIndex++] = 1; // 保留轴,但长度为1
85+
outputShape[outputIndex++] = 1; // keep axis but length is one.
8686
}
8787
}
8888

src/NumSharp.Core/Logic/np.any.cs

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,111 @@ public static bool any(NDArray a)
5555
/// <param name="axis">Axis or axes along which a logical OR reduction is performed. The default (axis = None) is to perform a logical OR over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis.</param>
5656
/// <returns>A new boolean or ndarray is returned unless out is specified, in which case a reference to out is returned.</returns>
5757
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.any.html</remarks>
58-
public static NDArray<bool> any(NDArray nd, int axis)
58+
public static NDArray<bool> any(NDArray nd, int axis, bool keepdims)
5959
{
60-
throw new NotImplementedException(); //TODO
60+
if (axis < 0)
61+
axis = nd.ndim + axis;
62+
if (axis < 0 || axis >= nd.ndim)
63+
{
64+
throw new ArgumentOutOfRangeException(nameof(axis));
65+
}
66+
if (nd.ndim == 0)
67+
{
68+
throw new ArgumentException("Can't operate with zero array");
69+
}
70+
if (nd == null)
71+
{
72+
throw new ArgumentException("Can't operate with null array");
73+
}
74+
75+
int[] inputShape = nd.shape;
76+
int[] outputShape = new int[keepdims ? inputShape.Length : inputShape.Length - 1];
77+
int outputIndex = 0;
78+
for (int i = 0; i < inputShape.Length; i++)
79+
{
80+
if (i != axis)
81+
{
82+
outputShape[outputIndex++] = inputShape[i];
83+
}
84+
else if (keepdims)
85+
{
86+
outputShape[outputIndex++] = 1; // keep axis but length is one.
87+
}
88+
}
89+
90+
NDArray<bool> resultArray = (NDArray<bool>)zeros<bool>(outputShape);
91+
Span<bool> resultSpan = resultArray.GetData().AsSpan<bool>();
92+
93+
int axisSize = inputShape[axis];
94+
95+
// It help to build an index
96+
int preAxisStride = 1;
97+
for (int i = 0; i < axis; i++)
98+
{
99+
preAxisStride *= inputShape[i];
100+
}
101+
102+
int postAxisStride = 1;
103+
for (int i = axis + 1; i < inputShape.Length; i++)
104+
{
105+
postAxisStride *= inputShape[i];
106+
}
107+
108+
109+
// Operate different logic by TypeCode
110+
bool computationSuccess = false;
111+
switch (nd.typecode)
112+
{
113+
case NPTypeCode.Boolean: computationSuccess = ComputeAnyPerAxis<bool>(nd.MakeGeneric<bool>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
114+
case NPTypeCode.Byte: computationSuccess = ComputeAnyPerAxis<byte>(nd.MakeGeneric<byte>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
115+
case NPTypeCode.Int16: computationSuccess = ComputeAnyPerAxis<short>(nd.MakeGeneric<short>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
116+
case NPTypeCode.UInt16: computationSuccess = ComputeAnyPerAxis<ushort>(nd.MakeGeneric<ushort>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
117+
case NPTypeCode.Int32: computationSuccess = ComputeAnyPerAxis<int>(nd.MakeGeneric<int>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
118+
case NPTypeCode.UInt32: computationSuccess = ComputeAnyPerAxis<uint>(nd.MakeGeneric<uint>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
119+
case NPTypeCode.Int64: computationSuccess = ComputeAnyPerAxis<long>(nd.MakeGeneric<long>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
120+
case NPTypeCode.UInt64: computationSuccess = ComputeAnyPerAxis<ulong>(nd.MakeGeneric<ulong>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
121+
case NPTypeCode.Char: computationSuccess = ComputeAnyPerAxis<char>(nd.MakeGeneric<char>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
122+
case NPTypeCode.Double: computationSuccess = ComputeAnyPerAxis<double>(nd.MakeGeneric<double>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
123+
case NPTypeCode.Single: computationSuccess = ComputeAnyPerAxis<float>(nd.MakeGeneric<float>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
124+
case NPTypeCode.Decimal: computationSuccess = ComputeAnyPerAxis<decimal>(nd.MakeGeneric<decimal>(), axis, preAxisStride, postAxisStride, axisSize, resultSpan); break;
125+
default:
126+
throw new NotSupportedException($"Type {nd.typecode} is not supported");
127+
}
128+
129+
130+
if (!computationSuccess)
131+
{
132+
throw new InvalidOperationException("Failed to compute all() along the specified axis");
133+
}
134+
135+
return resultArray;
136+
}
137+
138+
private static bool ComputeAnyPerAxis<T>(NDArray<T> nd, int axis, int preAxisStride, int postAxisStride, int axisSize, Span<bool> resultSpan) where T : unmanaged
139+
{
140+
Span<T> inputSpan = nd.GetData().AsSpan<T>();
141+
142+
143+
for (int o = 0; o < resultSpan.Length; o++)
144+
{
145+
int blockIndex = o / postAxisStride;
146+
int inBlockIndex = o % postAxisStride;
147+
int inputStartIndex = blockIndex * axisSize * postAxisStride + inBlockIndex;
148+
149+
bool currentResult = true;
150+
for (int a = 0; a < axisSize; a++)
151+
{
152+
int inputIndex = inputStartIndex + a * postAxisStride;
153+
if (inputSpan[inputIndex].Equals(default(T)))
154+
{
155+
currentResult = true;
156+
break;
157+
}
158+
}
159+
resultSpan[o] = currentResult;
160+
}
161+
162+
return false;
61163
}
62164

63165
private static bool _any_linear<T>(NDArray<T> nd) where T : unmanaged

src/NumSharp.Core/Utilities/ArrayConvert.cs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Numerics;
33
using System.Runtime.CompilerServices;
4+
using System.Text.RegularExpressions;
45
using System.Threading.Tasks;
56
using NumSharp.Backends;
67

@@ -4072,12 +4073,13 @@ public static Complex[] ToComplex(Decimal[] sourceArray)
40724073
Parallel.For(0, length, i => new Complex(Converts.ToDouble(sourceArray[i]), 0d));
40734074
return output;
40744075
}
4075-
4076+
40764077
/// <summary>
40774078
/// Converts <see cref="String"/> array to a <see cref="Complex"/> array.
40784079
/// </summary>
40794080
/// <param name="sourceArray">The array to convert</param>
40804081
/// <returns>Converted array of type Complex</returns>
4082+
/// <exception cref="FormatException">A string in sourceArray has an invalid complex format</exception>
40814083
[MethodImpl(MethodImplOptions.AggressiveInlining)]
40824084
public static Complex[] ToComplex(String[] sourceArray)
40834085
{
@@ -4087,7 +4089,16 @@ public static Complex[] ToComplex(String[] sourceArray)
40874089
var length = sourceArray.Length;
40884090
var output = new Complex[length];
40894091

4090-
Parallel.For(0, length, i => new Complex(Converts.ToDouble(sourceArray[i]), 0d));
4092+
Parallel.For(0, length, i =>
4093+
{
4094+
string input = sourceArray[i]?.Trim() ?? string.Empty;
4095+
if (string.IsNullOrEmpty(input))
4096+
{
4097+
output[i] = Complex.Zero; // NullString save as zero.
4098+
return;
4099+
}
4100+
var match = py.Complex(sourceArray[i]);
4101+
});
40914102
return output;
40924103
}
40934104
#endif

src/NumSharp.Core/Utilities/py.cs

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
namespace NumSharp.Utilities
1+
using System;
2+
using System.Numerics;
3+
using System.Text.RegularExpressions;
4+
5+
namespace NumSharp.Utilities
26
{
37
/// <summary>
48
/// Implements Python utility functions that are often used in connection with numpy
@@ -12,5 +16,47 @@ public static int[] range(int n)
1216
a[i] = i;
1317
return a;
1418
}
19+
/// <summary>
20+
/// 解析单个Python风格的复数字符串为Complex对象
21+
/// </summary>
22+
private static readonly Regex _pythonComplexRegex = new Regex(
23+
@"^(?<real>-?\d+(\.\d+)?)?((?<imagSign>\+|-)?(?<imag>\d+(\.\d+)?)?)?j$|^(?<onlyReal>-?\d+(\.\d+)?)$",
24+
RegexOptions.IgnoreCase | RegexOptions.Compiled | RegexOptions.ExplicitCapture);
25+
public static Complex Complex(string input)
26+
{
27+
var match = _pythonComplexRegex.Match(input);
28+
if (!match.Success)
29+
throw new FormatException($"Invalid Python complex format: '{input}'. Expected format like '10+5j', '3-2j', '4j' or '5'.");
30+
31+
// 解析仅实部的场景
32+
if (match.Groups["onlyReal"].Success)
33+
{
34+
double real = double.Parse(match.Groups["onlyReal"].Value);
35+
return new Complex(real, 0);
36+
}
37+
38+
// 解析实部(默认0)
39+
double realPart = 0;
40+
if (double.TryParse(match.Groups["real"].Value, out double r))
41+
realPart = r;
42+
43+
// 解析虚部(处理特殊情况:j / -j / +j)
44+
double imagPart = 0;
45+
string imagStr = match.Groups["imag"].Value;
46+
string imagSign = match.Groups["imagSign"].Value;
47+
48+
if (string.IsNullOrEmpty(imagStr) && !string.IsNullOrEmpty(input.TrimEnd('j', 'J')))
49+
{
50+
// 处理仅虚部的情况:j → 1j, -j → -1j, +j → 1j
51+
imagStr = "1";
52+
}
53+
54+
if (double.TryParse(imagStr, out double im))
55+
{
56+
imagPart = im * (imagSign == "-" ? -1 : 1);
57+
}
58+
59+
return new Complex(realPart, imagPart);
60+
}
1561
}
1662
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
7+
namespace NumSharp.UnitTest.Utilities
8+
{
9+
internal class pyTest
10+
{
11+
}
12+
}

0 commit comments

Comments
 (0)