Skip to content

Commit

Permalink
[wip][api] New API calls to pass compiler options to the correspondin…
Browse files Browse the repository at this point in the history
…g driver/s
  • Loading branch information
jjfumero committed Aug 23, 2024
1 parent 35e7e55 commit ffddb98
Show file tree
Hide file tree
Showing 21 changed files with 262 additions and 287 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.enums.ProfilerMode;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.runtime.ExecutorFrame;

/**
Expand Down Expand Up @@ -197,6 +198,10 @@ void withoutPrintKernel() {
taskGraph.withoutPrintKernel();
}

void withCompilerFlags(TornadoVMBackendType backendType, String compilerFlags) {
taskGraph.withCompilerFlags(backendType, compilerFlags);
}

void withGridScheduler(GridScheduler gridScheduler) {
taskGraph.withGridScheduler(gridScheduler);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import uk.ac.manchester.tornado.api.common.TornadoFunctions.Task8;
import uk.ac.manchester.tornado.api.common.TornadoFunctions.Task9;
import uk.ac.manchester.tornado.api.enums.ProfilerMode;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.exceptions.TornadoTaskRuntimeException;
import uk.ac.manchester.tornado.api.runtime.ExecutorFrame;
import uk.ac.manchester.tornado.api.runtime.TornadoAPIProvider;
Expand Down Expand Up @@ -889,6 +890,10 @@ void withoutPrintKernel() {
taskGraphImpl.withoutPrintKernel();
}

void withCompilerFlags(TornadoVMBackendType backendType, String compilerFlags) {
taskGraphImpl.withCompilerFlags(backendType, compilerFlags);
}

void withGridScheduler(GridScheduler gridScheduler) {
taskGraphImpl.withGridScheduler(gridScheduler);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.enums.ProfilerMode;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.api.runtime.ExecutorFrame;
Expand Down Expand Up @@ -382,6 +383,11 @@ public TornadoExecutionPlan withoutPrintKernel() {
return this;
}

public TornadoExecutionPlan withCompilerFlags(TornadoVMBackendType backend, String compilerFlags) {
tornadoExecutor.withCompilerFlags(backend, compilerFlags);
return this;
}

@Override
public void close() throws TornadoExecutionPlanException {
tornadoExecutor.freeDeviceMemory();
Expand Down Expand Up @@ -579,6 +585,10 @@ void withoutPrintKernel() {
immutableTaskGraphList.forEach(ImmutableTaskGraph::withoutPrintKernel);
}

void withCompilerFlags(TornadoVMBackendType backendType, String compilerFlags) {
immutableTaskGraphList.forEach(immutableTaskGraph -> immutableTaskGraph.withCompilerFlags(backendType, compilerFlags));
}

long getTotalBytesTransferred() {
return immutableTaskGraphList.stream().mapToLong(ImmutableTaskGraph::getTotalBytesTransferred).sum();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import uk.ac.manchester.tornado.api.common.TaskPackage;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.enums.ProfilerMode;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.memory.TaskMetaDataInterface;
import uk.ac.manchester.tornado.api.profiler.ProfilerInterface;
import uk.ac.manchester.tornado.api.runtime.ExecutorFrame;
Expand Down Expand Up @@ -122,4 +123,6 @@ public interface TornadoTaskGraphInterface extends ProfilerInterface {
void withGridScheduler(GridScheduler gridScheduler);

long getCurrentDeviceMemoryUsage();

void withCompilerFlags(TornadoVMBackendType backendType, String compilerFlags);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2013-2023, APT Group, Department of Computer Science,
* Copyright (c) 2013-2024, APT Group, Department of Computer Science,
* The University of Manchester.
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -25,7 +25,7 @@ public enum TornadoVMBackendType {
JAVA("Java"), //
VIRTUAL("Virtual");

String backendName;
final String backendName;

TornadoVMBackendType(String backendName) {
this.backendName = backendName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@

import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.common.TornadoEvents;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;

public interface TaskMetaDataInterface {

List<TornadoEvents> getProfiles(long executionPlanId);

String getCompilerFlags();

void setCompilerFlags(String flags);
void setCompilerFlags(TornadoVMBackendType backendType, String flags);

void setGlobalWork(long[] global);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
import java.util.StringTokenizer;
import java.util.concurrent.ConcurrentHashMap;

import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException;
import uk.ac.manchester.tornado.api.exceptions.TornadoCompilationException;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.drivers.opencl.enums.OCLBuildStatus;
import uk.ac.manchester.tornado.drivers.opencl.enums.OCLDeviceType;
Expand Down Expand Up @@ -74,6 +76,7 @@ public class OCLCodeCache {
private final String FPGA_CONFIGURATION_FILE = getProperty("tornado.fpga.conf.file", null);
private static final String FPGA_CLEANUP_SCRIPT = System.getenv("TORNADO_SDK") + "/bin/cleanFpga.sh";
private static final String FPGA_AWS_AFI_SCRIPT = System.getenv("TORNADO_SDK") + "/bin/aws_post_processing.sh";

/**
* OpenCL Binary Options: -Dtornado.precompiled.binary=<path/to/binary,task>
*
Expand Down Expand Up @@ -181,17 +184,13 @@ private void parseFPGAConfigurationFile() {
}

switch (token) {
case "DEVICE_NAME":
fpgaName = tokenizer.nextToken(" =");
break;
case "COMPILER":
fpgaCompiler = tokenizer.nextToken(" =");
break;
case "DIRECTORY_BITSTREAM":
case "DEVICE_NAME" -> fpgaName = tokenizer.nextToken(" =");
case "COMPILER" -> fpgaCompiler = tokenizer.nextToken(" =");
case "DIRECTORY_BITSTREAM" -> {
directoryBitstream = resolveAbsoluteDirectory(tokenizer.nextToken(" ="));
fpgaSourceDir = directoryBitstream;
break;
case "FLAGS":
}
case "FLAGS" -> {
StringBuilder buildFlags = new StringBuilder();

// Iterate over tokens that correspond to multiple flags
Expand All @@ -210,12 +209,10 @@ private void parseFPGAConfigurationFile() {
}
}
}
break;
case "AWS_ENV":
isFPGAInAWS = tokenizer.nextToken(" =").toLowerCase().equals("yes");
break;
default:
break;
}
case "AWS_ENV" -> isFPGAInAWS = tokenizer.nextToken(" =").toLowerCase().equals("yes");
default -> {
}
}
break;
}
Expand Down Expand Up @@ -281,12 +278,12 @@ private String[] processPrecompiledBinariesFromFile(String fileName) {
} catch (FileNotFoundException e) {
throw new RuntimeException("File: " + fileName + " not found");
} catch (IOException e) {
e.printStackTrace();
throw new TornadoCompilationException(e.getMessage());
} finally {
try {
fileContent.close();
} catch (IOException e) {
e.printStackTrace();
throw new TornadoCompilationException(e.getMessage());
}
}
return listBinaries.toString().split(",");
Expand Down Expand Up @@ -565,10 +562,8 @@ private void dumpKernelSource(String id, String entryPoint, String log, byte[] s

}

private void installCodeInCodeCache(OCLProgram program, TaskMetaData meta, String id, String entryPoint, OCLInstalledCode code) {

private void installCodeInCodeCache(OCLProgram program, String id, String entryPoint, OCLInstalledCode code) {
cache.put(id + "-" + entryPoint, code);

// BUG Apple does not seem to like implementing the OpenCL spec
// properly, this causes a SIGFAULT.
if ((OPENCL_CACHE_ENABLE || OPENCL_DUMP_BINS) && !deviceContext.getPlatformContext().getPlatform().getVendor().equalsIgnoreCase("Apple")) {
Expand Down Expand Up @@ -607,10 +602,7 @@ public OCLInstalledCode installSource(TaskMetaData meta, String id, String entry
RuntimeUtilities.dumpKernel(source);
}

final long t0 = System.nanoTime();
program.build(meta.getCompilerFlags());
final long t1 = System.nanoTime();

program.build(meta.getCompilerFlags(TornadoVMBackendType.OPENCL));
final OCLBuildStatus status = program.getStatus(deviceContext.getDeviceId());
logger.debug("\tOpenCL compilation status = %s", status.toString());

Expand All @@ -630,7 +622,7 @@ public OCLInstalledCode installSource(TaskMetaData meta, String id, String entry
final OCLInstalledCode code = new OCLInstalledCode(entryPoint, source, (OCLDeviceContext) deviceContext, program, kernel, isSPIRVBinary);
if (status == CL_BUILD_SUCCESS) {
logger.debug("\tOpenCL Kernel id = 0x%x", kernel.getOclKernelID());
installCodeInCodeCache(program, meta, id, entryPoint, code);
installCodeInCodeCache(program, id, entryPoint, code);
} else {
logger.warn("\tunable to compile %s", entryPoint);
code.invalidate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ private void setKernelArgs(final OCLKernelStackFrame kernelArgs, final XPUBuffer
}
}

private void printDebugLaunchInfo(final TaskMetaData meta) {
System.out.println("Running on: ");
System.out.println("\tPlatform: " + meta.getXPUDevice().getPlatformName());
if (meta.getXPUDevice() instanceof OCLTornadoDevice) {
System.out.println("\tDevice : " + ((OCLTornadoDevice) meta.getXPUDevice()).getPhysicalDevice().getDeviceName());
}
}

public int submitWithEvents(long executionPlanId, final OCLKernelStackFrame kernelArgs, final XPUBuffer atomicSpace, final TaskMetaData meta, final int[] events, long batchThreads) {
guarantee(kernel != null, "kernel is null");

Expand Down Expand Up @@ -229,25 +237,15 @@ public int submitWithEvents(long executionPlanId, final OCLKernelStackFrame kern
}
} else {
if (meta.isDebug()) {
System.out.println("Running on: ");
System.out.println("\tPlatform: " + meta.getLogicDevice().getPlatformName());
if (meta.getLogicDevice() instanceof OCLTornadoDevice) {
System.out.println("\tDevice : " + ((OCLTornadoDevice) meta.getLogicDevice()).getPhysicalDevice().getDeviceName());
}
printDebugLaunchInfo(meta);
}
if (meta.getGlobalWork() == null) {
task = deviceContext.enqueueNDRangeKernel(executionPlanId, kernel, 1, null, singleThreadGlobalWorkSize, singleThreadLocalWorkSize, waitEvents);
} else {
task = deviceContext.enqueueNDRangeKernel(executionPlanId, kernel, 1, null, meta.getGlobalWork(), meta.getLocalWork(), waitEvents);
}
}

if (meta.shouldDumpProfiles()) {
deviceContext.retainEvent(executionPlanId, task);
meta.addProfile(task);
}
}

return task;
}

Expand Down Expand Up @@ -296,16 +294,10 @@ private int submitParallel(long executionPlanId, final TaskMetaData meta, long b
}

private void launchKernel(long executionPlanId, final OCLKernelStackFrame callWrapper, final TaskMetaData meta, long batchThreads) {
final int task;
if (meta.isParallel() || meta.isWorkerGridAvailable()) {
task = submitParallel(executionPlanId, meta, batchThreads);
submitParallel(executionPlanId, meta, batchThreads);
} else {
task = submitSequential(executionPlanId, meta);
}

if (meta.shouldDumpProfiles()) {
deviceContext.retainEvent(executionPlanId, task);
meta.addProfile(task);
submitSequential(executionPlanId, meta);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public int[] calculateBlockDimension(PTXModule module, TaskMetaData taskMeta) {
return Arrays.stream(taskMeta.getLocalWork()).mapToInt(l -> (int) l).toArray();
}

long maxThreadsPerBlock = taskMeta.getLogicDevice().getPhysicalDevice().getMaxThreadsPerBlock();
long maxThreadsPerBlock = taskMeta.getXPUDevice().getPhysicalDevice().getMaxThreadsPerBlock();
if (taskMeta.getDims() > 1) {
maxThreadsPerBlock = module.getPotentialBlockSizeMaxOccupancy();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler;
import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions;
import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVFileReader;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException;
import uk.ac.manchester.tornado.drivers.common.logging.Logger;
import uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVInstalledCode;
Expand Down Expand Up @@ -60,7 +61,7 @@ public synchronized SPIRVInstalledCode installSPIRVBinary(TaskMetaData meta, Str
ZeModuleDescriptor moduleDesc = new ZeModuleDescriptor();
ZeBuildLogHandle buildLog = new ZeBuildLogHandle();
moduleDesc.setFormat(ZeModuleFormat.ZE_MODULE_FORMAT_IL_SPIRV);
moduleDesc.setBuildFlags("-ze-opt-level 2 -ze-opt-large-register-file");
moduleDesc.setBuildFlags(meta.getCompilerFlags(TornadoVMBackendType.SPIRV));

checkBinaryFileExists(pathToFile);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler;
import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions;
import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVFileReader;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.drivers.opencl.OCLErrorCode;
Expand Down Expand Up @@ -99,7 +100,8 @@ public SPIRVInstalledCode installSPIRVBinary(TaskMetaData meta, String id, Strin
}

OCLTargetDevice oclDevice = (OCLTargetDevice) deviceContext.getDevice().getDeviceRuntime();
int status = spirvoclNativeCompiler.clBuildProgram(programPointer, 1, new long[] { oclDevice.getDevicePointer() }, "");
String compilerFlags = meta.getCompilerFlags(TornadoVMBackendType.OPENCL);
int status = spirvoclNativeCompiler.clBuildProgram(programPointer, 1, new long[] { oclDevice.getDevicePointer() }, compilerFlags);
if (status != OCLErrorCode.CL_SUCCESS) {
String log = spirvoclNativeCompiler.clGetProgramBuildInfo(programPointer, oclDevice.getDevicePointer());
System.out.println(log);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public TaskMetaData getMeta() {
}

public TornadoXPUDevice getDeviceMapping() {
return meta.getLogicDevice();
return meta.getXPUDevice();
}

public boolean hasMeta() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ public Deque<Integer> getActiveDeviceIndexes() {

/**
* It retrieves a list of tasks for a specific device and driver. Both
* deviceContext and driverIndex are checked to ensure the correct task
* deviceContext and backendIndex are checked to ensure the correct task
* assignment.
*
* @param deviceContext
Expand Down Expand Up @@ -493,7 +493,7 @@ public List<SchedulableTask> getTasksForDevice(TornadoDeviceContext deviceContex
*/
@Deprecated
public TornadoXPUDevice getDefaultDevice() {
return meta.getLogicDevice();
return meta.getXPUDevice();
}

public SchedulableTask getTask(String id) {
Expand Down Expand Up @@ -534,13 +534,13 @@ public void sync() {
Object object = objects.get(i);
if (object != null) {
final LocalObjectState localState = objectState.get(i);
Event event = localState.sync(executionPlanId, object, meta().getLogicDevice());
Event event = localState.sync(executionPlanId, object, meta().getXPUDevice());

if (TornadoOptions.isProfilerEnabled() && event != null) {
long value = profiler.getTimer(ProfilerType.COPY_OUT_TIME_SYNC);
value += event.getElapsedTime();
profiler.setTimer(ProfilerType.COPY_OUT_TIME_SYNC, value);
XPUDeviceBufferState deviceObjectState = localState.getDataObjectState().getDeviceBufferState(meta().getLogicDevice());
XPUDeviceBufferState deviceObjectState = localState.getDataObjectState().getDeviceBufferState(meta().getXPUDevice());
profiler.addValueToMetric(ProfilerType.COPY_OUT_SIZE_BYTES_SYNC, TimeProfiler.NO_TASK_NAME, deviceObjectState.getXPUBuffer().size());
}
}
Expand Down
Loading

0 comments on commit ffddb98

Please sign in to comment.