Skip to content

Commit 05fcc76

Browse files
author
Menooker
authored
Add kmp_* wrapper for gomp environment (#79)
1 parent bd9e32e commit 05fcc76

File tree

8 files changed

+271
-7
lines changed

8 files changed

+271
-7
lines changed

lib/gc/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ include(functions)
66

77
add_subdirectory(CAPI)
88
add_subdirectory(Dialect)
9-
add_subdirectory(Transforms)
9+
add_subdirectory(Transforms)
10+
add_subdirectory(ExecutionEngine)

lib/gc/ExecutionEngine/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(CPURuntime)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
find_package(OpenMP REQUIRED)
2+
3+
if ("iomp" IN_LIST OpenMP_C_LIB_NAMES OR "omp" IN_LIST OpenMP_C_LIB_NAMES OR "omp5" IN_LIST OpenMP_C_LIB_NAMES)
4+
else()
5+
add_definitions("-DGC_NEEDS_OMP_WRAPPER=1")
6+
endif()
7+
8+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fopenmp")
9+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
10+
add_mlir_library(GCCpuRuntime
11+
SHARED
12+
Parallel.cpp
13+
14+
EXCLUDE_FROM_LIBMLIR
15+
)
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
//===-- Parallel.cpp - parallel ---------------------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <assert.h>
10+
#include <atomic>
11+
#include <chrono>
12+
#include <immintrin.h>
13+
#include <omp.h>
14+
#include <stdarg.h>
15+
16+
#define likely(x) __builtin_expect(!!(x), 1)
17+
#define unlikely(x) __builtin_expect(!!(x), 0)
18+
19+
#define WEAK_SYMBOL __attribute__((weak))
20+
21+
namespace {
22+
struct barrier_t {
23+
alignas(64) std::atomic<int32_t> pending_;
24+
std::atomic<int32_t> rounds_;
25+
uint64_t total_;
26+
// pad barrier to size of cacheline to avoid false sharing
27+
char padding_[64 - 4 * sizeof(int32_t)];
28+
};
29+
30+
using barrier_idle_func = uint64_t (*)(std::atomic<int32_t> *remaining,
31+
int32_t expected_remain, int32_t tid,
32+
void *args);
33+
} // namespace
34+
35+
extern "C" {
36+
int gc_runtime_keep_alive = 0;
37+
void gc_arrive_at_barrier(barrier_t *b, barrier_idle_func idle_func,
38+
void *idle_args) {
39+
auto cur_round = b->rounds_.load(std::memory_order_acquire);
40+
auto cnt = --b->pending_;
41+
assert(cnt >= 0);
42+
if (cnt == 0) {
43+
b->pending_.store(b->total_);
44+
b->rounds_.store(cur_round + 1);
45+
} else {
46+
if (idle_func) {
47+
if (cur_round != b->rounds_.load()) {
48+
return;
49+
}
50+
idle_func(&b->rounds_, cur_round + 1, -1, idle_args);
51+
}
52+
while (cur_round == b->rounds_.load()) {
53+
_mm_pause();
54+
}
55+
}
56+
}
57+
58+
static_assert(sizeof(barrier_t) == 64, "size of barrier_t should be 64-byte");
59+
60+
void gc_init_barrier(barrier_t *b, int num_barriers, uint64_t thread_count) {
61+
for (int i = 0; i < num_barriers; i++) {
62+
b[i].total_ = thread_count;
63+
b[i].pending_.store(thread_count);
64+
b[i].rounds_.store(0);
65+
}
66+
}
67+
68+
#if GC_NEEDS_OMP_WRAPPER
69+
void WEAK_SYMBOL __kmpc_barrier(void *loc, int32_t global_tid) {
70+
#pragma omp barrier
71+
}
72+
73+
int WEAK_SYMBOL __kmpc_global_thread_num(void *loc) {
74+
return omp_get_thread_num();
75+
}
76+
77+
// The implementation was extracted and simplified from LLVM libomp
78+
// at openmp/runtime/src/kmp_sched.cpp
79+
void WEAK_SYMBOL __kmpc_for_static_init_8u(void *loc, int32_t gtid,
80+
int32_t schedtype,
81+
int32_t *plastiter, uint64_t *plower,
82+
uint64_t *pupper, int64_t *pstride,
83+
int64_t incr, int64_t chunk) {
84+
if (unlikely(schedtype != 34)) {
85+
std::abort();
86+
}
87+
const int32_t FALSE = 0;
88+
const int32_t TRUE = 1;
89+
using UT = uint64_t;
90+
// using ST = int64_t;
91+
/* this all has to be changed back to TID and such.. */
92+
uint32_t tid = gtid;
93+
uint32_t nth = omp_get_num_threads();
94+
UT trip_count;
95+
96+
/* special handling for zero-trip loops */
97+
if (incr > 0 ? (*pupper < *plower) : (*plower < *pupper)) {
98+
if (plastiter != nullptr)
99+
*plastiter = FALSE;
100+
/* leave pupper and plower set to entire iteration space */
101+
*pstride = incr; /* value should never be used */
102+
return;
103+
}
104+
105+
if (nth == 1) {
106+
if (plastiter != nullptr)
107+
*plastiter = TRUE;
108+
*pstride =
109+
(incr > 0) ? (*pupper - *plower + 1) : (-(*plower - *pupper + 1));
110+
return;
111+
}
112+
113+
/* compute trip count */
114+
if (incr == 1) {
115+
trip_count = *pupper - *plower + 1;
116+
} else if (incr == -1) {
117+
trip_count = *plower - *pupper + 1;
118+
} else if (incr > 0) {
119+
// upper-lower can exceed the limit of signed type
120+
trip_count = (UT)(*pupper - *plower) / incr + 1;
121+
} else {
122+
trip_count = (UT)(*plower - *pupper) / (-incr) + 1;
123+
}
124+
if (trip_count < nth) {
125+
if (tid < trip_count) {
126+
*pupper = *plower = *plower + tid * incr;
127+
} else {
128+
// set bounds so non-active threads execute no iterations
129+
*plower = *pupper + (incr > 0 ? 1 : -1);
130+
}
131+
if (plastiter != nullptr)
132+
*plastiter = (tid == trip_count - 1);
133+
} else {
134+
UT small_chunk = trip_count / nth;
135+
UT extras = trip_count % nth;
136+
*plower += incr * (tid * small_chunk + (tid < extras ? tid : extras));
137+
*pupper = *plower + small_chunk * incr - (tid < extras ? 0 : incr);
138+
if (plastiter != nullptr)
139+
*plastiter = (tid == nth - 1);
140+
}
141+
*pstride = trip_count;
142+
}
143+
144+
void WEAK_SYMBOL __kmpc_for_static_fini(void *ptr, int32_t v) {}
145+
146+
static thread_local int next_num_threads = 0;
147+
148+
/*!
149+
@ingroup PARALLEL
150+
The type for a microtask which gets passed to @ref __kmpc_fork_call().
151+
The arguments to the outlined function are
152+
@param global_tid the global thread identity of the thread executing the
153+
function.
154+
@param bound_tid the local identity of the thread executing the function
155+
@param ... pointers to shared variables accessed by the function.
156+
*/
157+
using kmpc_micro = void (*)(int32_t *global_tid, int32_t *bound_tid, ...);
158+
void WEAK_SYMBOL __kmpc_fork_call(void *loc, int32_t argc, void *pfunc, ...) {
159+
if (unlikely(argc != 1 && argc != 0)) {
160+
std::abort();
161+
}
162+
va_list ap;
163+
va_start(ap, pfunc);
164+
void *c = va_arg(ap, void *);
165+
int32_t global_tid = 0;
166+
if (unlikely(next_num_threads)) {
167+
#pragma omp parallel num_threads(next_num_threads)
168+
{
169+
kmpc_micro func = (kmpc_micro)(pfunc);
170+
func(&global_tid, nullptr, c);
171+
}
172+
next_num_threads = 0;
173+
} else {
174+
#pragma omp parallel
175+
{
176+
kmpc_micro func = (kmpc_micro)(pfunc);
177+
func(&global_tid, nullptr, c);
178+
}
179+
}
180+
va_end(ap);
181+
}
182+
183+
void WEAK_SYMBOL __kmpc_push_num_threads(void *loc, int32_t global_tid,
184+
int32_t num_threads) {
185+
next_num_threads = num_threads;
186+
}
187+
#endif
188+
}

scripts/license.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
# SPDX-License-Identifier: Apache-2.0
1616

1717
import datetime, sys, re, argparse
18-
from typing import Dict, Set
18+
from typing import Dict, Set, List
1919

2020
WIDTH: int = 80
21-
intel_license: list[str] = [
21+
intel_license: List[str] = [
2222
'Copyright \\(C\\) (\\d\\d\\d\\d-)?$YEAR Intel Corporation',
2323
'',
2424
'Licensed under the Apache License, Version 2.0 (the "License");',
@@ -35,7 +35,7 @@
3535
'SPDX-License-Identifier: Apache-2.0',
3636
]
3737

38-
llvm_license: list[str] = [
38+
llvm_license: List[str] = [
3939
"===-{1,2} $FILE - .* -*\\*- $LANG -\\*-===",
4040
'',
4141
'This file is licensed under the Apache License v2.0 with LLVM Exceptions.',
@@ -45,7 +45,7 @@
4545
"===-*===",
4646
]
4747

48-
def check_license(filepath: str, license: list[str], var: Dict[str, str], re_line: Set[int]):
48+
def check_license(filepath: str, license: List[str], var: Dict[str, str], re_line: Set[int]):
4949
with open(filepath, 'r') as f:
5050
idx: int = 0
5151
for line in f.readlines():
@@ -117,7 +117,7 @@ def use_llvm_license(path: str) -> bool:
117117
var: Dict[str, str] = {}
118118
re_line: Set[int] = set()
119119

120-
lic = list[str]
120+
lic = List[str]
121121

122122
if filepath.startswith("test/") or filepath.startswith("./test/"):
123123
continue

src/gc-cpu-runner/CMakeLists.txt

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
################################################################################
2+
# Copyright (C) 2024 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing,
11+
# software distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions
14+
# and limitations under the License.
15+
# SPDX-License-Identifier: Apache-2.0
16+
################################################################################
17+
118
if(GC_DEV_LINK_LLVM_DYLIB)
219
set(LLVM_LINK_COMPONENTS
320
LLVM
@@ -36,7 +53,8 @@ endif()
3653

3754
#LLVM_LINK_COMPONENTS is processed by LLVM cmake in add_llvm_executable
3855
set(gc_cpu_runner_libs
39-
${MLIR_LINK_COMPONENTS})
56+
${MLIR_LINK_COMPONENTS}
57+
GCCpuRuntime)
4058
add_mlir_tool(gc-cpu-runner
4159
gc-cpu-runner.cpp
4260
)

src/gc-cpu-runner/gc-cpu-runner.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
#include "llvm/Support/TargetSelect.h"
2828
#include <stdio.h>
2929

30+
extern int gc_runtime_keep_alive;
31+
3032
int main(int argc, char **argv) {
33+
// keeps GCCPURuntime linked
34+
gc_runtime_keep_alive = 0;
3135
llvm::InitLLVM y(argc, argv);
3236
llvm::InitializeNativeTarget();
3337
llvm::InitializeNativeTargetAsmPrinter();

test/gc/cpu-runner/tid.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: gc-opt %s --convert-cpuruntime-to-llvm --convert-openmp-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --reconcile-unrealized-casts | gc-cpu-runner -e main -entry-point-result=void | FileCheck %s
2+
module {
3+
func.func private @omp_get_thread_num() -> i32
4+
5+
func.func @check_parallel() {
6+
%c64 = arith.constant 64 : index
7+
%c1 = arith.constant 1 : index
8+
%c0 = arith.constant 0 : index
9+
%c8 = arith.constant 8 : index
10+
%0 = llvm.mlir.constant(1 : i64) : i64
11+
omp.parallel num_threads(%c8: index) {
12+
omp.wsloop {
13+
omp.loop_nest (%arg1, %arg2) : index = (%c0, %c0) to (%c1, %c64) step (%c1, %c1) {
14+
cpuruntime.printf "ITR %zu\n" %arg2 : index
15+
omp.yield
16+
}
17+
omp.terminator
18+
}
19+
%tid = func.call @omp_get_thread_num() : () -> i32
20+
cpuruntime.printf "EXIT %d\n" %tid : i32
21+
omp.terminator
22+
}
23+
return
24+
}
25+
26+
func.func @main() {
27+
%0 = func.call @omp_get_thread_num() : () -> i32
28+
cpuruntime.printf "TID %d\n" %0 : i32
29+
call @check_parallel() : ()->()
30+
return
31+
}
32+
// CHECK: TID 0
33+
// CHECK-COUNT-64: ITR {{[0-9]+}}
34+
// CHECK-NOT: ITR
35+
// CHECK-COUNT-8: EXIT {{[0-9]+}}
36+
// CHECK-NOT: EXIT
37+
}

0 commit comments

Comments
 (0)