Skip to content

Commit f0365a9

Browse files
authored
Add numpy dtype access for TypeValue (#738)
* Add numpy dtype mapping for TypeValue * Move numpy dtype mapping to Julia * tidying * add support for datetime64 and timedelta64 * add tests for struct types * document struct support * test simplifications and comments * update changelog --------- Co-authored-by: Christopher Rowley <github.com/cjdoris>
1 parent be16342 commit f0365a9

File tree

4 files changed

+88
-1
lines changed

4 files changed

+88
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## Unreleased
4+
* Added `juliacall.TypeValue.__numpy_dtype__` attribute to allow converting Julia types
5+
to the corresponding NumPy dtype, like `numpy.dtype(jl.Int)`.
6+
37
## 0.9.31 (2025-12-17)
48
* Restore support for Python 3.14+.
59

docs/src/juliacall-reference.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ from juliacall import Main as jl
200200
# equivalent to Vector{Int}() in Julia
201201
jl.Vector[jl.Int]()
202202
```
203+
204+
Some Julia types can be converted to corresponding numpy dtypes like `numpy.dtype(jl.Int)`.
205+
Supports primitive types: `Bool`, `IntXX`, `UIntXX`, `FloatXX`, `ComplexFXX`,
206+
`NumpyDates.InlineDateTime64{unit}` and `NumpyDates.InlineTimeDelta64{unit}`. Also
207+
supports tuples, named tuples and structs of these.
203208
`````
204209

205210
`````@customdoc

src/JlWrap/type.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@ function pyjltype_getitem(self::Type, k_)
1111
end
1212
end
1313

14+
function pyjltype_numpy_dtype(self::Type)
15+
typestr, descr = pytypestrdescr(self)
16+
if isempty(typestr)
17+
errset(pybuiltins.AttributeError, "__numpy_dtype__")
18+
return PyNULL
19+
end
20+
np = pyimport("numpy")
21+
if pyisnull(descr)
22+
return np.dtype(typestr)
23+
else
24+
return np.dtype(descr)
25+
end
26+
end
27+
28+
pyjl_handle_error_type(::typeof(pyjltype_numpy_dtype), x, exc) = pybuiltins.AttributeError
29+
1430
function init_type()
1531
jl = pyjuliacallmodule
1632
pybuiltins.exec(
@@ -25,6 +41,9 @@ class TypeValue(AnyValue):
2541
raise TypeError("not supported")
2642
def __delitem__(self, k):
2743
raise TypeError("not supported")
44+
@property
45+
def __numpy_dtype__(self):
46+
return self._jl_callmethod($(pyjl_methodnum(pyjltype_numpy_dtype)))
2847
""",
2948
@__FILE__(),
3049
"exec",

test/JlWrap.jl

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,13 +472,72 @@ end
472472
end
473473
end
474474

475-
@testitem "type" begin
475+
@testitem "type" setup = [Setup] begin
476+
using PythonCall.NumpyDates
476477
@testset "type" begin
477478
@test pyis(pytype(pyjl(Int)), PythonCall.pyjltypetype)
478479
end
479480
@testset "bool" begin
480481
@test pytruth(pyjl(Int))
481482
end
483+
@testset "numpy dtype" begin
484+
if Setup.devdeps
485+
np = pyimport("numpy")
486+
487+
# success cases
488+
@testset "$t -> $d" for (t, d) in [
489+
(Bool, "bool"),
490+
(Int8, "int8"),
491+
(Int16, "int16"),
492+
(Int32, "int32"),
493+
(Int64, "int64"),
494+
(UInt8, "uint8"),
495+
(UInt16, "uint16"),
496+
(UInt32, "uint32"),
497+
(UInt64, "uint64"),
498+
(Float16, "float16"),
499+
(Float32, "float32"),
500+
(Float64, "float64"),
501+
(ComplexF32, "complex64"),
502+
(ComplexF64, "complex128"),
503+
(InlineDateTime64{SECONDS}, "datetime64[s]"),
504+
(InlineDateTime64{(SECONDS, 5)}, "datetime64[5s]"),
505+
(InlineDateTime64{NumpyDates.UNBOUND_UNITS}, "datetime64"),
506+
(InlineTimeDelta64{MINUTES}, "timedelta64[m]"),
507+
(InlineTimeDelta64{(SECONDS, 5)}, "timedelta64[5s]"),
508+
(InlineTimeDelta64{NumpyDates.UNBOUND_UNITS}, "timedelta64"),
509+
(Tuple{}, pylist()),
510+
(Tuple{Int32, Int32}, pylist([("f0", "int32"), ("f1", "int32")])),
511+
(@NamedTuple{}, pylist()),
512+
(@NamedTuple{x::Int32, y::Int32}, pylist([("x", "int32"), ("y", "int32")])),
513+
(Pair{Int32, Int32}, pylist([("first", "int32"), ("second", "int32")])),
514+
]
515+
@test pyeq(Bool, pygetattr(pyjl(t), "__numpy_dtype__"), np.dtype(d))
516+
@test pyeq(Bool, np.dtype(pyjl(t)), np.dtype(d))
517+
end
518+
519+
# unsupported cases
520+
@testset "$t -> AttributeError" for t in [
521+
# non-primitives or mutables
522+
String,
523+
Vector{Int},
524+
# pointers
525+
Ptr{Cvoid},
526+
Ptr{Int},
527+
# PyPtr specifically should NOT be interpreted as np.dtype("O")
528+
PythonCall.C.PyPtr,
529+
]
530+
err = try
531+
pygetattr(pyjl(t), "__numpy_dtype__")
532+
nothing
533+
catch err
534+
err
535+
end
536+
@test err isa PythonCall.PyException
537+
@test pyis(err.t, pybuiltins.AttributeError)
538+
end
539+
end
540+
end
482541
end
483542

484543
@testitem "vector" begin

0 commit comments

Comments
 (0)