Skip to content

Commit ef6c404

Browse files
authored
Adds tests (and fixes) for the dtype invariant. (#739)
That is, np.array(array).dtype == np.dtype(eltype(array)). For this to hold, we needed to restrict to only creating dtypes for primitives, tuples and named tuples. We removed support for arbitrary structs (which are not supported by our implementation of the array interface and buffer protocol). We also worked around a feature/bug/quirk of numpy in that if you do numpy.dtype(descr) where descr is a list of (name,type) field descriptors of a struct, then the dtype you get is not the same as the dtype of an array constructed from something whose array interface has that same descr. In particular, if any item in descr is struct padding like ("", "|V4"), then on conversion to a dtype the name is replaced with e.g. "f2". Going the array route, the padding gets ignored and does not feature in the resulting dtype. The fix here is to compute a different representation of the same information for the dtype - namely the dict of names, types and offsets way. Co-authored-by: Christopher Rowley <github.com/cjdoris>
1 parent f0365a9 commit ef6c404

File tree

4 files changed

+54
-20
lines changed

4 files changed

+54
-20
lines changed

docs/src/juliacall-reference.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,9 @@ jl.Vector[jl.Int]()
202202
```
203203
204204
Some Julia types can be converted to corresponding numpy dtypes like `numpy.dtype(jl.Int)`.
205-
Supports primitive types: `Bool`, `IntXX`, `UIntXX`, `FloatXX`, `ComplexFXX`,
206-
`NumpyDates.InlineDateTime64{unit}` and `NumpyDates.InlineTimeDelta64{unit}`. Also
207-
supports tuples, named tuples and structs of these.
205+
Supports `Bool`, `IntXX`, `UIntXX`, `FloatXX`, `ComplexFXX`,
206+
`NumpyDates.InlineDateTime64{unit}` and `NumpyDates.InlineTimeDelta64{unit}`, plus
207+
`Tuple`s and `NamedTuple`s of these.
208208
`````
209209

210210
`````@customdoc

src/JlWrap/array.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ pybufferformat(::Type{T}) where {T} =
187187
T == Complex{Cdouble} ? "Zd" :
188188
T == Bool ? "?" :
189189
T == Ptr{Cvoid} ? "P" :
190-
if isstructtype(T) && isconcretetype(T) && allocatedinline(T)
190+
if (T <: Union{Tuple,NamedTuple}) && isstructtype(T) && isconcretetype(T) && allocatedinline(T)
191191
n = fieldcount(T)
192192
flds = []
193193
for i = 1:n
@@ -234,7 +234,7 @@ pyjlarray_isarrayabletype(::Type{NamedTuple{names,types}}) where {names,types} =
234234

235235
const PYTYPESTRDESCR = IdDict{Type,Tuple{String,Py}}()
236236

237-
pytypestrdescr(::Type{T}) where {T} =
237+
function pytypestrdescr(::Type{T}) where {T}
238238
get!(PYTYPESTRDESCR, T) do
239239
c = Utils.islittleendian() ? '<' : '>'
240240
if T == Bool
@@ -275,7 +275,7 @@ pytypestrdescr(::Type{T}) where {T} =
275275
u == NumpyDates.UNBOUND_UNITS ? "" :
276276
m == 1 ? "[$(Symbol(u))]" : "[$(m)$(Symbol(u))]"
277277
("$(c)$(tc)8$(us)", PyNULL)
278-
elseif isstructtype(T) && isconcretetype(T) && Base.allocatedinline(T)
278+
elseif (T <: Union{Tuple,NamedTuple}) && isstructtype(T) && isconcretetype(T) && Base.allocatedinline(T)
279279
n = fieldcount(T)
280280
flds = []
281281
for i = 1:n
@@ -298,6 +298,7 @@ pytypestrdescr(::Type{T}) where {T} =
298298
("", PyNULL)
299299
end
300300
end
301+
end
301302

302303
pyjlarray_array__array(x::AbstractArray) = x isa Array ? Py(nothing) : pyjl(Array(x))
303304
pyjlarray_array__pyobjectarray(x::AbstractArray) = pyjl(PyObjectArray(x))

src/JlWrap/type.jl

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,49 @@ function pyjltype_getitem(self::Type, k_)
1111
end
1212
end
1313

14+
const PYNUMPYDTYPE = IdDict{Type,Py}()
15+
1416
function pyjltype_numpy_dtype(self::Type)
15-
typestr, descr = pytypestrdescr(self)
16-
if isempty(typestr)
17-
errset(pybuiltins.AttributeError, "__numpy_dtype__")
18-
return PyNULL
17+
ans = get!(PYNUMPYDTYPE, self) do
18+
typestr, descr = pytypestrdescr(self)
19+
# unsupported type
20+
if typestr == ""
21+
return PyNULL
22+
end
23+
np = pyimport("numpy")
24+
# simple scalar type
25+
if pyisnull(descr)
26+
return np.dtype(typestr)
27+
end
28+
# We could juse use np.dtype(descr), but when there is padding, np.dtype(descr)
29+
# changes the names of the padding fields from "" to "f{N}". Using this other
30+
# dtype constructor avoids this issue and preserves the invariant:
31+
# np.dtype(eltype(array)) == np.array(array).dtype
32+
names = []
33+
formats = []
34+
offsets = []
35+
for i = 1:fieldcount(self)
36+
nm = fieldname(self, i)
37+
push!(names, nm isa Integer ? "f$(nm-1)" : String(nm))
38+
ts, ds = pytypestrdescr(fieldtype(self, i))
39+
push!(formats, pyisnull(ds) ? ts : ds)
40+
push!(offsets, fieldoffset(self, i))
41+
end
42+
return np.dtype(
43+
pydict(
44+
names = pylist(names),
45+
formats = pylist(formats),
46+
offsets = pylist(offsets),
47+
itemsize = sizeof(self),
48+
),
49+
)
1950
end
20-
np = pyimport("numpy")
21-
if pyisnull(descr)
22-
return np.dtype(typestr)
23-
else
24-
return np.dtype(descr)
51+
if pyisnull(ans)
52+
errset(pybuiltins.AttributeError, "__numpy_dtype__")
2553
end
54+
return ans
2655
end
2756

28-
pyjl_handle_error_type(::typeof(pyjltype_numpy_dtype), x, exc) = pybuiltins.AttributeError
29-
3057
function init_type()
3158
jl = pyjuliacallmodule
3259
pybuiltins.exec(

test/JlWrap.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,22 +510,28 @@ end
510510
(Tuple{Int32, Int32}, pylist([("f0", "int32"), ("f1", "int32")])),
511511
(@NamedTuple{}, pylist()),
512512
(@NamedTuple{x::Int32, y::Int32}, pylist([("x", "int32"), ("y", "int32")])),
513-
(Pair{Int32, Int32}, pylist([("first", "int32"), ("second", "int32")])),
514513
]
515514
@test pyeq(Bool, pygetattr(pyjl(t), "__numpy_dtype__"), np.dtype(d))
516-
@test pyeq(Bool, np.dtype(pyjl(t)), np.dtype(d))
515+
@test pyeq(Bool, np.dtype(t), np.dtype(d))
516+
# test the invariant np.dtype(eltype(array)) == np.array(array).dtype
517+
@test isequal(np.dtype(t), np.array(t[]).dtype)
517518
end
518519

519520
# unsupported cases
520521
@testset "$t -> AttributeError" for t in [
521-
# non-primitives or mutables
522+
# structs / mutables
523+
Pair,
524+
Pair{Int,Int},
522525
String,
523526
Vector{Int},
524527
# pointers
525528
Ptr{Cvoid},
526529
Ptr{Int},
527530
# PyPtr specifically should NOT be interpreted as np.dtype("O")
528531
PythonCall.C.PyPtr,
532+
# tuples containing illegal things
533+
Tuple{String},
534+
Tuple{Pair{Int,Int}},
529535
]
530536
err = try
531537
pygetattr(pyjl(t), "__numpy_dtype__")

0 commit comments

Comments
 (0)