Skip to content
This repository was archived by the owner on Dec 20, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.4.1...HEAD)

### Added

- feat: Support for multi-dimensional node attributes in plots (#86)

## [0.4.1 - ICON graphs, multiple edge builders and post processors](https://github.com/ecmwf/anemoi-graphs/compare/0.4.0...0.4.1) - 2024-11-26

### Added
Expand Down
40 changes: 21 additions & 19 deletions src/anemoi/graphs/plotting/interactive_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import torch
from matplotlib.colors import rgb2hex
from torch_geometric.data import HeteroData

Expand Down Expand Up @@ -197,25 +198,26 @@ def plot_interactive_nodes(graph: HeteroData, nodes_name: str, out_file: Optiona
for node_attr in node_attrs:
node_attr_values = graph[nodes_name][node_attr].float().numpy()

# Skip multi-dimensional attributes. Supported only: (N, 1) or (N,) tensors
if node_attr_values.ndim > 1 and node_attr_values.shape[1] > 1:
continue

node_traces[node_attr] = go.Scattergeo(
lat=node_latitudes,
lon=node_longitudes,
name=" ".join(node_attr.split("_")).capitalize(),
mode="markers",
hoverinfo="text",
marker={
"color": node_attr_values.squeeze().tolist(),
"showscale": True,
"colorscale": "RdBu",
"colorbar": {"thickness": 15, "title": node_attr, "xanchor": "left"},
"size": 5,
},
visible=False,
)
if node_attr_values.ndim == 1:
node_attr_values = torch.unsqueeze(node_attr_values, -1)

for attr_dim in range(node_attr_values.shape[1]):
suffix = "" if node_attr_values.shape[1] == 1 else f"_[{attr_dim}]"
node_traces[node_attr + suffix] = go.Scattergeo(
lat=node_latitudes,
lon=node_longitudes,
name=" ".join((node_attr + suffix).split("_")).capitalize(),
mode="markers",
hoverinfo="text",
marker={
"color": node_attr_values[:, attr_dim].squeeze().tolist(),
"showscale": True,
"colorscale": "RdBu",
"colorbar": {"thickness": 15, "title": node_attr + suffix, "xanchor": "left"},
"size": 5,
},
visible=False,
)

# Create and add slider
slider_steps = []
Expand Down
Loading