Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Shared Execution Context fixed when having Multi-Threaded Java applicatons #557

Merged
merged 7 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ tornado-examples/target/
tornado-runtime/target/
tornado.iml
tornado_unittests.log

OpenCL-Headers/
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public interface TornadoDeviceContext {

boolean isFP64Supported();

boolean isCached(String methodName, SchedulableTask task);
boolean isCached(long executionPlanId, String methodName, SchedulableTask task);

int getDeviceIndex();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public class TornadoExecutionPlan implements AutoCloseable {
*/
public TornadoExecutionPlan(ImmutableTaskGraph... immutableTaskGraphs) {
this.tornadoExecutor = new TornadoExecutor(immutableTaskGraphs);
long id = globalExecutionPlanCounter.incrementAndGet();
final long id = globalExecutionPlanCounter.incrementAndGet();
executionPackage = new ExecutorFrame(id);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,25 @@ public class OCLDeviceContext implements OCLDeviceContextInterface {
/**
* Table to represent {@link uk.ac.manchester.tornado.api.TornadoExecutionPlan} -> {@link OCLCommandQueueTable}
*/
private Map<Long, OCLCommandQueueTable> commandQueueTable;

private final Map<Long, OCLCommandQueueTable> commandQueueTable;
private final OCLContext context;
private final PowerMetric powerMetric;
private final OCLMemoryManager memoryManager;
private final OCLCodeCache codeCache;
private final Map<Long, OCLEventPool> oclEventPool;
private final TornadoBufferProvider bufferProvider;
private boolean wasReset;
private Set<Long> executionIDs;
private final Set<Long> executionIDs;

/**
* Map table to represent the compiled-code per execution plan. Each entry in the execution plan has its own
* code cache. The code cache manages the compilation and the cache for each task within an execution plan.
*/
private final Map<Long, OCLCodeCache> codeCache;

public OCLDeviceContext(OCLTargetDevice device, OCLContext context) {
this.device = device;
this.context = context;
this.memoryManager = new OCLMemoryManager(this);
this.codeCache = new OCLCodeCache(this);
this.oclEventPool = new ConcurrentHashMap<>();
this.bufferProvider = new OCLBufferProvider(this);
this.commandQueueTable = new ConcurrentHashMap<>();
Expand All @@ -85,6 +88,7 @@ public OCLDeviceContext(OCLTargetDevice device, OCLContext context) {
} else {
this.powerMetric = new OCLEmptyPowerMetric();
}
codeCache = new ConcurrentHashMap<>();
}

private boolean isDeviceContextOfNvidia() {
Expand Down Expand Up @@ -523,7 +527,9 @@ public void reset(long executionPlanId) {
executionIDs.remove(executionPlanId);
}
getMemoryManager().releaseKernelStackFrame(executionPlanId);
codeCache.reset();
OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId);
oclCodeCache.reset();
codeCache.remove(executionPlanId);
wasReset = true;
}

Expand Down Expand Up @@ -588,11 +594,6 @@ public int getDevicePlatform() {
return context.getPlatformIndex();
}

public void retainEvent(long executionPlanId, int localEventId) {
OCLEventPool eventPool = getOCLEventPool(executionPlanId);
eventPool.retainEvent(localEventId);
}

@Override
public Event resolveEvent(long executionPlanId, int event) {
if (event == -1) {
Expand All @@ -609,57 +610,66 @@ public void flush(long executionPlanId) {
commandQueue.flush();
}

public void finish(long executionPlanId) {
OCLCommandQueue commandQueue = getCommandQueue(executionPlanId);
commandQueue.finish();
}

@Override
public void flushEvents(long executionPlanId) {
OCLCommandQueue commandQueue = getCommandQueue(executionPlanId);
commandQueue.flushEvents();
}

private OCLCodeCache getOCLCodeCache(long executionPlanId) {
if (!codeCache.containsKey(executionPlanId)) {
codeCache.put(executionPlanId, new OCLCodeCache(this));
}
return codeCache.get(executionPlanId);
}

@Override
public boolean isKernelAvailable() {
return codeCache.isKernelAvailable();
public boolean isKernelAvailable(long executionPlanId) {
OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId);
return oclCodeCache.isKernelAvailable();
}

public OCLInstalledCode installCode(OCLCompilationResult result) {
return installCode(result.getMeta(), result.getId(), result.getName(), result.getTargetCode());
public OCLInstalledCode installCode(long executionPlanId, OCLCompilationResult result) {
return installCode(executionPlanId, result.getMeta(), result.getId(), result.getName(), result.getTargetCode());
}

@Override
public OCLInstalledCode installCode(TaskDataContext meta, String id, String entryPoint, byte[] code) {
public OCLInstalledCode installCode(long executionPlanId, TaskDataContext meta, String id, String entryPoint, byte[] code) {
entryPoint = checkKernelName(entryPoint);
return codeCache.installSource(meta, id, entryPoint, code);
OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId);
return oclCodeCache.installSource(meta, id, entryPoint, code);
}

@Override
public OCLInstalledCode installCode(String id, String entryPoint, byte[] code, boolean printKernel) {
return codeCache.installFPGASource(id, entryPoint, code, printKernel);
public OCLInstalledCode installCode(long executionPlanId, String id, String entryPoint, byte[] code, boolean printKernel) {
OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId);
return oclCodeCache.installFPGASource(id, entryPoint, code, printKernel);
}

@Override
public boolean isCached(String id, String entryPoint) {
public boolean isCached(long executionPlanId, String id, String entryPoint) {
entryPoint = checkKernelName(entryPoint);
return codeCache.isCached(id + "-" + entryPoint);
OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId);
return oclCodeCache.isCached(id + "-" + entryPoint);
}

@Override
public boolean isCached(String methodName, SchedulableTask task) {
public boolean isCached(long executionPlanId, String methodName, SchedulableTask task) {
methodName = checkKernelName(methodName);
return codeCache.isCached(task.getId() + "-" + methodName);
OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId);
return oclCodeCache.isCached(task.getId() + "-" + methodName);
}

public OCLInstalledCode getInstalledCode(String id, String entryPoint) {
@Override
public OCLInstalledCode getInstalledCode(long executionPlanId, String id, String entryPoint) {
entryPoint = checkKernelName(entryPoint);
return codeCache.getInstalledCode(id, entryPoint);
OCLCodeCache oclCodeCache = getOCLCodeCache(executionPlanId);
return oclCodeCache.getInstalledCode(id, entryPoint);
}

@Override
public OCLCodeCache getCodeCache() {
return this.codeCache;
public OCLCodeCache getCodeCache(long executionPlanId) {
return getOCLCodeCache(executionPlanId);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ public interface OCLDeviceContextInterface extends TornadoDeviceContext {

OCLTargetDevice getDevice();

OCLCodeCache getCodeCache();
OCLCodeCache getCodeCache(long executionPlanId);

boolean isCached(String id, String entryPoint);
boolean isCached(long executionPlanId, String id, String entryPoint);

OCLInstalledCode getInstalledCode(String id, String entryPoint);
OCLInstalledCode getInstalledCode(long executionPlanId, String id, String entryPoint);

OCLInstalledCode installCode(String id, String entryPoint, byte[] code, boolean printKernel);
OCLInstalledCode installCode(long executionPlanId, OCLCompilationResult result);

OCLInstalledCode installCode(OCLCompilationResult result);
OCLInstalledCode installCode(long executionPlanId, TaskDataContext meta, String id, String entryPoint, byte[] code);

OCLInstalledCode installCode(TaskDataContext meta, String id, String entryPoint, byte[] code);
OCLInstalledCode installCode(long executionPlanId, String id, String entryPoint, byte[] code, boolean printKernel);

boolean isKernelAvailable();
boolean isKernelAvailable(long executionPlanId);

void reset(long executionPlanId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ public static void main(String[] args) {

OCLCompilationResult result = OCLCompiler.compileCodeForDevice(resolvedMethod, new Object[] {}, meta, (OCLProviders) backend.getProviders(), backend, new EmptyProfiler());

OCLInstalledCode code = OpenCL.defaultDevice().getDeviceContext().installCode(result);
final long executionPlanId = 0;

OCLInstalledCode code = OpenCL.defaultDevice().getDeviceContext().installCode(executionPlanId, result);

for (byte b : code.getCode()) {
System.out.printf("%c", b);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,20 +238,20 @@ public XPUBuffer createOrReuseAtomicsBuffer(int[] array) {
return reuseBuffer;
}

private boolean isOpenCLPreLoadBinary(OCLDeviceContextInterface deviceContext, String deviceInfo) {
OCLCodeCache installedCode = deviceContext.getCodeCache();
private boolean isOpenCLPreLoadBinary(long executionPlanId, OCLDeviceContextInterface deviceContext, String deviceInfo) {
OCLCodeCache installedCode = deviceContext.getCodeCache(executionPlanId);
return (installedCode.isLoadBinaryOptionEnabled() && (installedCode.getOpenCLBinary(deviceInfo) != null));
}

private TornadoInstalledCode compileTask(SchedulableTask task) {
private TornadoInstalledCode compileTask(long executionPlanId, SchedulableTask task) {
final OCLDeviceContextInterface deviceContext = getDeviceContext();
final CompilableTask executable = (CompilableTask) task;
final ResolvedJavaMethod resolvedMethod = TornadoCoreRuntime.getTornadoRuntime().resolveMethod(executable.getMethod());
final Sketch sketch = TornadoSketcher.lookup(resolvedMethod, task.meta().getBackendIndex(), task.meta().getDeviceIndex());

// Return the code from the cache
if (!task.shouldCompile() && deviceContext.isCached(task.getId(), resolvedMethod.getName())) {
return deviceContext.getInstalledCode(task.getId(), resolvedMethod.getName());
if (!task.shouldCompile() && deviceContext.isCached(executionPlanId, task.getId(), resolvedMethod.getName())) {
return deviceContext.getInstalledCode(executionPlanId, task.getId(), resolvedMethod.getName());
}

// copy meta data into task
Expand Down Expand Up @@ -289,10 +289,10 @@ private TornadoInstalledCode compileTask(SchedulableTask task) {
OCLInstalledCode installedCode;
if (OCLBackend.isDeviceAnFPGAAccelerator(deviceContext)) {
// A) for FPGA
installedCode = deviceContext.installCode(result.getId(), result.getName(), result.getTargetCode(), task.meta().isPrintKernelEnabled());
installedCode = deviceContext.installCode(executionPlanId, result.getId(), result.getName(), result.getTargetCode(), task.meta().isPrintKernelEnabled());
} else {
// B) for CPU multi-core or GPU
installedCode = deviceContext.installCode(result);
installedCode = deviceContext.installCode(executionPlanId, result);
}
profiler.stop(ProfilerType.TASK_COMPILE_DRIVER_TIME, taskMeta.getId());
profiler.sum(ProfilerType.TOTAL_DRIVER_COMPILE_TIME, profiler.getTaskTimer(ProfilerType.TASK_COMPILE_DRIVER_TIME, taskMeta.getId()));
Expand All @@ -310,11 +310,11 @@ private TornadoInstalledCode compileTask(SchedulableTask task) {
}
}

private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) {
private TornadoInstalledCode compilePreBuiltTask(long executionPlanId, SchedulableTask task) {
final OCLDeviceContextInterface deviceContext = getDeviceContext();
final PrebuiltTask executable = (PrebuiltTask) task;
if (deviceContext.isCached(task.getId(), executable.getEntryPoint())) {
return deviceContext.getInstalledCode(task.getId(), executable.getEntryPoint());
if (deviceContext.isCached(executionPlanId, task.getId(), executable.getEntryPoint())) {
return deviceContext.getInstalledCode(executionPlanId, task.getId(), executable.getEntryPoint());
}

final Path path = Paths.get(executable.getFilename());
Expand All @@ -325,10 +325,10 @@ private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) {
OCLInstalledCode installedCode;
if (OCLBackend.isDeviceAnFPGAAccelerator(deviceContext)) {
// A) for FPGA
installedCode = deviceContext.installCode(task.getId(), executable.getEntryPoint(), source, task.meta().isPrintKernelEnabled());
installedCode = deviceContext.installCode(executionPlanId, task.getId(), executable.getEntryPoint(), source, task.meta().isPrintKernelEnabled());
} else {
// B) for CPU multi-core or GPU
installedCode = deviceContext.installCode(executable.meta(), task.getId(), executable.getEntryPoint(), source);
installedCode = deviceContext.installCode(executionPlanId, executable.meta(), task.getId(), executable.getEntryPoint(), source);
}
return installedCode;
} catch (IOException e) {
Expand All @@ -337,11 +337,11 @@ private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) {
return null;
}

private TornadoInstalledCode compileJavaToAccelerator(SchedulableTask task) {
private TornadoInstalledCode compileJavaToAccelerator(long executionPlanId, SchedulableTask task) {
if (task instanceof CompilableTask) {
return compileTask(task);
return compileTask(executionPlanId, task);
} else if (task instanceof PrebuiltTask) {
return compilePreBuiltTask(task);
return compilePreBuiltTask(executionPlanId, task);
}
TornadoInternalError.shouldNotReachHere("task of unknown type: " + task.getClass().getSimpleName());
return null;
Expand All @@ -351,15 +351,15 @@ private String getTaskEntryName(SchedulableTask task) {
return task.getTaskName();
}

private TornadoInstalledCode loadPreCompiledBinaryForTask(SchedulableTask task) {
private TornadoInstalledCode loadPreCompiledBinaryForTask(long executionPlanId, SchedulableTask task) {
final OCLDeviceContextInterface deviceContext = getDeviceContext();
final OCLCodeCache codeCache = deviceContext.getCodeCache();
final OCLCodeCache codeCache = deviceContext.getCodeCache(executionPlanId);
final String deviceFullName = getFullTaskIdDevice(task);
final Path lookupPath = Paths.get(codeCache.getOpenCLBinary(deviceFullName));
String entry = getTaskEntryName(task);

if (deviceContext.getInstalledCode(task.getId(), entry) != null) {
return deviceContext.getInstalledCode(task.getId(), entry);
if (deviceContext.getInstalledCode(executionPlanId, task.getId(), entry) != null) {
return deviceContext.getInstalledCode(executionPlanId, task.getId(), entry);
} else {
return codeCache.installEntryPointForBinaryForFPGAs(task.getId(), lookupPath, entry);
}
Expand All @@ -376,16 +376,16 @@ private String getFullTaskIdDevice(SchedulableTask task) {
}

@Override
public boolean isFullJITMode(SchedulableTask task) {
public boolean isFullJITMode(long executionPlanId, SchedulableTask task) {
final OCLDeviceContextInterface deviceContext = getDeviceContext();
final String deviceFullName = getFullTaskIdDevice(task);
return (!isOpenCLPreLoadBinary(deviceContext, deviceFullName) && deviceContext.isPlatformFPGA());
return (!isOpenCLPreLoadBinary(executionPlanId, deviceContext, deviceFullName) && deviceContext.isPlatformFPGA());
}

@Override
public TornadoInstalledCode getCodeFromCache(SchedulableTask task) {
public TornadoInstalledCode getCodeFromCache(long executionPlanId, SchedulableTask task) {
String entry = getTaskEntryName(task);
return getDeviceContext().getInstalledCode(task.getId(), entry);
return getDeviceContext().getInstalledCode(executionPlanId, task.getId(), entry);
}

@Override
Expand Down Expand Up @@ -441,34 +441,34 @@ public boolean checkAtomicsParametersForTask(SchedulableTask task) {
return TornadoAtomicIntegerNode.globalAtomicsParameters.containsKey(task.meta().getCompiledResolvedJavaMethod());
}

private boolean isJITTaskForFGPA(SchedulableTask task) {
private boolean isJITTaskForFGPA(long executionPlanId, SchedulableTask task) {
final OCLDeviceContextInterface deviceContext = getDeviceContext();
final String deviceFullName = getFullTaskIdDevice(task);
return !isOpenCLPreLoadBinary(deviceContext, deviceFullName) && deviceContext.isPlatformFPGA();
return !isOpenCLPreLoadBinary(executionPlanId, deviceContext, deviceFullName) && deviceContext.isPlatformFPGA();
}

private boolean isJITTaskForGPUsAndCPUs(SchedulableTask task) {
private boolean isJITTaskForGPUsAndCPUs(long executionplanId, SchedulableTask task) {
final OCLDeviceContextInterface deviceContext = getDeviceContext();
final String deviceFullName = getFullTaskIdDevice(task);
return !isOpenCLPreLoadBinary(deviceContext, deviceFullName) && !deviceContext.isPlatformFPGA();
return !isOpenCLPreLoadBinary(executionplanId, deviceContext, deviceFullName) && !deviceContext.isPlatformFPGA();
}

private TornadoInstalledCode compileJavaForFPGAs(SchedulableTask task) {
TornadoInstalledCode tornadoInstalledCode = compileJavaToAccelerator(task);
private TornadoInstalledCode compileJavaForFPGAs(long executionPlanId, SchedulableTask task) {
TornadoInstalledCode tornadoInstalledCode = compileJavaToAccelerator(executionPlanId, task);
if (tornadoInstalledCode != null) {
return loadPreCompiledBinaryForTask(task);
return loadPreCompiledBinaryForTask(executionPlanId, task);
}
return null;
}

@Override
public TornadoInstalledCode installCode(SchedulableTask task) {
if (isJITTaskForFGPA(task)) {
return compileJavaForFPGAs(task);
} else if (isJITTaskForGPUsAndCPUs(task)) {
return compileJavaToAccelerator(task);
public TornadoInstalledCode installCode(long executionPlanId, SchedulableTask task) {
if (isJITTaskForFGPA(executionPlanId, task)) {
return compileJavaForFPGAs(executionPlanId, task);
} else if (isJITTaskForGPUsAndCPUs(executionPlanId, task)) {
return compileJavaToAccelerator(executionPlanId, task);
}
return loadPreCompiledBinaryForTask(task);
return loadPreCompiledBinaryForTask(executionPlanId, task);
}

private XPUBuffer createArrayWrapper(Class<?> type, OCLDeviceContext device, long batchSize) {
Expand Down
Loading