Skip to content

Commit

Permalink
Add point clouds summary writer to kauldron interface with tensorboar…
Browse files Browse the repository at this point in the history
…d and jaxboard.

PiperOrigin-RevId: 665840429
  • Loading branch information
CLU Authors authored and copybara-github committed Aug 21, 2024
1 parent 7c22ddf commit 307b0bc
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 1 deletion.
16 changes: 16 additions & 0 deletions clu/metric_writers/async_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,22 @@ def write_histograms(self,
self._pool(self._writer.write_histograms)(
step=step, arrays=arrays, num_buckets=num_buckets)

@_wrap_exceptions
def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
self._pool(self._writer.write_pointcloud)(
step=step,
point_clouds=point_clouds,
point_colors=point_colors,
configs=configs,
)

@_wrap_exceptions
def write_hparams(self, hparams: Mapping[str, Any]):
self._pool(self._writer.write_hparams)(hparams=hparams)
Expand Down
21 changes: 21 additions & 0 deletions clu/metric_writers/async_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,27 @@ def test_write_videos(self):
self.sync_writer.write_videos.assert_called_with(4,
{"input_videos": mock.ANY})

def test_write_pointcloud(self):
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
config = {
"material": "PointCloudMaterial",
"size": 0.09,
}
self.writer.write_pointcloud(
step=0,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"config": config},
)
self.writer.flush()
self.sync_writer.write_pointcloud.assert_called_with(
step=0,
point_clouds={"pcd": mock.ANY},
point_colors={"pcd": mock.ANY},
configs={"config": mock.ANY},
)

def test_write_texts(self):
self.writer.write_texts(4, {"samples": "bla"})
self.writer.flush()
Expand Down
20 changes: 20 additions & 0 deletions clu/metric_writers/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,26 @@ def write_histograms(self,
of the MetricWriter.
"""

def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
"""Writes point cloud summaries.
Args:
step: Step at which the point cloud was generated.
point_clouds: Mapping from point clouds key to point cloud of shape [N, 3]
array of point coordinates.
point_colors: Mapping from point colors key to [N, 3] array of point
colors.
configs: A dictionary of configuration options for the point cloud.
"""
raise NotImplementedError()

@abc.abstractmethod
def write_hparams(self, hparams: Mapping[str, Any]):
"""Write hyper parameters.
Expand Down
21 changes: 21 additions & 0 deletions clu/metric_writers/logging_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ def write_histograms(self,
self._collection_str, key,
_get_histogram_as_string(histo, bins))

def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Any] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
logging.info(
"[%d]%s Got point clouds: %s, point_colors: %s, configs: %s.",
step,
self._collection_str,
{k: v.shape for k, v in point_clouds.items()},
(
{k: v.shape for k, v in point_colors.items()}
if point_colors is not None
else None
),
configs,
)

def write_hparams(self, hparams: Mapping[str, Any]):
logging.info("[Hyperparameters]%s %s", self._collection_str, hparams)

Expand Down
23 changes: 23 additions & 0 deletions clu/metric_writers/logging_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,29 @@ def test_write_histogram(self):
"INFO:absl:[4] Histogram for 'c' = {[-0.4, 0.6]: 5}",
])

def test_write_pointcloud(self):
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
config = {
"material": "PointCloudMaterial",
"size": 0.09,
}
with self.assertLogs(level="INFO") as logs:
self.writer.write_pointcloud(
step=4,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"configs": config},
)
self.assertEqual(
logs.output,
[
"INFO:absl:[4] Got point clouds: {'pcd': (1, 1024, 3)},"
" point_colors: {'pcd': (1, 1024, 3)}, configs: {'configs':"
" {'material': 'PointCloudMaterial', 'size': 0.09}}."
],
)

def test_write_hparams(self):
with self.assertLogs(level="INFO") as logs:
self.writer.write_hparams({"learning_rate": 0.1, "batch_size": 128})
Expand Down
13 changes: 13 additions & 0 deletions clu/metric_writers/multi_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ def write_histograms(self,
for w in self._writers:
w.write_histograms(step, arrays, num_buckets)

def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
for w in self._writers:
w.write_pointcloud(
step, point_clouds, point_colors=point_colors, configs=configs
)

def write_hparams(self, hparams: Mapping[str, Any]):
for w in self._writers:
w.write_hparams(hparams)
Expand Down
24 changes: 24 additions & 0 deletions clu/metric_writers/multi_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from clu.metric_writers import interface
from clu.metric_writers import multi_writer
import numpy as np
import tensorflow as tf


Expand Down Expand Up @@ -48,6 +49,29 @@ def test_write_scalars(self):
])
w.flush.assert_called()

def test_write_pointcloud(self):
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
config = {
"material": "PointCloudMaterial",
"size": 0.09,
}
self.writer.write_pointcloud(
step=0,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"config": config},
)
self.writer.flush()
for w in self.writers:
w.write_pointcloud.assert_called_with(
step=0,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"config": config},
)
w.flush.assert_called()


if __name__ == "__main__":
tf.test.main()
21 changes: 21 additions & 0 deletions clu/metric_writers/tf/summary_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
from tensorboard.plugins.hparams import api as hparams_api
from tensorboard.plugins.mesh import summary as mesh_summary # pylint: disable=line-too-long
# pylint: enable=g-import-not-at-top


Expand Down Expand Up @@ -97,6 +98,26 @@ def write_histograms(
buckets = None if num_buckets is None else num_buckets.get(key)
tf.summary.histogram(key, value, step=step, buckets=buckets)

def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
with self._summary_writer.as_default():
for key, vertices in point_clouds.items():
colors = None if point_colors is None else point_colors.get(key)
config = None if configs is None else configs.get(key)
mesh_summary.mesh(
key,
vertices=vertices,
colors=colors,
step=step,
config_dict=config,
)

def write_hparams(self, hparams: Mapping[str, Any]):
with self._summary_writer.as_default():
hparams_api.hparams(dict(utils.flatten_dict(hparams)))
Expand Down
33 changes: 33 additions & 0 deletions clu/metric_writers/tf/summary_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ def _load_scalars_data(logdir: str):
return data


def _load_pointcloud_data(logdir: str):
"""Loads pointcloud summaries from events in a logdir."""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
data = collections.defaultdict(dict)
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
for value in event.summary.value:
if value.metadata.plugin_data.plugin_name == "mesh":
if "config" not in value.tag:
data[event.step][value.tag] = tf.make_ndarray(value.tensor)
else:
data[event.step][value.tag] = value.metadata.plugin_data.content
return data


def _load_hparams(logdir: str):
"""Loads hparams summaries from events in a logdir."""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
Expand Down Expand Up @@ -142,6 +157,24 @@ def test_write_histograms(self):
]
self.assertAllClose(data["b"], ([0, 2], expected_histograms_b))

def test_write_pointcloud(self):
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
config = {
"material": "PointCloudMaterial",
"size": 0.09,
}
self.writer.write_pointcloud(
step=0,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"config": config},
)
self.writer.flush()
data = _load_pointcloud_data(self.logdir)
self.assertAllClose(data[0]["pcd_VERTEX"], point_clouds)
self.assertAllClose(data[0]["pcd_COLOR"], point_colors)

def test_hparams(self):
self.writer.write_hparams(dict(batch_size=512, num_epochs=90))
hparams = _load_hparams(self.logdir)
Expand Down
15 changes: 14 additions & 1 deletion clu/metric_writers/torch_tensorboard_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Any, Optional
from absl import logging


from clu.metric_writers import interface
from torch.utils import tensorboard

Expand Down Expand Up @@ -79,6 +78,20 @@ def write_histograms(self,
self._writer.add_histogram(
tag, values, global_step=step, bins="auto", max_bins=bins)

def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
logging.log_first_n(
logging.WARNING,
"TorchTensorBoardWriter does not support writing point clouds.",
1,
)

def write_hparams(self, hparams: Mapping[str, Any]):
self._writer.add_hparams(hparams, {})

Expand Down

0 comments on commit 307b0bc

Please sign in to comment.