Skip to content

Commit bf20c77

Browse files
committed
Add typehints & add required requested_schema parameter to the PyCapsule wrappers
1 parent 12b4236 commit bf20c77

File tree

1 file changed

+34
-11
lines changed

1 file changed

+34
-11
lines changed

arrow-pyarrow-integration-testing/tests/test_sql.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import datetime
2121
import decimal
2222
import string
23+
from typing import Tuple, Protocol
2324

2425
import pytest
2526
import pyarrow as pa
@@ -120,28 +121,50 @@ def assert_pyarrow_leak():
120121
# This defines that Arrow consumers should allow any object that has specific "dunder"
121122
# methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able to handle
122123
# _any_ class, without pyarrow-specific handling.
123-
class SchemaWrapper:
124-
def __init__(self, schema):
124+
125+
126+
class ArrowSchemaExportable(Protocol):
127+
def __arrow_c_schema__(self) -> object: ...
128+
129+
130+
class ArrowArrayExportable(Protocol):
131+
def __arrow_c_array__(
132+
self,
133+
requested_schema: object | None = None
134+
) -> Tuple[object, object]:
135+
...
136+
137+
138+
class ArrowStreamExportable(Protocol):
139+
def __arrow_c_stream__(
140+
self,
141+
requested_schema: object | None = None
142+
) -> object:
143+
...
144+
145+
146+
class SchemaWrapper(ArrowSchemaExportable):
147+
def __init__(self, schema: ArrowSchemaExportable) -> None:
125148
self.schema = schema
126149

127-
def __arrow_c_schema__(self):
150+
def __arrow_c_schema__(self) -> object:
128151
return self.schema.__arrow_c_schema__()
129152

130153

131-
class ArrayWrapper:
132-
def __init__(self, array):
154+
class ArrayWrapper(ArrowArrayExportable):
155+
def __init__(self, array: ArrowArrayExportable) -> None:
133156
self.array = array
134157

135-
def __arrow_c_array__(self):
136-
return self.array.__arrow_c_array__()
158+
def __arrow_c_array__(self, requested_schema: object | None = None) -> Tuple[object, object]:
159+
return self.array.__arrow_c_array__(requested_schema=requested_schema)
137160

138161

139-
class StreamWrapper:
140-
def __init__(self, stream):
162+
class StreamWrapper(ArrowStreamExportable):
163+
def __init__(self, stream: ArrowStreamExportable) -> None:
141164
self.stream = stream
142165

143-
def __arrow_c_stream__(self):
144-
return self.stream.__arrow_c_stream__()
166+
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
167+
return self.stream.__arrow_c_stream__(requested_schema=requested_schema)
145168

146169

147170
@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)

0 commit comments

Comments
 (0)