Skip to content

Commit 8639ee2

Browse files
authored
Merge pull request #12 from MachineLearningLifeScience/plotting-in-curves
Adds plotting in individual axes
2 parents 734162e + 2264f5b commit 8639ee2

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

stochman/curves.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Optional, Tuple
44

55
import torch
6+
from matplotlib.axis import Axis
67
from torch import nn
78

89

@@ -47,7 +48,9 @@ def __len__(self):
4748
"""Returns the batch dimension e.g. the number of curves"""
4849
return self.begin.shape[0]
4950

50-
def plot(self, t0: float = 0.0, t1: float = 1.0, N: int = 100, *plot_args, **plot_kwargs):
51+
def plot(
52+
self, t0: float = 0.0, t1: float = 1.0, N: int = 100, ax: Axis = None, *plot_args, **plot_kwargs
53+
):
5154
"""Plot the curve.
5255
5356
Args:
@@ -70,15 +73,20 @@ def plot(self, t0: float = 0.0, t1: float = 1.0, N: int = 100, *plot_args, **plo
7073
if len(points.shape) == 2:
7174
points.unsqueeze_(0) # 1xNxD
7275

76+
plot_in = ax or plt
77+
if ax is not None:
78+
t = t.detach().numpy()
79+
points = points.detach().numpy()
80+
7381
figs = []
7482
if points.shape[-1] == 1:
7583
for b in range(points.shape[0]):
76-
fig = plt.plot(t, points[b], *plot_args, **plot_kwargs)
84+
fig = plot_in.plot(t, points[b], *plot_args, **plot_kwargs)
7785
figs.append(fig)
7886
return figs
7987
if points.shape[-1] == 2:
8088
for b in range(points.shape[0]):
81-
fig = plt.plot(points[b, :, 0], points[b, :, 1], *plot_args, **plot_kwargs)
89+
fig = plot_in.plot(points[b, :, 0], points[b, :, 1], *plot_args, **plot_kwargs)
8290
figs.append(fig)
8391
return figs
8492

tests/test_curves.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@ def test_plot_func(self, curve_class, dim):
6161

6262
@pytest.mark.parametrize("batch_size", [1, 5])
6363
def test_fit_func(self, curve_class, batch_size):
64-
""" test fit function """
64+
"""test fit function"""
6565
c = curve_class(torch.randn(batch_size, 2), torch.randn(batch_size, 2), 20)
6666
loss = c.fit(torch.linspace(0, 1, 10), torch.randn(5, 10, 2))
6767
assert isinstance(loss, torch.Tensor)
6868

6969
def test_getindex_func(self, curve_class):
70-
""" test __getidx__ function """
70+
"""test __getidx__ function"""
7171
batched_c = curve_class(torch.randn(5, 2), torch.randn(5, 2))
7272
for i in range(len(batched_c)):
7373
c = batched_c[i]
@@ -77,14 +77,14 @@ def test_getindex_func(self, curve_class):
7777
assert c.device == batched_c.device
7878

7979
def test_setindex_func(self, curve_class):
80-
""" test __setidx__ function """
80+
"""test __setidx__ function"""
8181
batched_c = curve_class(torch.randn(5, 2), torch.randn(5, 2))
8282
for i in range(len(batched_c)):
8383
batched_c[i] = curve_class(torch.randn(1, 2), torch.randn(1, 2))
8484
assert batched_c[i]
8585

8686
def test_to_other(self, curve_class):
87-
""" test .tospline and .todiscrete """
87+
"""test .tospline and .todiscrete"""
8888
c = curve_class(torch.randn(1, 2), torch.randn(1, 2), 20)
8989
if curve_class == curves.DiscreteCurve:
9090
new_c = c.tospline()
@@ -112,3 +112,19 @@ def test_constant_speed(self, curve_class):
112112
assert isinstance(Ct, torch.Tensor)
113113
assert new_t.shape == (batch_size, timesteps)
114114
assert Ct.shape == (batch_size, timesteps, dim)
115+
116+
def test_plotting_in_axis(self, curve_class):
117+
batch_size = 5
118+
dim = 2
119+
begin = torch.randn(batch_size, dim)
120+
end = torch.randn(batch_size, dim)
121+
c = curve_class(begin, end, 20)
122+
try:
123+
import torchplot as plt
124+
125+
fig, ax = plt.subplots(1, 1)
126+
c.plot(ax=ax)
127+
plt.close(fig)
128+
assert True
129+
except Exception as e:
130+
assert False, e

0 commit comments

Comments
 (0)