Skip to content

Commit 825f450

Browse files
committed
test: make tests more flexible
1 parent c6a7626 commit 825f450

File tree

5 files changed

+150
-9
lines changed

5 files changed

+150
-9
lines changed

tests/test_module1.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,26 @@ def test_rename_columns_module1():
5353

5454
@pytest.mark.test_concatenate_identifier_columns_module1
5555
def test_concatenate_identifier_columns_module1():
56-
assert 'games:pd:concat:games:identifiers:axis:1:sort:False' in get_assignments(data), 'Concatenate the `games` and `identifiers` DataFrames.'
56+
concat = False
57+
frames = False
58+
axis = False
59+
sort = False
60+
61+
for string in get_assignments(data):
62+
if 'games:pd:concat' in string:
63+
concat = True
64+
if 'games:identifiers' in string:
65+
frames = True
66+
if 'axis:1' in string:
67+
axis = True
68+
if 'sort:False' in string:
69+
sort = True
70+
71+
assert concat, 'Are you calling `pd.concat()`?'
72+
assert frames, 'Does the call to `pd.concat()` have a list of DataFrames to concatenate as the first argument? Make sure the frames are in the correct order.'
73+
assert axis, 'Does the call to `pd.concat()` have a keyword argument of `axis` set to `1`?'
74+
assert sort, 'Does the call to `pd.concat()` have a keyword argument of `sort` set to `False`?'
75+
5776

5877
@pytest.mark.test_fill_nan_values_module1
5978
def test_fill_nan_values_module1():

tests/test_module2.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22
import matplotlib
3+
import numpy as np
4+
import pandas as pd
35
matplotlib.use('Agg')
46

57
from .utils import get_assignments, get_calls
@@ -27,8 +29,10 @@ def test_select_attendance_module2():
2729
if 'multi3' not in attendance.attendance.columns:
2830
local_attendance.columns = ['year', 'attendance']
2931

30-
assert attendance.attendance.equals(local_attendance), 'Have you selected the attendance rows with `loc[]`?'
32+
if attendance.attendance['attendance'].dtype == np.int64:
33+
local_attendance.loc[:, 'attendance'] = pd.to_numeric(local_attendance.loc[:, 'attendance'])
3134

35+
assert attendance.attendance.equals(local_attendance), 'Have you selected the attendance rows with `loc[]`?'
3236
except ImportError:
3337
print('It looks as if `data.py` is incomplete.')
3438

@@ -42,7 +46,28 @@ def test_convert_to_numeric_module2():
4246

4347
@pytest.mark.test_plot_dataframe_module2
4448
def test_plot_dataframe_module2():
45-
assert 'attendance:plot:x:year:y:attendance:figsize:15:7:kind:bar' in get_calls(attendance), 'Plot the `year` on the x-axis and the `attendance` on the y-axis of a bar plot. Adjust the size of the plot.'
49+
plot = False
50+
x = False
51+
y = False
52+
figsize = False
53+
kind = False
54+
for string in get_calls(attendance):
55+
if 'attendance:plot' in string:
56+
plot = True
57+
if 'x:year' in string:
58+
x = True
59+
if 'y:attendance' in string:
60+
y = True
61+
if 'figsize:15:7' in string:
62+
figsize = True
63+
if 'kind:bar' in string:
64+
kind = True
65+
66+
assert plot, 'Are you calling `plot()` on the `attendance` DataFrame?'
67+
assert x, 'Does the call to `plot()` have a keyword argument of `x` set to `\'year\'`?'
68+
assert y, 'Does the call to `plot()` have a keyword argument of `y` set to `\'attendance\'`?'
69+
assert figsize, 'Does the call to `plot()` have a keyword argument of `figsize` set to `(15, 7)`?'
70+
assert kind, 'Does the call to `plot()` have a keyword argument of `kind` set to `\'bar\'`?'
4671
assert 'plt:show' in get_calls(attendance), 'Have you shown the plot?'
4772

4873
@pytest.mark.test_axis_labels_module2
@@ -52,4 +77,25 @@ def test_axis_labels_module2():
5277

5378
@pytest.mark.test_mean_line_module2
5479
def test_mean_line_module2():
55-
assert 'plt:axhline:y:attendance:attendance:mean:label:Mean:linestyle:--:color:green' in get_calls(attendance), 'Plot a green dashed line at the mean.'
80+
axhline = False
81+
y = False
82+
label = False
83+
linestyle = False
84+
color = False
85+
for string in get_calls(attendance):
86+
if 'plt:axhline' in string:
87+
axhline = True
88+
if 'y:attendance:attendance:mean' in string:
89+
y = True
90+
if 'label:Mean' in string:
91+
label = True
92+
if 'linestyle:--' in string:
93+
linestyle = True
94+
if 'color:green' in string:
95+
color = True
96+
97+
assert axhline, 'Are you calling `plt.axhline()`?'
98+
assert y, 'Does the call to `plt.axhline()` have a keyword argument of `y` set to `attendance[\'attendance\'].mean()`?'
99+
assert label, 'Does the call to `plt.axhline()` have a keyword argument of `label` set to `\'Mean\'`?'
100+
assert linestyle, 'Does the call to `plt.axhline()` have a keyword argument of `linestyle` set to `\'--\'`?'
101+
assert color, 'Does the call to `plt.axhline()` have a keyword argument of `color` set to `\'green\'`?'

tests/test_module3.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
@pytest.mark.test_select_all_plays_module3
99
def test_select_all_plays_module3():
1010
assert 'games' in dir(pitching), 'Have you imported `games` from `data`?'
11-
assert 'plays:games:games:type:play' in get_assignments(pitching), 'Are you selecting just the rows that have a `type` of `play`?'
11+
assert 'plays:games:games:type:play' in get_assignments(pitching), 'Are you selecting just the rows that have a `type` of `play`? Make sure you are using the shortcut method of selection.'
1212

1313
@pytest.mark.test_select_all_strike_outs_module3
1414
def test_select_all_strike_outs_module3():
@@ -28,5 +28,26 @@ def test_apply_an_operation_to_multiple_columns_module3():
2828

2929
@pytest.mark.test_change_plot_formatting_module3
3030
def test_change_plot_formatting_module3():
31-
assert 'strike_outs:plot:x:year:y:strike_outs:kind:scatter:legend:Strike Outs' in get_calls(pitching), 'Create a scatter plot with the `year` as the x-axis and the number of `strikes_outs` on the y-axis.'
31+
plot = False
32+
x = False
33+
y = False
34+
kind = False
35+
legend = False
36+
for string in get_calls(pitching):
37+
if 'strike_outs:plot' in string:
38+
plot = True
39+
if 'x:year' in string:
40+
x = True
41+
if 'y:strike_outs' in string:
42+
y = True
43+
if 'kind:scatter' in string:
44+
kind = True
45+
if 'legend:Strike Outs' in string:
46+
legend = True
47+
48+
assert plot, 'Are you calling `plot()` on the `strike_outs` DataFrame?'
49+
assert x, 'Does the call to `plot()` have a keyword argument of `x` set to `\'year\'`?'
50+
assert y, 'Does the call to `plot()` have a keyword argument of `y` set to `\'strike_outs\'`?'
51+
assert kind, 'Does the call to `plot()` have a keyword argument of `kind` set to `\'scatter\'`?'
52+
assert legend, 'Have you chained a call to `legend()`?'
3253
assert 'plt:show' in get_calls(pitching), 'Show the scatter plot.'

tests/test_module4.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,25 @@ def test_sort_values_module4():
4545

4646
@pytest.mark.test_reshape_with_pivot_module4
4747
def test_reshape_with_pivot_module4():
48-
assert 'hits:hits:pivot:index:inning:columns:hit_type:values:count' in get_assignments(offense), 'Make sure to reshape the `hits` DataFrame with the `pivot()` function. Ensure you have set the correct keyword arguments.'
48+
pivot = False
49+
index = False
50+
columns = False
51+
values = False
52+
53+
for string in get_assignments(offense):
54+
if 'hits:hits:pivot' in string:
55+
pivot = True
56+
if 'index:inning' in string:
57+
index = True
58+
if 'columns:hit_type' in string:
59+
columns = True
60+
if 'values:count' in string:
61+
values = True
62+
63+
assert pivot, 'Are you calling `pivot()` on the `hits` DataFrame?'
64+
assert index, 'Does the call to `pivot()` have a keyword argument of `index` set to `\'inning\'`?'
65+
assert columns, 'Does the call to `pivot()` have a keyword argument of `columns` set to `\'strike_outs\'`?'
66+
assert values, 'Does the call to `pivot()` have a keyword argument of `values` set to `\'count\'`?'
4967

5068
@pytest.mark.test_stacked_bar_plot_module4
5169
def test_stacked_bar_plot_module4():

tests/test_module5.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,25 @@ def test_manage_column_labels_module5():
5454

5555
@pytest.mark.test_merge_plate_appearances_module5
5656
def test_merge_plate_appearances_module5():
57-
assert 'events_plus_pa:pd:merge:events:pa:how:outer:left_on:year:game_id:team:right_on:year:game_id:team' in get_assignments(defense), 'Have the `events` DataFrame and the `pa` DataFrame been merged with the correct keyword arguments?'
57+
merge = False
58+
how = False
59+
left_on = False
60+
right_on = False
61+
62+
for string in get_assignments(defense):
63+
if 'events_plus_pa:pd:merge:events:pa' in string:
64+
merge = True
65+
if 'how:outer' in string:
66+
how = True
67+
if 'left_on:year:game_id:team' in string:
68+
left_on = True
69+
if 'right_on:year:game_id:team' in string:
70+
right_on = True
71+
72+
assert merge, 'Are you calling `pd.merge()` to merge the `events` and `pa` DataFrames?'
73+
assert how, 'Does the call to `pd.merge()` have a keyword argument of `how` set to `\'outer\'`?'
74+
assert left_on, "Does the call to `pd.merge()` have a keyword argument of `left_on` set to `\'['year', 'game_id', 'team']\'`?"
75+
assert right_on, "Does the call to `pd.merge()` have a keyword argument of `right_on` set to `\'['year', 'game_id', 'team']\'`?"
5876

5977
@pytest.mark.test_merge_team_module5
6078
def test_merge_team_module5():
@@ -73,7 +91,26 @@ def test_calculate_der_module5():
7391
@pytest.mark.test_reshape_with_pivot_module5
7492
def test_reshape_with_pivot_module5():
7593
assert 'der:Name:defense:Name:loc:Attribute:defense:Name:year:Str:Index:Subscript:GtE:1978:Num:Compare:year:Str:defense:Str:DER:Str:List:Tuple:Index:Subscript:Assign' in get_assignments(defense, include_type=True), 'Select just the rows of the `defense` DataFrame with a year greater than 1978.'
76-
assert 'der:der:pivot:index:year:columns:defense:values:DER' in get_assignments(defense), 'Reshape the `defense` DataFrame with the `pivot()` function and the correct keyword arguments.'
94+
95+
pivot = False
96+
index = False
97+
columns = False
98+
values = False
99+
100+
for string in get_assignments(defense):
101+
if 'der:der:pivot' in string:
102+
pivot = True
103+
if 'index:year' in string:
104+
index = True
105+
if 'columns:defense' in string:
106+
columns = True
107+
if 'values:DER' in string:
108+
values = True
109+
110+
assert pivot, 'Are you calling `pivot()` on the `der` DataFrame?'
111+
assert index, 'Does the call to `pivot()` have a keyword argument of `index` set to `\'year\'`?'
112+
assert columns, 'Does the call to `pivot()` have a keyword argument of `columns` set to `\'defense\'`?'
113+
assert values, 'Does the call to `pivot()` have a keyword argument of `values` set to `\'DER\'`?'
77114

78115
@pytest.mark.test_plot_formatting_xticks_module5
79116
def test_plot_formatting_xticks_module5():

0 commit comments

Comments
 (0)