@@ -26,23 +26,23 @@ public static bool all(NDArray a)
2626#else
2727
2828 #region Compute
29- switch ( a . typecode )
30- {
31- case NPTypeCode . Boolean : return _all_linear < bool > ( a . MakeGeneric < bool > ( ) ) ;
32- case NPTypeCode . Byte : return _all_linear < byte > ( a . MakeGeneric < byte > ( ) ) ;
33- case NPTypeCode . Int16 : return _all_linear < short > ( a . MakeGeneric < short > ( ) ) ;
34- case NPTypeCode . UInt16 : return _all_linear < ushort > ( a . MakeGeneric < ushort > ( ) ) ;
35- case NPTypeCode . Int32 : return _all_linear < int > ( a . MakeGeneric < int > ( ) ) ;
36- case NPTypeCode . UInt32 : return _all_linear < uint > ( a . MakeGeneric < uint > ( ) ) ;
37- case NPTypeCode . Int64 : return _all_linear < long > ( a . MakeGeneric < long > ( ) ) ;
38- case NPTypeCode . UInt64 : return _all_linear < ulong > ( a . MakeGeneric < ulong > ( ) ) ;
39- case NPTypeCode . Char : return _all_linear < char > ( a . MakeGeneric < char > ( ) ) ;
40- case NPTypeCode . Double : return _all_linear < double > ( a . MakeGeneric < double > ( ) ) ;
41- case NPTypeCode . Single : return _all_linear < float > ( a . MakeGeneric < float > ( ) ) ;
42- case NPTypeCode . Decimal : return _all_linear < decimal > ( a . MakeGeneric < decimal > ( ) ) ;
43- default :
44- throw new NotSupportedException ( ) ;
45- }
29+ switch ( a . typecode )
30+ {
31+ case NPTypeCode . Boolean : return _all_linear < bool > ( a . MakeGeneric < bool > ( ) ) ;
32+ case NPTypeCode . Byte : return _all_linear < byte > ( a . MakeGeneric < byte > ( ) ) ;
33+ case NPTypeCode . Int16 : return _all_linear < short > ( a . MakeGeneric < short > ( ) ) ;
34+ case NPTypeCode . UInt16 : return _all_linear < ushort > ( a . MakeGeneric < ushort > ( ) ) ;
35+ case NPTypeCode . Int32 : return _all_linear < int > ( a . MakeGeneric < int > ( ) ) ;
36+ case NPTypeCode . UInt32 : return _all_linear < uint > ( a . MakeGeneric < uint > ( ) ) ;
37+ case NPTypeCode . Int64 : return _all_linear < long > ( a . MakeGeneric < long > ( ) ) ;
38+ case NPTypeCode . UInt64 : return _all_linear < ulong > ( a . MakeGeneric < ulong > ( ) ) ;
39+ case NPTypeCode . Char : return _all_linear < char > ( a . MakeGeneric < char > ( ) ) ;
40+ case NPTypeCode . Double : return _all_linear < double > ( a . MakeGeneric < double > ( ) ) ;
41+ case NPTypeCode . Single : return _all_linear < float > ( a . MakeGeneric < float > ( ) ) ;
42+ case NPTypeCode . Decimal : return _all_linear < decimal > ( a . MakeGeneric < decimal > ( ) ) ;
43+ default :
44+ throw new NotSupportedException ( ) ;
45+ }
4646 #endregion
4747#endif
4848 }
@@ -51,12 +51,113 @@ public static bool all(NDArray a)
5151 /// Test whether all array elements along a given axis evaluate to True.
5252 /// </summary>
5353 /// <param name="a">Input array or object that can be converted to an array.</param>
54- /// <param name="axis">Axis or axes along which a logical OR reduction is performed. The default (axis = None) is to perform a logical OR over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis.</param>
54+ /// <param name="axis">Axis or axes along which a logical AND reduction is performed. The default (axis = None) is to perform a logical OR over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis.</param>
5555 /// <returns>A new boolean or ndarray is returned unless out is specified, in which case a reference to out is returned.</returns>
5656 /// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.all.html</remarks>
57- public static NDArray < bool > all ( NDArray nd , int axis )
57+ public static NDArray < bool > all ( NDArray nd , int axis , bool keepdims = false )
5858 {
59- throw new NotImplementedException ( ) ; //TODO
59+ if ( axis < 0 )
60+ axis = nd . ndim + axis ;
61+ if ( axis < 0 || axis >= nd . ndim )
62+ {
63+ throw new ArgumentOutOfRangeException ( nameof ( axis ) ) ;
64+ }
65+ if ( nd . ndim == 0 )
66+ {
67+ throw new ArgumentException ( "Can't operate with zero array" ) ;
68+ }
69+ if ( nd == null )
70+ {
71+ throw new ArgumentException ( "Can't operate with null array" ) ;
72+ }
73+
74+ int [ ] inputShape = nd . shape ;
75+ int [ ] outputShape = new int [ keepdims ? inputShape . Length : inputShape . Length - 1 ] ;
76+ int outputIndex = 0 ;
77+ for ( int i = 0 ; i < inputShape . Length ; i ++ )
78+ {
79+ if ( i != axis )
80+ {
81+ outputShape [ outputIndex ++ ] = inputShape [ i ] ;
82+ }
83+ else if ( keepdims )
84+ {
85+ outputShape [ outputIndex ++ ] = 1 ; // 保留轴,但长度为1
86+ }
87+ }
88+
89+ NDArray < bool > resultArray = ( NDArray < bool > ) zeros < bool > ( outputShape ) ;
90+ Span < bool > resultSpan = resultArray . GetData ( ) . AsSpan < bool > ( ) ;
91+
92+ int axisSize = inputShape [ axis ] ;
93+
94+ // It help to build an index
95+ int preAxisStride = 1 ;
96+ for ( int i = 0 ; i < axis ; i ++ )
97+ {
98+ preAxisStride *= inputShape [ i ] ;
99+ }
100+
101+ int postAxisStride = 1 ;
102+ for ( int i = axis + 1 ; i < inputShape . Length ; i ++ )
103+ {
104+ postAxisStride *= inputShape [ i ] ;
105+ }
106+
107+
108+ // Operate different logic by TypeCode
109+ bool computationSuccess = false ;
110+ switch ( nd . typecode )
111+ {
112+ case NPTypeCode . Boolean : computationSuccess = ComputeAllPerAxis < bool > ( nd . MakeGeneric < bool > ( ) , axis , preAxisStride , postAxisStride , axisSize , resultSpan ) ; break ;
113+ case NPTypeCode . Byte : computationSuccess = ComputeAllPerAxis < byte > ( nd . MakeGeneric < byte > ( ) , axis , preAxisStride , postAxisStride , axisSize , resultSpan ) ; break ;
114+ case NPTypeCode . Int16 : computationSuccess = ComputeAllPerAxis < short > ( nd . MakeGeneric < short > ( ) , axis , preAxisStride , postAxisStride , axisSize , resultSpan ) ; break ;
115+ case NPTypeCode . UInt16 : computationSuccess = ComputeAllPerAxis < ushort > ( nd . MakeGeneric < ushort > ( ) , axis , preAxisStride , postAxisStride , axisSize , resultSpan ) ; break ;
116+ case NPTypeCode . Int32 : computationSuccess = ComputeAllPerAxis < int > ( nd . MakeGeneric < int > ( ) , axis , preAxisStride , postAxisStride , axisSize , resultSpan ) ; break ;
117+ case NPTypeCode . UInt32 : computationSuccess = ComputeAllPerAxis < uint > ( nd . MakeGeneric < uint > ( ) , axis , preAxisStride , postAxisStride , axisSize , resultSpan ) ; break ;
118+ case NPTypeCode . Int64 : computationSuccess = ComputeAllPerAxis < long > ( nd . MakeGeneric < long > ( ) , axis , preAxisStride , postAxisStride , axisSize , resultSpan ) ; break ;
119+ case NPTypeCode . UInt64 : computationSuccess = ComputeAllPerAxis < ulong > ( nd . MakeGeneric < ulong > ( ) , axis , preAxisStride , postAxisStride , axisSize , resultSpan ) ; break ;
120+ case NPTypeCode . Char : computationSuccess = ComputeAllPerAxis < char > ( nd . MakeGeneric < char > ( ) , axis , preAxisStride , postAxisStride , axisSize , resultSpan ) ; break ;
121+ case NPTypeCode . Double : computationSuccess = ComputeAllPerAxis < double > ( nd . MakeGeneric < double > ( ) , axis , preAxisStride , postAxisStride , axisSize , resultSpan ) ; break ;
122+ case NPTypeCode . Single : computationSuccess = ComputeAllPerAxis < float > ( nd . MakeGeneric < float > ( ) , axis , preAxisStride , postAxisStride , axisSize , resultSpan ) ; break ;
123+ case NPTypeCode . Decimal : computationSuccess = ComputeAllPerAxis < decimal > ( nd . MakeGeneric < decimal > ( ) , axis , preAxisStride , postAxisStride , axisSize , resultSpan ) ; break ;
124+ default :
125+ throw new NotSupportedException ( $ "Type { nd . typecode } is not supported") ;
126+ }
127+
128+ if ( ! computationSuccess )
129+ {
130+ throw new InvalidOperationException ( "Failed to compute all() along the specified axis" ) ;
131+ }
132+
133+ return resultArray ;
134+ }
135+
136+ private static bool ComputeAllPerAxis < T > ( NDArray < T > nd , int axis , int preAxisStride , int postAxisStride , int axisSize , Span < bool > resultSpan ) where T : unmanaged
137+ {
138+ Span < T > inputSpan = nd . GetData ( ) . AsSpan < T > ( ) ;
139+
140+
141+ for ( int o = 0 ; o < resultSpan . Length ; o ++ )
142+ {
143+ int blockIndex = o / postAxisStride ;
144+ int inBlockIndex = o % postAxisStride ;
145+ int inputStartIndex = blockIndex * axisSize * postAxisStride + inBlockIndex ;
146+
147+ bool currentResult = true ;
148+ for ( int a = 0 ; a < axisSize ; a ++ )
149+ {
150+ int inputIndex = inputStartIndex + a * postAxisStride ;
151+ if ( inputSpan [ inputIndex ] . Equals ( default ( T ) ) )
152+ {
153+ currentResult = false ;
154+ break ;
155+ }
156+ }
157+ resultSpan [ o ] = currentResult ;
158+ }
159+
160+ return true ;
60161 }
61162
62163 private static bool _all_linear < T > ( NDArray < T > nd ) where T : unmanaged
0 commit comments