Skip to content

Commit

Permalink
define geom::Line type and plot method (#99)
Browse files Browse the repository at this point in the history
Change-Id: I73f509c32238441fd4219203231a30059dcb2e42
  • Loading branch information
oliverlee authored Oct 3, 2024
1 parent b120c86 commit 8044708
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 4 deletions.
2 changes: 2 additions & 0 deletions python/example_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
geom.Point([1, 0, 0]),
geom.Point([0, 1, 0]),
geom.Point([0, 0, 1]),
geom.Line([0, 1, 0, 0, 0, 1]),
geom.Line([0, 1, 0, 0, 0, 0.5]),
]

fig = geom.plot(elements)
Expand Down
73 changes: 69 additions & 4 deletions python/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,23 @@
import numpy as np
from matplotlib.axes import Axes
from matplotlib.collections import PathCollection
from mpl_toolkits.mplot3d.art3d import Line3DCollection


class GeomError(ArithmeticError):
pass


def require(cond: bool, message: str) -> None: # noqa: FBT001
# https://github.com/astral-sh/ruff/issues/9497
if not cond:
raise GeomError(message)


@dataclass
class _Data:
data: np.ndarray


class Point(_Data):
@property
def data(self) -> np.ndarray:
return self._data
Expand All @@ -21,11 +30,67 @@ def data(self) -> np.ndarray:
def data(self, value: Iterable[float]) -> None:
self._data = np.array(value, dtype=np.float64)

@property
def view(self) -> np.ndarray:
v = self.data.view()
v.setflags(write=False)
return v

def _invariant(self) -> None:
pass

def __post_init__(self) -> None:
self._invariant()


class Point(_Data):
def add_to(self, ax: Axes) -> PathCollection:
return ax.scatter(*self.view)


class Line(_Data):
def _invariant(self) -> None:
require(
self.view.size % 2 == 0, "`Line` must contain an even number of elements"
)
require(
np.dot(self.direction, self.moment) == 0,
"The `direction` and `moment` of a `Line` must be orthogonal",
)

@property
def direction(self) -> np.ndarray:
return self.view[: (self.view.size // 2)]

@property
def moment(self) -> np.ndarray:
return self.view[(self.view.size // 2) :]

def add_to(self, ax: Axes) -> PathCollection:
return ax.scatter(*self.data)
norm = np.linalg.norm
u = np.cross(self.direction, self.moment)
p = norm(self.moment) / norm(self.direction) * u / norm(u)

# these lines should "extend" to infinity
lines = Line3DCollection(
[
np.vstack(
[
p - 1000 * self.direction,
p - 10 * self.direction,
p - self.direction,
p,
p + self.direction,
p + 10 * self.direction,
p + 1000 * self.direction,
]
)
]
)
return ax.add_collection(lines)


def plot(data: Iterable[Point]) -> plt.Figure:
def plot(data: Iterable[_Data]) -> plt.Figure:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
ax.set_xlabel("x")
ax.set_ylabel("y")
Expand Down

0 comments on commit 8044708

Please sign in to comment.