Skip to content

Commit

Permalink
create geom lib for plotting elements
Browse files Browse the repository at this point in the history
Change-Id: I056ea529207457f51f4be22380f8e2d3f140b404
  • Loading branch information
oliverlee committed Oct 3, 2024
1 parent 77ae341 commit 292e182
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 33 deletions.
17 changes: 15 additions & 2 deletions python/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("@rules_python//python:defs.bzl", "py_binary")
load("@rules_python//python:defs.bzl", "py_binary", "py_library")
load("@rules_python//python:pip.bzl", "compile_pip_requirements")
load("@pypi//:requirements.bzl", "all_requirements")

Expand All @@ -8,8 +8,21 @@ compile_pip_requirements(
requirements_txt = "requirements_lock.txt",
)

py_library(
name = "geom",
srcs = ["geom.py"],
imports = ["rigid_geometric_algebra.python"],
deps = all_requirements,
)

py_binary(
name = "plot",
srcs = ["plot.py"],
deps = all_requirements,
deps = [":geom"] + all_requirements,
)

py_binary(
name = "example_plot",
srcs = ["example_plot.py"],
deps = [":geom"] + all_requirements,
)
16 changes: 16 additions & 0 deletions python/example_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env python

import matplotlib.pyplot as plt

from python import geom

if __name__ == "__main__":
elements = [
geom.Point([0, 0, 0]),
geom.Point([1, 0, 0]),
geom.Point([0, 1, 0]),
geom.Point([0, 0, 1]),
]

fig = geom.plot(elements)
plt.show()
37 changes: 37 additions & 0 deletions python/geom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from collections.abc import Iterable
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.collections import PathCollection


@dataclass
class _Data:
data: np.ndarray


class Point(_Data):
@property
def data(self) -> np.ndarray:
return self._data

@data.setter
def data(self, value: Iterable[float]) -> None:
self._data = np.array(value, dtype=np.float64)

def add_to(self, ax: Axes) -> PathCollection:
return ax.scatter(*self.data)


def plot(data: Iterable[Point]) -> plt.Figure:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")

for p in data:
p.add_to(ax)

return fig
35 changes: 4 additions & 31 deletions python/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,24 @@
import json
import sys
from collections.abc import Iterable
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np

from python import geom

@dataclass
class _Point:
data: np.ndarray


class Point(_Point):
@property
def data(self) -> np.ndarray:
return self._data

@data.setter
def data(self, value: Iterable[float]) -> None:
self._data = np.array(value, dtype=np.float64)


def parse(raw: dict[str, Iterable[float]]) -> Point:
def parse(raw: dict[str, Iterable[float]]) -> geom.Point:
if len(raw) != 1:
msg = "`raw` must contain a single item"
raise ValueError(msg)

k, v = next(iter(raw.items()))

return {
"point": Point,
"point": geom.Point,
}[k](v)


def plot(data: Iterable[Point]) -> plt.Figure:
points = np.stack([d.data for d in data if isinstance(d, Point)])

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.scatter(*np.unstack(points, axis=1))

return fig


if __name__ == "__main__":
try:
line = fileinput.input().readline()
Expand All @@ -67,5 +40,5 @@ def plot(data: Iterable[Point]) -> plt.Figure:
sys.exit(1)

raw_values = json.loads(line)
fig = plot(map(parse, raw_values))
fig = geom.plot(map(parse, raw_values))
plt.show()

0 comments on commit 292e182

Please sign in to comment.