diff --git a/python/BUILD.bazel b/python/BUILD.bazel index a63df92..482ccbf 100644 --- a/python/BUILD.bazel +++ b/python/BUILD.bazel @@ -1,4 +1,4 @@ -load("@rules_python//python:defs.bzl", "py_binary") +load("@rules_python//python:defs.bzl", "py_binary", "py_library") load("@rules_python//python:pip.bzl", "compile_pip_requirements") load("@pypi//:requirements.bzl", "all_requirements") @@ -8,8 +8,21 @@ compile_pip_requirements( requirements_txt = "requirements_lock.txt", ) +py_library( + name = "geom", + srcs = ["geom.py"], + imports = ["rigid_geometric_algebra.python"], + deps = all_requirements, +) + py_binary( name = "plot", srcs = ["plot.py"], - deps = all_requirements, + deps = [":geom"] + all_requirements, +) + +py_binary( + name = "example_plot", + srcs = ["example_plot.py"], + deps = [":geom"] + all_requirements, ) diff --git a/python/example_plot.py b/python/example_plot.py new file mode 100755 index 0000000..b7783b1 --- /dev/null +++ b/python/example_plot.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +import matplotlib.pyplot as plt + +from python import geom + +if __name__ == "__main__": + elements = [ + geom.Point([0, 0, 0]), + geom.Point([1, 0, 0]), + geom.Point([0, 1, 0]), + geom.Point([0, 0, 1]), + ] + + fig = geom.plot(elements) + plt.show() diff --git a/python/geom.py b/python/geom.py new file mode 100644 index 0000000..8c7a6c2 --- /dev/null +++ b/python/geom.py @@ -0,0 +1,37 @@ +from collections.abc import Iterable +from dataclasses import dataclass + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.axes import Axes +from matplotlib.collections import PathCollection + + +@dataclass +class _Data: + data: np.ndarray + + +class Point(_Data): + @property + def data(self) -> np.ndarray: + return self._data + + @data.setter + def data(self, value: Iterable[float]) -> None: + self._data = np.array(value, dtype=np.float64) + + def add_to(self, ax: Axes) -> PathCollection: + return ax.scatter(*self.data) + + +def plot(data: Iterable[Point]) -> plt.Figure: + fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + + for p in data: + p.add_to(ax) + + return fig diff --git a/python/plot.py b/python/plot.py index 0cdb6af..c0ceaf1 100755 --- a/python/plot.py +++ b/python/plot.py @@ -3,28 +3,13 @@ import json import sys from collections.abc import Iterable -from dataclasses import dataclass import matplotlib.pyplot as plt -import numpy as np +from python import geom -@dataclass -class _Point: - data: np.ndarray - -class Point(_Point): - @property - def data(self) -> np.ndarray: - return self._data - - @data.setter - def data(self, value: Iterable[float]) -> None: - self._data = np.array(value, dtype=np.float64) - - -def parse(raw: dict[str, Iterable[float]]) -> Point: +def parse(raw: dict[str, Iterable[float]]) -> geom.Point: if len(raw) != 1: msg = "`raw` must contain a single item" raise ValueError(msg) @@ -32,22 +17,10 @@ def parse(raw: dict[str, Iterable[float]]) -> Point: k, v = next(iter(raw.items())) return { - "point": Point, + "point": geom.Point, }[k](v) -def plot(data: Iterable[Point]) -> plt.Figure: - points = np.stack([d.data for d in data if isinstance(d, Point)]) - - fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) - ax.set_xlabel("X") - ax.set_ylabel("Y") - ax.set_zlabel("Z") - ax.scatter(*np.unstack(points, axis=1)) - - return fig - - if __name__ == "__main__": try: line = fileinput.input().readline() @@ -67,5 +40,5 @@ def plot(data: Iterable[Point]) -> plt.Figure: sys.exit(1) raw_values = json.loads(line) - fig = plot(map(parse, raw_values)) + fig = geom.plot(map(parse, raw_values)) plt.show()