@@ -185,7 +185,9 @@ def test_rst_role(self, doctype, expected):
185185 )
186186 @pytest .mark .parametrize ("name" , ["array" , "ndarray" , "array-like" , "array_like" ])
187187 @pytest .mark .parametrize ("dtype" , ["int" , "np.int8" ])
188- @pytest .mark .parametrize ("shape" , ["(2, 3)" , "(N, m)" , "3D" , "2-D" , "(N, ...)" ])
188+ @pytest .mark .parametrize ("shape" ,
189+ ["(2, 3)" , "(N, m)" , "3D" , "2-D" , "(N, ...)" , "([P,] M, N)" ]
190+ )
189191 def test_natlang_array (self , fmt , expected_fmt , name , dtype , shape ):
190192
191193 def escape (name : str ) -> str :
@@ -202,6 +204,18 @@ def escape(name: str) -> str:
202204 assert annotation .value == expected
203205 # fmt: on
204206
207+ @pytest .mark .parametrize (
208+ ("doctype" , "expected" ),
209+ [
210+ ("ndarray of dtype (int or float)" , "ndarray[int | float]" ),
211+ ("([P,] M, N) (int or float) array" , "array[int | float]" ),
212+ ],
213+ )
214+ def test_natlang_array_specific (self , doctype , expected ):
215+ transformer = DoctypeTransformer ()
216+ annotation , _ = transformer .doctype_to_annotation (doctype )
217+ assert annotation .value == expected
218+
205219 @pytest .mark .parametrize ("shape" , ["(-1, 3)" , "(1.0, 2)" , "-3D" , "-2-D" ])
206220 def test_natlang_array_invalid_shape (self , shape ):
207221 doctype = f"array of shape { shape } "
0 commit comments