@@ -61,13 +61,13 @@ def test_plot_func(self, curve_class, dim):
61
61
62
62
@pytest .mark .parametrize ("batch_size" , [1 , 5 ])
63
63
def test_fit_func (self , curve_class , batch_size ):
64
- """ test fit function """
64
+ """test fit function"""
65
65
c = curve_class (torch .randn (batch_size , 2 ), torch .randn (batch_size , 2 ), 20 )
66
66
loss = c .fit (torch .linspace (0 , 1 , 10 ), torch .randn (5 , 10 , 2 ))
67
67
assert isinstance (loss , torch .Tensor )
68
68
69
69
def test_getindex_func (self , curve_class ):
70
- """ test __getidx__ function """
70
+ """test __getidx__ function"""
71
71
batched_c = curve_class (torch .randn (5 , 2 ), torch .randn (5 , 2 ))
72
72
for i in range (len (batched_c )):
73
73
c = batched_c [i ]
@@ -77,14 +77,14 @@ def test_getindex_func(self, curve_class):
77
77
assert c .device == batched_c .device
78
78
79
79
def test_setindex_func (self , curve_class ):
80
- """ test __setidx__ function """
80
+ """test __setidx__ function"""
81
81
batched_c = curve_class (torch .randn (5 , 2 ), torch .randn (5 , 2 ))
82
82
for i in range (len (batched_c )):
83
83
batched_c [i ] = curve_class (torch .randn (1 , 2 ), torch .randn (1 , 2 ))
84
84
assert batched_c [i ]
85
85
86
86
def test_to_other (self , curve_class ):
87
- """ test .tospline and .todiscrete """
87
+ """test .tospline and .todiscrete"""
88
88
c = curve_class (torch .randn (1 , 2 ), torch .randn (1 , 2 ), 20 )
89
89
if curve_class == curves .DiscreteCurve :
90
90
new_c = c .tospline ()
@@ -112,3 +112,19 @@ def test_constant_speed(self, curve_class):
112
112
assert isinstance (Ct , torch .Tensor )
113
113
assert new_t .shape == (batch_size , timesteps )
114
114
assert Ct .shape == (batch_size , timesteps , dim )
115
+
116
+ def test_plotting_in_axis (self , curve_class ):
117
+ batch_size = 5
118
+ dim = 2
119
+ begin = torch .randn (batch_size , dim )
120
+ end = torch .randn (batch_size , dim )
121
+ c = curve_class (begin , end , 20 )
122
+ try :
123
+ import torchplot as plt
124
+
125
+ fig , ax = plt .subplots (1 , 1 )
126
+ c .plot (ax = ax )
127
+ plt .close (fig )
128
+ assert True
129
+ except Exception as e :
130
+ assert False , e
0 commit comments