Skip to content

Commit 9fb6bbc

Browse files
authored
fix: Fix ArcLayer with numpy coords input (#989)
### Change list - Fix numpy input to `PointAccessor` for ArcLayer - Reformat imports across `traits` files - Add test from #901 Closes #901
1 parent 7937c49 commit 9fb6bbc

File tree

8 files changed

+53
-54
lines changed

8 files changed

+53
-54
lines changed

lonboard/traits/_float.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,9 @@
55
from typing import TYPE_CHECKING, Any
66

77
import numpy as np
8-
from arro3.core import (
9-
Array,
10-
ChunkedArray,
11-
DataType,
12-
)
13-
14-
from lonboard._serialization import (
15-
ACCESSOR_SERIALIZATION,
16-
)
8+
from arro3.core import Array, ChunkedArray, DataType
9+
10+
from lonboard._serialization import ACCESSOR_SERIALIZATION
1711
from lonboard.traits._base import FixedErrorTraitType
1812

1913
if TYPE_CHECKING:

lonboard/traits/_map.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
import traitlets
77

88
from lonboard._environment import DEFAULT_HEIGHT
9-
from lonboard._serialization import (
10-
serialize_view_state,
11-
)
9+
from lonboard._serialization import serialize_view_state
1210
from lonboard.models import ViewState
1311
from lonboard.traits._base import FixedErrorTraitType
1412

lonboard/traits/_normal.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,9 @@
66
from typing import TYPE_CHECKING, Any
77

88
import numpy as np
9-
from arro3.core import (
10-
Array,
11-
ChunkedArray,
12-
DataType,
13-
Field,
14-
fixed_size_list_array,
15-
)
16-
17-
from lonboard._serialization import (
18-
ACCESSOR_SERIALIZATION,
19-
)
9+
from arro3.core import Array, ChunkedArray, DataType, Field, fixed_size_list_array
10+
11+
from lonboard._serialization import ACCESSOR_SERIALIZATION
2012
from lonboard.traits._base import FixedErrorTraitType
2113

2214
if TYPE_CHECKING:

lonboard/traits/_point.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,11 @@
55
from typing import TYPE_CHECKING, Any
66

77
import numpy as np
8-
from arro3.core import (
9-
Array,
10-
ChunkedArray,
11-
DataType,
12-
fixed_size_list_array,
13-
)
8+
from arro3.core import Array, ChunkedArray, DataType, Field, fixed_size_list_array
149

10+
from lonboard._geoarrow.extension_types import CoordinateDimension, coord_storage_type
1511
from lonboard._geoarrow.ops.coord_layout import convert_struct_column_to_interleaved
16-
from lonboard._serialization import (
17-
ACCESSOR_SERIALIZATION,
18-
)
12+
from lonboard._serialization import ACCESSOR_SERIALIZATION
1913
from lonboard.traits._base import FixedErrorTraitType
2014

2115
if TYPE_CHECKING:
@@ -63,8 +57,27 @@ def _numpy_to_arrow(self, obj: BaseArrowLayer, value: np.ndarray) -> ChunkedArra
6357
info="Point array to have 2 or 3 as its second dimension",
6458
)
6559

66-
assert np.issubdtype(value.dtype, np.float64)
67-
array = fixed_size_list_array(value.ravel("C"), list_size)
60+
if not np.issubdtype(value.dtype, np.float64):
61+
self.error(obj, value, info="Point array to have float64 type.")
62+
63+
# Set geoarrow extension metadata
64+
field = Field(
65+
"",
66+
coord_storage_type(
67+
interleaved=True,
68+
dims=CoordinateDimension.XY
69+
if list_size == 2
70+
else CoordinateDimension.XYZ,
71+
),
72+
nullable=True,
73+
metadata={"ARROW:extension:name": "geoarrow.point"},
74+
)
75+
array = fixed_size_list_array(
76+
value.ravel("C"),
77+
list_size,
78+
type=field,
79+
)
80+
6881
return ChunkedArray([array])
6982

7083
def validate(

lonboard/traits/_table.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,11 @@
55
from typing import TYPE_CHECKING, Any
66
from typing import cast as type_cast
77

8-
from arro3.core import (
9-
DataType,
10-
Table,
11-
)
8+
from arro3.core import DataType, Table
129

1310
from lonboard._constants import EXTENSION_NAME
1411
from lonboard._geoarrow.box_to_polygon import parse_box_encoded_table
15-
from lonboard._serialization import (
16-
TABLE_SERIALIZATION,
17-
)
12+
from lonboard._serialization import TABLE_SERIALIZATION
1813
from lonboard._utils import get_geometry_column_index
1914
from lonboard.traits._base import FixedErrorTraitType
2015

lonboard/traits/_text.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,9 @@
55
from typing import TYPE_CHECKING, Any
66

77
import numpy as np
8-
from arro3.core import (
9-
Array,
10-
ChunkedArray,
11-
DataType,
12-
)
13-
14-
from lonboard._serialization import (
15-
ACCESSOR_SERIALIZATION,
16-
)
8+
from arro3.core import Array, ChunkedArray, DataType
9+
10+
from lonboard._serialization import ACCESSOR_SERIALIZATION
1711
from lonboard.traits._base import FixedErrorTraitType
1812

1913
if TYPE_CHECKING:

lonboard/traits/_timestamp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
from traitlets.traitlets import TraitType
2121

2222
from lonboard._constants import MAX_INTEGER_FLOAT32, MIN_INTEGER_FLOAT32
23-
from lonboard._serialization import (
24-
TIMESTAMP_ACCESSOR_SERIALIZATION,
25-
)
23+
from lonboard._serialization import TIMESTAMP_ACCESSOR_SERIALIZATION
2624
from lonboard._utils import get_geometry_column_index
2725
from lonboard.traits import FixedErrorTraitType
2826

tests/layers/test_arc_layer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import pyarrow as pa
3-
from arro3.core import Table
3+
from arro3.core import ChunkedArray, Table
44
from geoarrow.rust.core import point, points
55

66
from lonboard import ArcLayer, Map
@@ -40,3 +40,18 @@ def test_arc_layer_geoarrow_separated():
4040
)
4141
m = Map(layer)
4242
assert isinstance(m.layers[0], ArcLayer)
43+
44+
45+
def test_arc_layer_numpy():
46+
data = {
47+
"source": ["London", "Manchester", "Bristol"],
48+
"target": ["Manchester", "Bristol", "London"],
49+
}
50+
51+
source = np.array([(51.5072, 0.1276), (53.4808, 2.2426), (51.4545, 2.5879)])
52+
target = np.array([(53.4808, 2.2426), (51.4545, 2.5879), (51.5072, 0.1276)])
53+
54+
table = pa.table(data)
55+
56+
layer = ArcLayer(table, get_source_position=source, get_target_position=target)
57+
assert isinstance(layer.get_source_position, ChunkedArray)

0 commit comments

Comments
 (0)