Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 1fa0b32

Browse files
authored
Merge pull request #37 from Python-for-HPC/add-reductions-and-tests
Add reductions and tests
2 parents 5a84f2e + 8a341d5 commit 1fa0b32

File tree

2 files changed

+355
-1
lines changed

2 files changed

+355
-1
lines changed

numba/openmp.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5401,12 +5401,26 @@ def PLUS(self, args):
54015401
print("visit PLUS", args, type(args))
54025402
return "+"
54035403

5404+
def MINUS(self, args):
5405+
if config.DEBUG_OPENMP >= 1:
5406+
print("visit MINUS", args, type(args))
5407+
return "-"
5408+
5409+
def STAR(self, args):
5410+
if config.DEBUG_OPENMP >= 1:
5411+
print("visit STAR", args, type(args))
5412+
return "*"
5413+
54045414
def reduction_operator(self, args):
54055415
arg = args[0]
54065416
if config.DEBUG_OPENMP >= 1:
54075417
print("visit reduction_operator", args, type(args), arg, type(arg))
54085418
if arg == "+":
54095419
return "ADD"
5420+
elif arg == "-":
5421+
return "SUB"
5422+
elif arg == "*":
5423+
return "MUL"
54105424
assert(0)
54115425

54125426
def threadprivate_directive(self, args):
@@ -6150,7 +6164,9 @@ def NUMBER(self, args):
61506164
var_list: name_slice | var_list "," name_slice
61516165
number_list: NUMBER | number_list "," NUMBER
61526166
PLUS: "+"
6153-
reduction_operator: PLUS | "\\" | "*" | "-" | "&" | "^" | "|" | "&&" | "||"
6167+
MINUS: "-"
6168+
STAR: "*"
6169+
reduction_operator: PLUS | "\\" | STAR | MINUS | "&" | "^" | "|" | "&&" | "||"
61546170
threadprivate_directive: "threadprivate" "(" var_list ")"
61556171
cancellation_point_directive: "cancellation point" construct_type_clause
61566172
construct_type_clause: PARALLEL

numba/tests/test_openmp.py

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,345 @@ def test_impl():
996996
return a
997997
self.check(test_impl)
998998

999+
class TestReductions(TestOpenmpBase):
1000+
def __init__(self, *args):
1001+
TestOpenmpBase.__init__(self, *args)
1002+
1003+
def test_parallel_reduction_add_int(self):
1004+
@njit
1005+
def test_impl():
1006+
redux = 0
1007+
nthreads = 0
1008+
with openmp("parallel reduction(+:redux)"):
1009+
thread_id = omp_get_thread_num()
1010+
if thread_id == 0:
1011+
nthreads = omp_get_num_threads()
1012+
redux = 1
1013+
return redux, nthreads
1014+
1015+
redux, nthreads = test_impl()
1016+
self.assertGreater(nthreads, 1)
1017+
self.assertEqual(redux, nthreads)
1018+
1019+
def test_parallel_reduction_sub_int(self):
1020+
@njit
1021+
def test_impl():
1022+
redux = 0
1023+
nthreads = 0
1024+
with openmp("parallel reduction(-:redux)"):
1025+
thread_id = omp_get_thread_num()
1026+
if thread_id == 0:
1027+
nthreads = omp_get_num_threads()
1028+
redux = 1
1029+
return redux, nthreads
1030+
1031+
redux, nthreads = test_impl()
1032+
self.assertGreater(nthreads, 1)
1033+
self.assertEqual(redux, nthreads)
1034+
1035+
def test_parallel_reduction_mul_int(self):
1036+
@njit
1037+
def test_impl():
1038+
redux = 1
1039+
nthreads = 0
1040+
with openmp("parallel reduction(*:redux) num_threads(8)"):
1041+
thread_id = omp_get_thread_num()
1042+
if thread_id == 0:
1043+
nthreads = omp_get_num_threads()
1044+
redux = 2
1045+
return redux, nthreads
1046+
1047+
redux, nthreads = test_impl()
1048+
self.assertGreater(nthreads, 1)
1049+
self.assertEqual(redux, 2**nthreads)
1050+
1051+
def test_parallel_reduction_add_fp64(self):
1052+
@njit
1053+
def test_impl():
1054+
redux = np.float64(0.0)
1055+
nthreads = np.float64(0.0)
1056+
with openmp("parallel reduction(+:redux)"):
1057+
thread_id = omp_get_thread_num()
1058+
if thread_id == 0:
1059+
nthreads = omp_get_num_threads()
1060+
redux = np.float64(1.0)
1061+
return redux, nthreads
1062+
1063+
redux, nthreads = test_impl()
1064+
self.assertGreater(nthreads, 1)
1065+
self.assertEqual(redux, 1.0*nthreads)
1066+
1067+
def test_parallel_reduction_sub_fp64(self):
1068+
@njit
1069+
def test_impl():
1070+
redux = np.float64(0.0)
1071+
nthreads = np.float64(0.0)
1072+
with openmp("parallel reduction(-:redux)"):
1073+
thread_id = omp_get_thread_num()
1074+
if thread_id == 0:
1075+
nthreads = omp_get_num_threads()
1076+
redux = np.float64(1.0)
1077+
return redux, nthreads
1078+
1079+
redux, nthreads = test_impl()
1080+
self.assertGreater(nthreads, 1)
1081+
self.assertEqual(redux, 1.0*nthreads)
1082+
1083+
def test_parallel_reduction_mul_fp64(self):
1084+
@njit
1085+
def test_impl():
1086+
redux = np.float64(1.0)
1087+
nthreads = np.float64(0.0)
1088+
with openmp("parallel reduction(*:redux) num_threads(8)"):
1089+
thread_id = omp_get_thread_num()
1090+
if thread_id == 0:
1091+
nthreads = omp_get_num_threads()
1092+
redux = np.float64(2.0)
1093+
return redux, nthreads
1094+
1095+
redux, nthreads = test_impl()
1096+
self.assertGreater(nthreads, 1)
1097+
self.assertEqual(redux, 2.0**nthreads)
1098+
1099+
def test_parallel_reduction_add_fp32(self):
1100+
@njit
1101+
def test_impl():
1102+
redux = np.float32(0.0)
1103+
nthreads = np.float32(0.0)
1104+
with openmp("parallel reduction(+:redux)"):
1105+
thread_id = omp_get_thread_num()
1106+
if thread_id == 0:
1107+
nthreads = omp_get_num_threads()
1108+
redux = np.float32(1.0)
1109+
return redux, nthreads
1110+
1111+
redux, nthreads = test_impl()
1112+
self.assertGreater(nthreads, 1)
1113+
self.assertEqual(redux, 1.0*nthreads)
1114+
1115+
def test_parallel_reduction_sub_fp32(self):
1116+
@njit
1117+
def test_impl():
1118+
redux = np.float32(0.0)
1119+
nthreads = np.float32(0.0)
1120+
with openmp("parallel reduction(-:redux)"):
1121+
thread_id = omp_get_thread_num()
1122+
if thread_id == 0:
1123+
nthreads = omp_get_num_threads()
1124+
redux = np.float32(1.0)
1125+
return redux, nthreads
1126+
1127+
redux, nthreads = test_impl()
1128+
self.assertGreater(nthreads, 1)
1129+
self.assertEqual(redux, 1.0*nthreads)
1130+
1131+
def test_parallel_reduction_mul_fp32(self):
1132+
@njit
1133+
def test_impl():
1134+
redux = np.float32(1.0)
1135+
nthreads = np.float32(0.0)
1136+
with openmp("parallel reduction(*:redux) num_threads(8)"):
1137+
thread_id = omp_get_thread_num()
1138+
if thread_id == 0:
1139+
nthreads = omp_get_num_threads()
1140+
redux = np.float32(2.0)
1141+
return redux, nthreads
1142+
1143+
redux, nthreads = test_impl()
1144+
self.assertGreater(nthreads, 1)
1145+
self.assertEqual(redux, 2.0**nthreads)
1146+
1147+
def test_parallel_for_reduction_add_int(self):
1148+
@njit
1149+
def test_impl():
1150+
redux = 0
1151+
with openmp("parallel for reduction(+:redux)"):
1152+
for i in range(10):
1153+
redux += 1
1154+
return redux
1155+
1156+
redux = test_impl()
1157+
self.assertEqual(redux, 10)
1158+
1159+
def test_parallel_for_reduction_sub_int(self):
1160+
@njit
1161+
def test_impl():
1162+
redux = 0
1163+
with openmp("parallel for reduction(-:redux)"):
1164+
for i in range(10):
1165+
redux += 1
1166+
return redux
1167+
1168+
redux = test_impl()
1169+
self.assertEqual(redux, 10)
1170+
1171+
def test_parallel_for_reduction_mul_int(self):
1172+
@njit
1173+
def test_impl():
1174+
redux = 1
1175+
with openmp("parallel for reduction(*:redux)"):
1176+
for i in range(10):
1177+
redux *= 2
1178+
return redux
1179+
1180+
redux = test_impl()
1181+
self.assertEqual(redux, 2**10)
1182+
1183+
def test_parallel_for_reduction_add_fp64(self):
1184+
@njit
1185+
def test_impl():
1186+
redux = np.float64(0.0)
1187+
with openmp("parallel for reduction(+:redux)"):
1188+
for i in range(10):
1189+
redux += np.float64(1.0)
1190+
return redux
1191+
1192+
redux = test_impl()
1193+
self.assertEqual(redux, 10.0)
1194+
1195+
def test_parallel_for_reduction_sub_fp64(self):
1196+
@njit
1197+
def test_impl():
1198+
redux = np.float64(0.0)
1199+
with openmp("parallel for reduction(-:redux)"):
1200+
for i in range(10):
1201+
redux += np.float64(1.0)
1202+
return redux
1203+
1204+
redux = test_impl()
1205+
self.assertEqual(redux, 10.0)
1206+
1207+
def test_parallel_for_reduction_mul_fp64(self):
1208+
@njit
1209+
def test_impl():
1210+
redux = np.float64(1.0)
1211+
with openmp("parallel for reduction(*:redux)"):
1212+
for i in range(10):
1213+
redux *= np.float64(2.0)
1214+
return redux
1215+
1216+
redux = test_impl()
1217+
self.assertEqual(redux, 2.0**10)
1218+
1219+
def test_parallel_for_reduction_add_fp32(self):
1220+
@njit
1221+
def test_impl():
1222+
redux = np.float32(0.0)
1223+
with openmp("parallel for reduction(+:redux)"):
1224+
for i in range(10):
1225+
redux += np.float32(1.0)
1226+
return redux
1227+
1228+
redux = test_impl()
1229+
self.assertEqual(redux, 10.0)
1230+
1231+
def test_parallel_for_reduction_sub_fp32(self):
1232+
@njit
1233+
def test_impl():
1234+
redux = np.float32(0.0)
1235+
with openmp("parallel for reduction(-:redux)"):
1236+
for i in range(10):
1237+
redux += np.float32(1.0)
1238+
return redux
1239+
1240+
redux = test_impl()
1241+
self.assertEqual(redux, 10.0)
1242+
1243+
def test_parallel_for_reduction_mul_fp32(self):
1244+
@njit
1245+
def test_impl():
1246+
redux = np.float32(1.0)
1247+
with openmp("parallel for reduction(*:redux)"):
1248+
for i in range(10):
1249+
redux *= np.float32(2.0)
1250+
return redux
1251+
1252+
redux = test_impl()
1253+
self.assertEqual(redux, 2.0**10)
1254+
1255+
def test_parallel_reduction_add_int_10(self):
1256+
@njit
1257+
def test_impl():
1258+
redux = 10
1259+
nthreads = 0
1260+
with openmp("parallel reduction(+:redux)"):
1261+
thread_id = omp_get_thread_num()
1262+
if thread_id == 0:
1263+
nthreads = omp_get_num_threads()
1264+
redux = 1
1265+
return redux, nthreads
1266+
1267+
redux, nthreads = test_impl()
1268+
self.assertGreater(nthreads, 1)
1269+
self.assertEqual(redux, nthreads+10)
1270+
1271+
def test_parallel_reduction_add_fp32_10(self):
1272+
@njit
1273+
def test_impl():
1274+
redux = np.float32(10.0)
1275+
nthreads = np.float32(0.0)
1276+
with openmp("parallel reduction(+:redux)"):
1277+
thread_id = omp_get_thread_num()
1278+
if thread_id == 0:
1279+
nthreads = omp_get_num_threads()
1280+
redux = np.float32(1.0)
1281+
return redux, nthreads
1282+
1283+
redux, nthreads = test_impl()
1284+
self.assertGreater(nthreads, 1)
1285+
self.assertEqual(redux, 1.0*nthreads+10.0)
1286+
1287+
def test_parallel_reduction_add_fp64_10(self):
1288+
@njit
1289+
def test_impl():
1290+
redux = np.float64(10.0)
1291+
nthreads = np.float64(0.0)
1292+
with openmp("parallel reduction(+:redux)"):
1293+
thread_id = omp_get_thread_num()
1294+
if thread_id == 0:
1295+
nthreads = omp_get_num_threads()
1296+
redux = np.float64(1.0)
1297+
return redux, nthreads
1298+
1299+
redux, nthreads = test_impl()
1300+
self.assertGreater(nthreads, 1)
1301+
self.assertEqual(redux, 1.0*nthreads+10.0)
1302+
1303+
def test_parallel_for_reduction_add_int_10(self):
1304+
@njit
1305+
def test_impl():
1306+
redux = 10
1307+
with openmp("parallel for reduction(+:redux)"):
1308+
for i in range(10):
1309+
redux += 1
1310+
return redux
1311+
1312+
redux = test_impl()
1313+
self.assertEqual(redux, 10+10)
1314+
1315+
def test_parallel_for_reduction_add_fp32(self):
1316+
@njit
1317+
def test_impl():
1318+
redux = np.float32(0.0)
1319+
with openmp("parallel for reduction(+:redux)"):
1320+
for i in range(10):
1321+
redux += np.float32(1.0)
1322+
return redux
1323+
1324+
redux = test_impl()
1325+
self.assertEqual(redux, 10.0)
1326+
1327+
def test_parallel_for_reduction_add_fp64_10(self):
1328+
@njit
1329+
def test_impl():
1330+
redux = np.float64(10.0)
1331+
with openmp("parallel for reduction(+:redux)"):
1332+
for i in range(10):
1333+
redux += np.float64(1.0)
1334+
return redux
9991335

1336+
redux = test_impl()
1337+
self.assertEqual(redux, 10.0+10.0)
10001338

10011339
class TestOpenmpDataClauses(TestOpenmpBase):
10021340

0 commit comments

Comments
 (0)