1
+ import os
1
2
import unittest
2
- import ROOT
3
+
4
+ import numba # noqa: F401
3
5
import numpy as np
6
+ import ROOT
7
+ from rdf_filter_pyz_helper import TYPE_TO_SYMBOL , CreateData , filter_dict
8
+
9
+ # numba is not used directly, but tests can crash when ROOT is built with
10
+ # builtin_llvm=OFF and numba is not imported at the beginning:
4
11
5
- import os
6
- from rdf_filter_pyz_helper import CreateData , TYPE_TO_SYMBOL , filter_dict
7
12
8
13
class PyFilter (unittest .TestCase ):
9
14
"""
@@ -12,16 +17,17 @@ class PyFilter(unittest.TestCase):
12
17
13
18
def test_with_dtypes (self ):
14
19
"""
15
- Tests the pythonized filter with all the tree datatypes and
20
+ Tests the pythonized filter with all the tree datatypes and
16
21
"""
17
22
CreateData ()
18
23
rdf = ROOT .RDataFrame ("TestData" , "./RDF_Filter_Pyz_TestData.root" )
19
- test_cols = [str (c ) for c in rdf .GetColumnNames ()]
24
+ test_cols = [str (c ) for c in rdf .GetColumnNames ()]
20
25
for col_name in test_cols :
21
- func = filter_dict [TYPE_TO_SYMBOL [col_name ]] # filter function
26
+ func = filter_dict [TYPE_TO_SYMBOL [col_name ]] # filter function
22
27
x = rdf .Mean (col_name ).GetValue ()
23
- if col_name == 'Bool_t' : x = True
24
- filtered = rdf .Filter (func , extra_args = {'x' :x })
28
+ if col_name == "Bool_t" :
29
+ x = True
30
+ filtered = rdf .Filter (func , extra_args = {"x" : x })
25
31
res_root = filtered .AsNumpy ()[col_name ]
26
32
if not isinstance (x , bool ):
27
33
filtered2 = rdf .Filter (f"{ col_name } > { x } " )
@@ -31,19 +37,21 @@ def test_with_dtypes(self):
31
37
else :
32
38
filtered2 = rdf .Filter (f"{ col_name } == false" )
33
39
res_root2 = filtered2 .AsNumpy ()[col_name ]
34
- self .assertTrue (np .array_equal (res_root ,res_root2 ))
40
+ self .assertTrue (np .array_equal (res_root , res_root2 ))
35
41
36
42
os .remove ("./RDF_Filter_Pyz_TestData.root" )
37
43
38
- # CPP Overload 1: Filter(callable, col_list = [], name = "") => 3 Possibilities
44
+ # CPP Overload 1: Filter(callable, col_list = [], name = "") => 3 Possibilities
39
45
def test_filter_overload1_a (self ):
40
46
"""
41
47
Test to verify the first overload (1.a) of filter
42
48
Filter(callable, col_list, name)
43
49
"""
44
50
rdf = ROOT .RDataFrame (5 ).Define ("x" , "(double) rdfentry_" )
51
+
45
52
def x_greater_than_2 (x ):
46
- return x > 2
53
+ return x > 2
54
+
47
55
fil1 = rdf .Filter (x_greater_than_2 , ["x" ], "x is more than 2" )
48
56
self .assertTrue (np .array_equal (fil1 .AsNumpy ()["x" ], np .array ([3 , 4 ])))
49
57
@@ -53,37 +61,43 @@ def test_filter_overload1_b(self):
53
61
Filter(callable, col_list)
54
62
"""
55
63
rdf = ROOT .RDataFrame (5 ).Define ("x" , "(double) rdfentry_" )
56
- fil1 = rdf .Filter (lambda x : x > 2 , ["x" ])
64
+ fil1 = rdf .Filter (lambda x : x > 2 , ["x" ])
57
65
self .assertTrue (np .array_equal (fil1 .AsNumpy ()["x" ], np .array ([3 , 4 ])))
58
-
66
+
59
67
def test_filter_overload1_c (self ):
60
68
"""
61
69
Test to verify the first overload (1.c) of filter
62
70
Filter(callable)
63
71
"""
64
72
rdf = ROOT .RDataFrame (5 ).Define ("x" , "(double) rdfentry_" )
73
+
65
74
def x_greater_than_2 (x ):
66
- return x > 2
75
+ return x > 2
76
+
67
77
fil1 = rdf .Filter (x_greater_than_2 )
68
78
self .assertTrue (np .array_equal (fil1 .AsNumpy ()["x" ], np .array ([3 , 4 ])))
69
-
79
+
70
80
# CPP Overload 3: Filter(callable, name)
71
81
def test_filter_overload3 (self ):
72
82
"""
73
83
Test to verify the third overload of filter
74
84
Filter(callable, name)
75
85
"""
76
86
rdf = ROOT .RDataFrame (5 ).Define ("x" , "(double) rdfentry_" )
87
+
77
88
def x_greater_than_2 (x ):
78
- return x > 2
89
+ return x > 2
90
+
79
91
fil1 = rdf .Filter (x_greater_than_2 , "x is greater than 2" )
80
92
self .assertTrue (np .array_equal (fil1 .AsNumpy ()["x" ], np .array ([3 , 4 ])))
81
-
93
+
82
94
def test_capture_from_scope (self ):
83
95
rdf = ROOT .RDataFrame (5 ).Define ("x" , "(double) rdfentry_" )
84
96
y = 2
97
+
85
98
def x_greater_than_y (x ):
86
99
return x > y
100
+
87
101
fil1 = rdf .Filter (x_greater_than_y , "x is greater than 2" )
88
102
self .assertTrue (np .array_equal (fil1 .AsNumpy ()["x" ], np .array ([3 , 4 ])))
89
103
@@ -93,12 +107,14 @@ def test_cpp_functor(self):
93
107
Filter operation.
94
108
"""
95
109
96
- ROOT .gInterpreter .Declare ("""
110
+ ROOT .gInterpreter .Declare (
111
+ """
97
112
struct MyFunctor
98
113
{
99
114
bool operator()(ULong64_t l) { return l == 0; };
100
115
};
101
- """ )
116
+ """
117
+ )
102
118
f = ROOT .MyFunctor ()
103
119
104
120
rdf = ROOT .RDataFrame (5 )
@@ -112,15 +128,17 @@ def test_std_function(self):
112
128
Filter operation.
113
129
"""
114
130
115
- ROOT .gInterpreter .Declare ("""
131
+ ROOT .gInterpreter .Declare (
132
+ """
116
133
std::function<bool(ULong64_t)> myfun = [](ULong64_t l) { return l == 0; };
117
- """ )
134
+ """
135
+ )
118
136
119
137
rdf = ROOT .RDataFrame (5 )
120
138
c = rdf .Filter (ROOT .myfun , ["rdfentry_" ]).Count ().GetValue ()
121
139
122
140
self .assertEqual (c , 1 )
123
141
124
-
125
- if __name__ == ' __main__' :
142
+
143
+ if __name__ == " __main__" :
126
144
unittest .main ()
0 commit comments