Skip to content

Commit 21fb81f

Browse files
author
Matthieu Ancellin
committed
Clean up some tests.
1 parent 670b196 commit 21fb81f

File tree

2 files changed

+53
-51
lines changed

2 files changed

+53
-51
lines changed

test/test_labels.py

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,25 @@
22
# coding: utf-8
33

44
import pytest
5-
import xarray as xr
65
from copy import copy
76

8-
from hypothesis import given, settings
9-
from hypothesis.strategies import floats
10-
11-
from labelled_functions.labels import *
12-
7+
from labelled_functions.labels import LabelledFunction
138
from example_functions import *
149

15-
number = floats(allow_nan=False, allow_infinity=False)
16-
1710

1811
def test_labelled_class():
19-
ldt = LabelledFunction(compute_pi)
20-
assert ldt() == compute_pi(), "The LabelledFunction is not callable"
21-
assert ldt.__wrapped__.__code__ == compute_pi.__code__, "Attributes are not passed to encapsulated function"
22-
assert ldt.name == compute_pi.__name__, "Name is wrong"
23-
assert str(ldt) == "compute_pi() -> (pi)", "String representation is wrong"
12+
lpi = LabelledFunction(compute_pi)
13+
assert lpi() == compute_pi()
14+
assert lpi.__wrapped__.__code__ == compute_pi.__code__
15+
assert lpi.name == compute_pi.__name__
16+
assert list(lpi.input_names) == []
17+
assert list(lpi.output_names) == ['pi']
18+
assert str(lpi) == "compute_pi() -> (pi)"
2419

2520
lc = LabelledFunction(cube)
2621
assert lc(0) == cube(0)
27-
assert lc.input_names == ['x']
28-
assert lc.output_names == ['length', 'area', 'volume']
22+
assert list(lc.input_names) == ['x']
23+
assert list(lc.output_names) == ['length', 'area', 'volume']
2924
assert str(lc) == "cube(x) -> (length, area, volume)"
3025

3126

@@ -35,15 +30,6 @@ def test_idempotence():
3530
assert llc is lc
3631

3732

38-
def test_output_checking():
39-
la = LabelledFunction(add)
40-
assert la._has_never_been_run
41-
42-
la.output_names = ['moose', 'llama']
43-
with pytest.raises(TypeError):
44-
la(1, 2)
45-
46-
4733
def test_method():
4834
class A:
4935
def __init__(self):
@@ -53,63 +39,77 @@ def f(self, x):
5339
y = 2*x + self.a
5440
return y
5541

56-
lab_Af = label(A().f)
42+
lab_Af = LabelledFunction(A().f)
5743
assert list(lab_Af.input_names) == ['x']
5844
assert list(lab_Af.output_names) == ['y']
5945
assert lab_Af(2) == 14
6046
assert lab_Af(x=2) == 14
6147

62-
lab_f = label(A.f)
48+
lab_f = LabelledFunction(A.f)
6349
assert list(lab_f.input_names) == ['self', 'x']
6450
assert list(lab_f.output_names) == ['y']
6551
assert lab_f(A(), 2) == 14
6652
assert lab_f(A(), x=2) == 14
6753

6854

55+
def test_output_checking():
56+
la = LabelledFunction(add, output_names=['shrubbery'])
57+
assert la._has_never_been_run
58+
assert la(1, 2) == 3
59+
60+
# Wrong number of outputs
61+
la = LabelledFunction(add, output_names=['moose', 'llama'])
62+
assert la._has_never_been_run
63+
with pytest.raises(TypeError):
64+
la(1, 2)
65+
66+
6967
def test_recorder():
7068
a, b = 1, 2
7169

72-
assert label(compute_pi).recorded_call() == {'pi': 3.14159}
70+
assert LabelledFunction(compute_pi).recorded_call() == {'pi': 3.14159}
7371

74-
assert label(double).recorded_call(a) == {'x': a, '2*x': 2*a}
72+
assert LabelledFunction(double).recorded_call(a) == {'x': a, '2*x': 2*a}
7573

76-
assert label(optional_double).recorded_call(a) == {'x': a, '2*x': 2*a}
77-
assert label(optional_double).recorded_call(x=a) == {'x': a, '2*x': 2*a}
78-
assert label(optional_double).recorded_call() == {'x': 0, '2*x': 0}
74+
assert LabelledFunction(optional_double).recorded_call(a) == {'x': a, '2*x': 2*a}
75+
assert LabelledFunction(optional_double).recorded_call(x=a) == {'x': a, '2*x': 2*a}
76+
assert LabelledFunction(optional_double).recorded_call() == {'x': 0, '2*x': 0}
7977

80-
assert label(add).recorded_call(a, b) == {'x': a, 'y': b, 'x+y': a+b}
78+
assert LabelledFunction(add).recorded_call(a, b) == {'x': a, 'y': b, 'x+y': a+b}
8179

82-
assert label(optional_add).recorded_call(a, b) == {'x': a, 'y': b, 'x+y': a+b}
83-
assert label(optional_add).recorded_call(a) == {'x': a, 'y': 0, 'x+y': a}
84-
assert label(optional_add).recorded_call() == {'x': 0, 'y': 0, 'x+y': 0}
85-
assert label(optional_add).recorded_call(x=a, y=b) == {'x': a, 'y': b, 'x+y': a+b}
86-
assert label(optional_add).recorded_call(y=a, x=b) == {'x': b, 'y': a, 'x+y': a+b}
87-
assert label(optional_add).recorded_call(y=a) == {'x': 0, 'y': a, 'x+y': a}
80+
assert LabelledFunction(optional_add).recorded_call(a, b) == {'x': a, 'y': b, 'x+y': a+b}
81+
assert LabelledFunction(optional_add).recorded_call(a) == {'x': a, 'y': 0, 'x+y': a}
82+
assert LabelledFunction(optional_add).recorded_call() == {'x': 0, 'y': 0, 'x+y': 0}
83+
assert LabelledFunction(optional_add).recorded_call(x=a, y=b) == {'x': a, 'y': b, 'x+y': a+b}
84+
assert LabelledFunction(optional_add).recorded_call(y=a, x=b) == {'x': b, 'y': a, 'x+y': a+b}
85+
assert LabelledFunction(optional_add).recorded_call(y=a) == {'x': 0, 'y': a, 'x+y': a}
8886

89-
assert label(cube).recorded_call(a) == {'x': a, 'length': 12*a, 'area': 6*a**2, 'volume': a**3}
87+
assert LabelledFunction(cube).recorded_call(a) == {'x': a, 'length': 12*a, 'area': 6*a**2, 'volume': a**3}
9088

9189
with pytest.raises(TypeError):
92-
label(all_kinds_of_args).recorded_call(0, 1, 2, 3)
90+
LabelledFunction(all_kinds_of_args).recorded_call(0, 1, 2, 3)
9391

94-
assert label(all_kinds_of_args).recorded_call(0, 1, z=2, t=3) == {'x': 0, 'y': 1, 'z': 2, 't': 3}
95-
assert label(all_kinds_of_args).recorded_call(x=0, y=1, z=2, t=3) == {'x': 0, 'y': 1, 'z': 2, 't': 3}
96-
assert label(all_kinds_of_args).recorded_call(x=0, z=2, t=3) == {'x': 0, 'y': 1, 'z': 2, 't': 3}
97-
assert label(all_kinds_of_args).recorded_call(x=0, z=2) == {'x': 0, 'y': 1, 'z': 2, 't': 3}
98-
assert label(all_kinds_of_args).recorded_call(z=2, t=3, x=0) == {'x': 0, 'y': 1, 'z': 2, 't': 3}
92+
assert LabelledFunction(all_kinds_of_args).recorded_call(0, 1, z=2, t=3) == {'x': 0, 'y': 1, 'z': 2, 't': 3}
93+
assert LabelledFunction(all_kinds_of_args).recorded_call(x=0, y=1, z=2, t=3) == {'x': 0, 'y': 1, 'z': 2, 't': 3}
94+
assert LabelledFunction(all_kinds_of_args).recorded_call(x=0, z=2, t=3) == {'x': 0, 'y': 1, 'z': 2, 't': 3}
95+
assert LabelledFunction(all_kinds_of_args).recorded_call(x=0, z=2) == {'x': 0, 'y': 1, 'z': 2, 't': 3}
96+
assert LabelledFunction(all_kinds_of_args).recorded_call(z=2, t=3, x=0) == {'x': 0, 'y': 1, 'z': 2, 't': 3}
9997

10098

10199
def test_namespace():
102100
namespace = {'x': 0, 'y': 3, 'z': 2}
103101
new_namespace = LabelledFunction(add).apply_in_namespace(namespace)
104102
assert new_namespace == {'x+y': 3, 'x': 0, 'y': 3, 'z': 2}
105103

104+
import xarray as xr
106105
in_ds = xr.Dataset(coords={'radius': np.linspace(0, 1, 10),
107106
'length': np.linspace(0, 1, 10)})
108107
out_ds = LabelledFunction(cylinder_volume).apply_in_namespace(in_ds)
109108
assert 'volume' in out_ds.data_vars
110109

110+
111111
def test_set_default():
112-
lc = label(cylinder_volume)
112+
lc = LabelledFunction(cylinder_volume)
113113
llc = lc.set_default(radius=1.0)
114114
assert llc(length=1.0) == np.pi
115115
assert copy(llc)(length=1.0) == np.pi
@@ -122,15 +122,17 @@ def test_set_default():
122122
with pytest.raises(TypeError):
123123
rllc()
124124

125+
125126
def test_hide():
126127
a = np.random.rand(1)[0]
127-
assert label(optional_add).hide('x').recorded_call(y=a) == {'y': a, 'x+y': a}
128-
assert label(optional_add).hide('y').recorded_call(y=a) == {'x': 0, 'x+y': a}
129-
assert label(optional_add).hide('y').recorded_call() == {'x': 0, 'x+y': 0}
130-
assert label(optional_add).hide_all_but('y').recorded_call() == {'y': 0, 'x+y': 0}
128+
assert LabelledFunction(optional_add).hide('x').recorded_call(y=a) == {'y': a, 'x+y': a}
129+
assert LabelledFunction(optional_add).hide('y').recorded_call(y=a) == {'x': 0, 'x+y': a}
130+
assert LabelledFunction(optional_add).hide('y').recorded_call() == {'x': 0, 'x+y': 0}
131+
assert LabelledFunction(optional_add).hide_all_but('y').recorded_call() == {'y': 0, 'x+y': 0}
132+
131133

132134
def test_fix():
133-
lc = label(cylinder_volume)
135+
lc = LabelledFunction(cylinder_volume)
134136
llc = lc.fix(radius=1.0)
135137
assert llc.default_values == {'radius': 1.0}
136138
assert llc.hidden_inputs == {'radius'}
File renamed without changes.

0 commit comments

Comments
 (0)