|
20 | 20 | import datetime |
21 | 21 | import decimal |
22 | 22 | import string |
| 23 | +from typing import Tuple, Protocol |
23 | 24 |
|
24 | 25 | import pytest |
25 | 26 | import pyarrow as pa |
@@ -120,28 +121,50 @@ def assert_pyarrow_leak(): |
120 | 121 | # This defines that Arrow consumers should allow any object that has specific "dunder" |
121 | 122 | # methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able to handle |
122 | 123 | # _any_ class, without pyarrow-specific handling. |
123 | | -class SchemaWrapper: |
124 | | - def __init__(self, schema): |
| 124 | + |
| 125 | + |
| 126 | +class ArrowSchemaExportable(Protocol): |
| 127 | + def __arrow_c_schema__(self) -> object: ... |
| 128 | + |
| 129 | + |
| 130 | +class ArrowArrayExportable(Protocol): |
| 131 | + def __arrow_c_array__( |
| 132 | + self, |
| 133 | + requested_schema: object | None = None |
| 134 | + ) -> Tuple[object, object]: |
| 135 | + ... |
| 136 | + |
| 137 | + |
| 138 | +class ArrowStreamExportable(Protocol): |
| 139 | + def __arrow_c_stream__( |
| 140 | + self, |
| 141 | + requested_schema: object | None = None |
| 142 | + ) -> object: |
| 143 | + ... |
| 144 | + |
| 145 | + |
| 146 | +class SchemaWrapper(ArrowSchemaExportable): |
| 147 | + def __init__(self, schema: ArrowSchemaExportable) -> None: |
125 | 148 | self.schema = schema |
126 | 149 |
|
127 | | - def __arrow_c_schema__(self): |
| 150 | + def __arrow_c_schema__(self) -> object: |
128 | 151 | return self.schema.__arrow_c_schema__() |
129 | 152 |
|
130 | 153 |
|
131 | | -class ArrayWrapper: |
132 | | - def __init__(self, array): |
| 154 | +class ArrayWrapper(ArrowArrayExportable): |
| 155 | + def __init__(self, array: ArrowArrayExportable) -> None: |
133 | 156 | self.array = array |
134 | 157 |
|
135 | | - def __arrow_c_array__(self): |
136 | | - return self.array.__arrow_c_array__() |
| 158 | + def __arrow_c_array__(self, requested_schema: object | None = None) -> Tuple[object, object]: |
| 159 | + return self.array.__arrow_c_array__(requested_schema=requested_schema) |
137 | 160 |
|
138 | 161 |
|
139 | | -class StreamWrapper: |
140 | | - def __init__(self, stream): |
| 162 | +class StreamWrapper(ArrowStreamExportable): |
| 163 | + def __init__(self, stream: ArrowStreamExportable) -> None: |
141 | 164 | self.stream = stream |
142 | 165 |
|
143 | | - def __arrow_c_stream__(self): |
144 | | - return self.stream.__arrow_c_stream__() |
| 166 | + def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: |
| 167 | + return self.stream.__arrow_c_stream__(requested_schema=requested_schema) |
145 | 168 |
|
146 | 169 |
|
147 | 170 | @pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str) |
|
0 commit comments