Skip to content

Commit a4f6cdd

Browse files
authored
fix arrow type id mapping (#742)
1 parent 7019b0f commit a4f6cdd

File tree

5 files changed

+112
-21
lines changed

5 files changed

+112
-21
lines changed

python/src/dataframe.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,7 @@ impl DataFrame {
159159
}
160160
};
161161

162-
let builder = errors::wrap(builder.join(
163-
&right.plan,
164-
join_type,
165-
on.clone(),
166-
on,
167-
))?;
162+
let builder = errors::wrap(builder.join(&right.plan, join_type, on.clone(), on))?;
168163

169164
let plan = errors::wrap(builder.build())?;
170165

python/src/to_rust.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ pub fn to_rust(ob: &PyAny) -> PyResult<ArrayRef> {
4848
Ok(array)
4949
}
5050

51+
/// converts a pyarrow batch into a RecordBatch
5152
pub fn to_rust_batch(batch: &PyAny) -> PyResult<RecordBatch> {
5253
let schema = batch.getattr("schema")?;
5354
let names = schema.getattr("names")?.extract::<Vec<String>>()?;

python/src/types.rs

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,13 @@ fn data_type_id(id: &i32) -> Result<DataType, errors::DataFusionError> {
4848
7 => DataType::Int32,
4949
8 => DataType::UInt64,
5050
9 => DataType::Int64,
51-
5251
10 => DataType::Float16,
5352
11 => DataType::Float32,
5453
12 => DataType::Float64,
55-
56-
//13 => DataType::Decimal,
57-
58-
// 14 => DataType::Date32(),
59-
// 15 => DataType::Date64(),
60-
// 16 => DataType::Timestamp(),
61-
// 17 => DataType::Time32(),
62-
// 18 => DataType::Time64(),
63-
// 19 => DataType::Duration()
64-
20 => DataType::Binary,
65-
21 => DataType::Utf8,
66-
22 => DataType::LargeBinary,
67-
23 => DataType::LargeUtf8,
68-
54+
13 => DataType::Utf8,
55+
14 => DataType::Binary,
56+
34 => DataType::LargeUtf8,
57+
35 => DataType::LargeBinary,
6958
other => {
7059
return Err(errors::DataFusionError::Common(format!(
7160
"The type {} is not valid",

python/tests/test_pa_types.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import pyarrow as pa
19+
20+
21+
def test_type_ids():
22+
"""having this fixed is very important because internally we rely on this id to parse from
23+
python"""
24+
for idx, arrow_type in [
25+
(0, pa.null()),
26+
(1, pa.bool_()),
27+
(2, pa.uint8()),
28+
(3, pa.int8()),
29+
(4, pa.uint16()),
30+
(5, pa.int16()),
31+
(6, pa.uint32()),
32+
(7, pa.int32()),
33+
(8, pa.uint64()),
34+
(9, pa.int64()),
35+
(10, pa.float16()),
36+
(11, pa.float32()),
37+
(12, pa.float64()),
38+
(13, pa.string()),
39+
(13, pa.utf8()),
40+
(14, pa.binary()),
41+
(16, pa.date32()),
42+
(17, pa.date64()),
43+
(18, pa.timestamp("us")),
44+
(19, pa.time32("s")),
45+
(20, pa.time64("us")),
46+
(23, pa.decimal128(8, 1)),
47+
(34, pa.large_utf8()),
48+
(35, pa.large_binary()),
49+
]:
50+
51+
assert idx == arrow_type.id
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import pyarrow as pa
19+
import pytest
20+
from datafusion import ExecutionContext
21+
from datafusion import functions as f
22+
23+
24+
@pytest.fixture
25+
def df():
26+
ctx = ExecutionContext()
27+
28+
# create a RecordBatch and a new DataFrame from it
29+
batch = pa.RecordBatch.from_arrays(
30+
[pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])],
31+
names=["a", "b"],
32+
)
33+
34+
return ctx.create_dataframe([[batch]])
35+
36+
37+
def test_string_functions(df):
38+
df = df.select(f.md5(f.col("a")), f.lower(f.col("a")))
39+
result = df.collect()
40+
assert len(result) == 1
41+
result = result[0]
42+
assert result.column(0) == pa.array(
43+
[
44+
"8b1a9953c4611296a827abf8c47804d7",
45+
"f5a7924e621e84c9280a9a27e1bcb7f6",
46+
"9033e0e305f247c0c3c80d0c7848c8b3",
47+
]
48+
)
49+
assert result.column(1) == pa.array(
50+
[
51+
"hello",
52+
"world",
53+
"!",
54+
]
55+
)

0 commit comments

Comments
 (0)