From ffddb981441189b0124ab517b2c8f4e6d50b105b Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Fri, 23 Aug 2024 12:21:43 +0200 Subject: [PATCH] [wip][api] New API calls to pass compiler options to the corresponding driver/s --- .../tornado/api/ImmutableTaskGraph.java | 5 + .../ac/manchester/tornado/api/TaskGraph.java | 5 + .../tornado/api/TornadoExecutionPlan.java | 10 + .../api/TornadoTaskGraphInterface.java | 3 + .../api/enums/TornadoVMBackendType.java | 4 +- .../api/memory/TaskMetaDataInterface.java | 5 +- .../tornado/drivers/opencl/OCLCodeCache.java | 42 ++- .../opencl/graal/OCLInstalledCode.java | 30 +- .../tornado/drivers/ptx/PTXScheduler.java | 2 +- .../spirv/SPIRVLevelZeroCodeCache.java | 3 +- .../drivers/spirv/SPIRVOCLCodeCache.java | 4 +- .../graal/phases/TornadoHighTierContext.java | 2 +- .../graph/TornadoExecutionContext.java | 8 +- .../interpreter/TornadoVMInterpreter.java | 21 +- .../tornado/runtime/tasks/CompilableTask.java | 2 +- .../tornado/runtime/tasks/PrebuiltTask.java | 2 +- .../runtime/tasks/ReduceTaskGraph.java | 34 +-- .../runtime/tasks/TornadoTaskGraph.java | 259 +++++++++--------- .../runtime/tasks/meta/AbstractMetaData.java | 77 ++---- .../runtime/tasks/meta/MetaDataUtils.java | 7 +- .../runtime/tasks/meta/TaskMetaData.java | 24 +- 21 files changed, 262 insertions(+), 287 deletions(-) diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/ImmutableTaskGraph.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/ImmutableTaskGraph.java index 72562b0a70..875d4699f1 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/ImmutableTaskGraph.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/ImmutableTaskGraph.java @@ -21,6 +21,7 @@ import uk.ac.manchester.tornado.api.common.TornadoDevice; import uk.ac.manchester.tornado.api.enums.ProfilerMode; +import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType; import uk.ac.manchester.tornado.api.runtime.ExecutorFrame; /** @@ -197,6 +198,10 @@ void withoutPrintKernel() { taskGraph.withoutPrintKernel(); } + void withCompilerFlags(TornadoVMBackendType backendType, String compilerFlags) { + taskGraph.withCompilerFlags(backendType, compilerFlags); + } + void withGridScheduler(GridScheduler gridScheduler) { taskGraph.withGridScheduler(gridScheduler); } diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraph.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraph.java index 95bc98af58..34bc553085 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraph.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraph.java @@ -41,6 +41,7 @@ import uk.ac.manchester.tornado.api.common.TornadoFunctions.Task8; import uk.ac.manchester.tornado.api.common.TornadoFunctions.Task9; import uk.ac.manchester.tornado.api.enums.ProfilerMode; +import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType; import uk.ac.manchester.tornado.api.exceptions.TornadoTaskRuntimeException; import uk.ac.manchester.tornado.api.runtime.ExecutorFrame; import uk.ac.manchester.tornado.api.runtime.TornadoAPIProvider; @@ -889,6 +890,10 @@ void withoutPrintKernel() { taskGraphImpl.withoutPrintKernel(); } + void withCompilerFlags(TornadoVMBackendType backendType, String compilerFlags) { + taskGraphImpl.withCompilerFlags(backendType, compilerFlags); + } + void withGridScheduler(GridScheduler gridScheduler) { taskGraphImpl.withGridScheduler(gridScheduler); } diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoExecutionPlan.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoExecutionPlan.java index 25f8e811d4..32c0b1f8f2 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoExecutionPlan.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoExecutionPlan.java @@ -25,6 +25,7 @@ import uk.ac.manchester.tornado.api.common.TornadoDevice; import uk.ac.manchester.tornado.api.enums.ProfilerMode; +import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType; import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException; import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException; import uk.ac.manchester.tornado.api.runtime.ExecutorFrame; @@ -382,6 +383,11 @@ public TornadoExecutionPlan withoutPrintKernel() { return this; } + public TornadoExecutionPlan withCompilerFlags(TornadoVMBackendType backend, String compilerFlags) { + tornadoExecutor.withCompilerFlags(backend, compilerFlags); + return this; + } + @Override public void close() throws TornadoExecutionPlanException { tornadoExecutor.freeDeviceMemory(); @@ -579,6 +585,10 @@ void withoutPrintKernel() { immutableTaskGraphList.forEach(ImmutableTaskGraph::withoutPrintKernel); } + void withCompilerFlags(TornadoVMBackendType backendType, String compilerFlags) { + immutableTaskGraphList.forEach(immutableTaskGraph -> immutableTaskGraph.withCompilerFlags(backendType, compilerFlags)); + } + long getTotalBytesTransferred() { return immutableTaskGraphList.stream().mapToLong(ImmutableTaskGraph::getTotalBytesTransferred).sum(); } diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoTaskGraphInterface.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoTaskGraphInterface.java index edb508d628..5851f9b6f7 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoTaskGraphInterface.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TornadoTaskGraphInterface.java @@ -25,6 +25,7 @@ import uk.ac.manchester.tornado.api.common.TaskPackage; import uk.ac.manchester.tornado.api.common.TornadoDevice; import uk.ac.manchester.tornado.api.enums.ProfilerMode; +import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType; import uk.ac.manchester.tornado.api.memory.TaskMetaDataInterface; import uk.ac.manchester.tornado.api.profiler.ProfilerInterface; import uk.ac.manchester.tornado.api.runtime.ExecutorFrame; @@ -122,4 +123,6 @@ public interface TornadoTaskGraphInterface extends ProfilerInterface { void withGridScheduler(GridScheduler gridScheduler); long getCurrentDeviceMemoryUsage(); + + void withCompilerFlags(TornadoVMBackendType backendType, String compilerFlags); } diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/enums/TornadoVMBackendType.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/enums/TornadoVMBackendType.java index 402e7bcd40..dec8e7bf2b 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/enums/TornadoVMBackendType.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/enums/TornadoVMBackendType.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013-2023, APT Group, Department of Computer Science, + * Copyright (c) 2013-2024, APT Group, Department of Computer Science, * The University of Manchester. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +25,7 @@ public enum TornadoVMBackendType { JAVA("Java"), // VIRTUAL("Virtual"); - String backendName; + final String backendName; TornadoVMBackendType(String backendName) { this.backendName = backendName; diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/memory/TaskMetaDataInterface.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/memory/TaskMetaDataInterface.java index 368b33ab93..619d32a5a0 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/memory/TaskMetaDataInterface.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/memory/TaskMetaDataInterface.java @@ -21,14 +21,13 @@ import uk.ac.manchester.tornado.api.common.TornadoDevice; import uk.ac.manchester.tornado.api.common.TornadoEvents; +import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType; public interface TaskMetaDataInterface { List getProfiles(long executionPlanId); - String getCompilerFlags(); - - void setCompilerFlags(String flags); + void setCompilerFlags(TornadoVMBackendType backendType, String flags); void setGlobalWork(long[] global); diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLCodeCache.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLCodeCache.java index a0c41e8e8b..ef3abcdf4d 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLCodeCache.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/OCLCodeCache.java @@ -44,7 +44,9 @@ import java.util.StringTokenizer; import java.util.concurrent.ConcurrentHashMap; +import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType; import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException; +import uk.ac.manchester.tornado.api.exceptions.TornadoCompilationException; import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException; import uk.ac.manchester.tornado.drivers.opencl.enums.OCLBuildStatus; import uk.ac.manchester.tornado.drivers.opencl.enums.OCLDeviceType; @@ -74,6 +76,7 @@ public class OCLCodeCache { private final String FPGA_CONFIGURATION_FILE = getProperty("tornado.fpga.conf.file", null); private static final String FPGA_CLEANUP_SCRIPT = System.getenv("TORNADO_SDK") + "/bin/cleanFpga.sh"; private static final String FPGA_AWS_AFI_SCRIPT = System.getenv("TORNADO_SDK") + "/bin/aws_post_processing.sh"; + /** * OpenCL Binary Options: -Dtornado.precompiled.binary= * @@ -181,17 +184,13 @@ private void parseFPGAConfigurationFile() { } switch (token) { - case "DEVICE_NAME": - fpgaName = tokenizer.nextToken(" ="); - break; - case "COMPILER": - fpgaCompiler = tokenizer.nextToken(" ="); - break; - case "DIRECTORY_BITSTREAM": + case "DEVICE_NAME" -> fpgaName = tokenizer.nextToken(" ="); + case "COMPILER" -> fpgaCompiler = tokenizer.nextToken(" ="); + case "DIRECTORY_BITSTREAM" -> { directoryBitstream = resolveAbsoluteDirectory(tokenizer.nextToken(" =")); fpgaSourceDir = directoryBitstream; - break; - case "FLAGS": + } + case "FLAGS" -> { StringBuilder buildFlags = new StringBuilder(); // Iterate over tokens that correspond to multiple flags @@ -210,12 +209,10 @@ private void parseFPGAConfigurationFile() { } } } - break; - case "AWS_ENV": - isFPGAInAWS = tokenizer.nextToken(" =").toLowerCase().equals("yes"); - break; - default: - break; + } + case "AWS_ENV" -> isFPGAInAWS = tokenizer.nextToken(" =").toLowerCase().equals("yes"); + default -> { + } } break; } @@ -281,12 +278,12 @@ private String[] processPrecompiledBinariesFromFile(String fileName) { } catch (FileNotFoundException e) { throw new RuntimeException("File: " + fileName + " not found"); } catch (IOException e) { - e.printStackTrace(); + throw new TornadoCompilationException(e.getMessage()); } finally { try { fileContent.close(); } catch (IOException e) { - e.printStackTrace(); + throw new TornadoCompilationException(e.getMessage()); } } return listBinaries.toString().split(","); @@ -565,10 +562,8 @@ private void dumpKernelSource(String id, String entryPoint, String log, byte[] s } - private void installCodeInCodeCache(OCLProgram program, TaskMetaData meta, String id, String entryPoint, OCLInstalledCode code) { - + private void installCodeInCodeCache(OCLProgram program, String id, String entryPoint, OCLInstalledCode code) { cache.put(id + "-" + entryPoint, code); - // BUG Apple does not seem to like implementing the OpenCL spec // properly, this causes a SIGFAULT. if ((OPENCL_CACHE_ENABLE || OPENCL_DUMP_BINS) && !deviceContext.getPlatformContext().getPlatform().getVendor().equalsIgnoreCase("Apple")) { @@ -607,10 +602,7 @@ public OCLInstalledCode installSource(TaskMetaData meta, String id, String entry RuntimeUtilities.dumpKernel(source); } - final long t0 = System.nanoTime(); - program.build(meta.getCompilerFlags()); - final long t1 = System.nanoTime(); - + program.build(meta.getCompilerFlags(TornadoVMBackendType.OPENCL)); final OCLBuildStatus status = program.getStatus(deviceContext.getDeviceId()); logger.debug("\tOpenCL compilation status = %s", status.toString()); @@ -630,7 +622,7 @@ public OCLInstalledCode installSource(TaskMetaData meta, String id, String entry final OCLInstalledCode code = new OCLInstalledCode(entryPoint, source, (OCLDeviceContext) deviceContext, program, kernel, isSPIRVBinary); if (status == CL_BUILD_SUCCESS) { logger.debug("\tOpenCL Kernel id = 0x%x", kernel.getOclKernelID()); - installCodeInCodeCache(program, meta, id, entryPoint, code); + installCodeInCodeCache(program, id, entryPoint, code); } else { logger.warn("\tunable to compile %s", entryPoint); code.invalidate(); diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/OCLInstalledCode.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/OCLInstalledCode.java index 21e61d4576..4248e6d2fb 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/OCLInstalledCode.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/OCLInstalledCode.java @@ -201,6 +201,14 @@ private void setKernelArgs(final OCLKernelStackFrame kernelArgs, final XPUBuffer } } + private void printDebugLaunchInfo(final TaskMetaData meta) { + System.out.println("Running on: "); + System.out.println("\tPlatform: " + meta.getXPUDevice().getPlatformName()); + if (meta.getXPUDevice() instanceof OCLTornadoDevice) { + System.out.println("\tDevice : " + ((OCLTornadoDevice) meta.getXPUDevice()).getPhysicalDevice().getDeviceName()); + } + } + public int submitWithEvents(long executionPlanId, final OCLKernelStackFrame kernelArgs, final XPUBuffer atomicSpace, final TaskMetaData meta, final int[] events, long batchThreads) { guarantee(kernel != null, "kernel is null"); @@ -229,11 +237,7 @@ public int submitWithEvents(long executionPlanId, final OCLKernelStackFrame kern } } else { if (meta.isDebug()) { - System.out.println("Running on: "); - System.out.println("\tPlatform: " + meta.getLogicDevice().getPlatformName()); - if (meta.getLogicDevice() instanceof OCLTornadoDevice) { - System.out.println("\tDevice : " + ((OCLTornadoDevice) meta.getLogicDevice()).getPhysicalDevice().getDeviceName()); - } + printDebugLaunchInfo(meta); } if (meta.getGlobalWork() == null) { task = deviceContext.enqueueNDRangeKernel(executionPlanId, kernel, 1, null, singleThreadGlobalWorkSize, singleThreadLocalWorkSize, waitEvents); @@ -241,13 +245,7 @@ public int submitWithEvents(long executionPlanId, final OCLKernelStackFrame kern task = deviceContext.enqueueNDRangeKernel(executionPlanId, kernel, 1, null, meta.getGlobalWork(), meta.getLocalWork(), waitEvents); } } - - if (meta.shouldDumpProfiles()) { - deviceContext.retainEvent(executionPlanId, task); - meta.addProfile(task); - } } - return task; } @@ -296,16 +294,10 @@ private int submitParallel(long executionPlanId, final TaskMetaData meta, long b } private void launchKernel(long executionPlanId, final OCLKernelStackFrame callWrapper, final TaskMetaData meta, long batchThreads) { - final int task; if (meta.isParallel() || meta.isWorkerGridAvailable()) { - task = submitParallel(executionPlanId, meta, batchThreads); + submitParallel(executionPlanId, meta, batchThreads); } else { - task = submitSequential(executionPlanId, meta); - } - - if (meta.shouldDumpProfiles()) { - deviceContext.retainEvent(executionPlanId, task); - meta.addProfile(task); + submitSequential(executionPlanId, meta); } } diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXScheduler.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXScheduler.java index f96276bd9c..b0d1a17562 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXScheduler.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXScheduler.java @@ -60,7 +60,7 @@ public int[] calculateBlockDimension(PTXModule module, TaskMetaData taskMeta) { return Arrays.stream(taskMeta.getLocalWork()).mapToInt(l -> (int) l).toArray(); } - long maxThreadsPerBlock = taskMeta.getLogicDevice().getPhysicalDevice().getMaxThreadsPerBlock(); + long maxThreadsPerBlock = taskMeta.getXPUDevice().getPhysicalDevice().getMaxThreadsPerBlock(); if (taskMeta.getDims() > 1) { maxThreadsPerBlock = module.getPotentialBlockSizeMaxOccupancy(); } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroCodeCache.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroCodeCache.java index 1363263cc1..315c164df3 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroCodeCache.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroCodeCache.java @@ -29,6 +29,7 @@ import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler; import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions; import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVFileReader; +import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType; import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException; import uk.ac.manchester.tornado.drivers.common.logging.Logger; import uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVInstalledCode; @@ -60,7 +61,7 @@ public synchronized SPIRVInstalledCode installSPIRVBinary(TaskMetaData meta, Str ZeModuleDescriptor moduleDesc = new ZeModuleDescriptor(); ZeBuildLogHandle buildLog = new ZeBuildLogHandle(); moduleDesc.setFormat(ZeModuleFormat.ZE_MODULE_FORMAT_IL_SPIRV); - moduleDesc.setBuildFlags("-ze-opt-level 2 -ze-opt-large-register-file"); + moduleDesc.setBuildFlags(meta.getCompilerFlags(TornadoVMBackendType.SPIRV)); checkBinaryFileExists(pathToFile); diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLCodeCache.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLCodeCache.java index f75ef402e2..1a3ca77c70 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLCodeCache.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLCodeCache.java @@ -32,6 +32,7 @@ import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler; import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions; import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVFileReader; +import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType; import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException; import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException; import uk.ac.manchester.tornado.drivers.opencl.OCLErrorCode; @@ -99,7 +100,8 @@ public SPIRVInstalledCode installSPIRVBinary(TaskMetaData meta, String id, Strin } OCLTargetDevice oclDevice = (OCLTargetDevice) deviceContext.getDevice().getDeviceRuntime(); - int status = spirvoclNativeCompiler.clBuildProgram(programPointer, 1, new long[] { oclDevice.getDevicePointer() }, ""); + String compilerFlags = meta.getCompilerFlags(TornadoVMBackendType.OPENCL); + int status = spirvoclNativeCompiler.clBuildProgram(programPointer, 1, new long[] { oclDevice.getDevicePointer() }, compilerFlags); if (status != OCLErrorCode.CL_SUCCESS) { String log = spirvoclNativeCompiler.clGetProgramBuildInfo(programPointer, oclDevice.getDevicePointer()); System.out.println(log); diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHighTierContext.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHighTierContext.java index d430dd98d6..3cd2c78b17 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHighTierContext.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graal/phases/TornadoHighTierContext.java @@ -74,7 +74,7 @@ public TaskMetaData getMeta() { } public TornadoXPUDevice getDeviceMapping() { - return meta.getLogicDevice(); + return meta.getXPUDevice(); } public boolean hasMeta() { diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graph/TornadoExecutionContext.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graph/TornadoExecutionContext.java index 2c2777abdd..f874d7805e 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graph/TornadoExecutionContext.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/graph/TornadoExecutionContext.java @@ -465,7 +465,7 @@ public Deque getActiveDeviceIndexes() { /** * It retrieves a list of tasks for a specific device and driver. Both - * deviceContext and driverIndex are checked to ensure the correct task + * deviceContext and backendIndex are checked to ensure the correct task * assignment. * * @param deviceContext @@ -493,7 +493,7 @@ public List getTasksForDevice(TornadoDeviceContext deviceContex */ @Deprecated public TornadoXPUDevice getDefaultDevice() { - return meta.getLogicDevice(); + return meta.getXPUDevice(); } public SchedulableTask getTask(String id) { @@ -534,13 +534,13 @@ public void sync() { Object object = objects.get(i); if (object != null) { final LocalObjectState localState = objectState.get(i); - Event event = localState.sync(executionPlanId, object, meta().getLogicDevice()); + Event event = localState.sync(executionPlanId, object, meta().getXPUDevice()); if (TornadoOptions.isProfilerEnabled() && event != null) { long value = profiler.getTimer(ProfilerType.COPY_OUT_TIME_SYNC); value += event.getElapsedTime(); profiler.setTimer(ProfilerType.COPY_OUT_TIME_SYNC, value); - XPUDeviceBufferState deviceObjectState = localState.getDataObjectState().getDeviceBufferState(meta().getLogicDevice()); + XPUDeviceBufferState deviceObjectState = localState.getDataObjectState().getDeviceBufferState(meta().getXPUDevice()); profiler.addValueToMetric(ProfilerType.COPY_OUT_SIZE_BYTES_SYNC, TimeProfiler.NO_TASK_NAME, deviceObjectState.getXPUBuffer().size()); } } diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/interpreter/TornadoVMInterpreter.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/interpreter/TornadoVMInterpreter.java index c4eaa27f05..fc3100c139 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/interpreter/TornadoVMInterpreter.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/interpreter/TornadoVMInterpreter.java @@ -210,11 +210,6 @@ public void dumpEvents() { } public void dumpProfiles() { - if (!executionContext.meta().shouldDumpProfiles()) { - logger.info("profiling is not enabled"); - return; - } - for (final SchedulableTask task : tasks) { final TaskMetaData meta = (TaskMetaData) task.meta(); for (final TornadoEvents eventSet : meta.getProfiles(executionContext.getExecutionPlanId())) { @@ -409,7 +404,7 @@ private int executeAlloc(StringBuilder tornadoVMBytecodeList, int[] args, long s } } - long allocationsTotalSize = interpreterDevice.allocateObjects(objects, sizeBatch, objectStates); + long allocationsTotalSize = interpreterDevice.allocateObjects(objects, sizeBatch, objectStates); executionContext.setCurrentDeviceMemoryUsage(allocationsTotalSize); @@ -419,7 +414,7 @@ private int executeAlloc(StringBuilder tornadoVMBytecodeList, int[] args, long s timeProfiler.addValueToMetric(ProfilerType.ALLOCATION_BYTES, TimeProfiler.NO_TASK_NAME, objectState.getXPUBuffer().size()); } } - + return -1; } @@ -433,7 +428,7 @@ private int executeDeAlloc(StringBuilder tornadoVMBytecodeList, final int object } final XPUDeviceBufferState objectState = resolveObjectState(objectIndex); - long spaceDeallocated = interpreterDevice.deallocate(objectState); + long spaceDeallocated = interpreterDevice.deallocate(objectState); // Update current device area use executionContext.setCurrentDeviceMemoryUsage(executionContext.getCurrentDeviceMemoryUsage() - spaceDeallocated); return -1; @@ -880,8 +875,8 @@ private int globalToLocalTaskIndex(int taskIndex) { private void profilerUpdateForPreCompiledTask(SchedulableTask task) { if (task instanceof PrebuiltTask prebuiltTask && timeProfiler instanceof TimeProfiler) { - timeProfiler.registerDeviceID(task.getId(), prebuiltTask.meta().getLogicDevice().getDriverIndex() + ":" + prebuiltTask.meta().getDeviceIndex()); - timeProfiler.registerDeviceName(task.getId(), prebuiltTask.meta().getLogicDevice().getPhysicalDevice().getDeviceName()); + timeProfiler.registerDeviceID(task.getId(), prebuiltTask.meta().getXPUDevice().getDriverIndex() + ":" + prebuiltTask.meta().getDeviceIndex()); + timeProfiler.registerDeviceName(task.getId(), prebuiltTask.meta().getXPUDevice().getPhysicalDevice().getDeviceName()); } } @@ -935,10 +930,8 @@ static void logTransferToDeviceOnce(List allEvents, Object object, Torn tornadoVMBytecodeList.append(verbose).append("\n"); } - static void logTransferToDeviceAlways(Object object, TornadoXPUDevice deviceForInterpreter, long sizeBatch, long offset, final int eventList, - StringBuilder tornadoVMBytecodeList) { - String verbose = String.format("bc: %s [0x%x] %s on %s, size=%d, offset=%d [event list=%d]", - InterpreterUtilities.debugHighLightBC("TRANSFER_HOST_TO_DEVICE_ALWAYS"), // + static void logTransferToDeviceAlways(Object object, TornadoXPUDevice deviceForInterpreter, long sizeBatch, long offset, final int eventList, StringBuilder tornadoVMBytecodeList) { + String verbose = String.format("bc: %s [0x%x] %s on %s, size=%d, offset=%d [event list=%d]", InterpreterUtilities.debugHighLightBC("TRANSFER_HOST_TO_DEVICE_ALWAYS"), // object.hashCode(), // object, // InterpreterUtilities.debugDeviceBC(deviceForInterpreter), // diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/CompilableTask.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/CompilableTask.java index 0fb82b5fbe..116bb2a24e 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/CompilableTask.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/CompilableTask.java @@ -81,7 +81,7 @@ public Access[] getArgumentsAccess() { @Override public TornadoXPUDevice getDevice() { - return meta.getLogicDevice(); + return meta.getXPUDevice(); } @Override diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/PrebuiltTask.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/PrebuiltTask.java index 5f9f04da48..f18e2d2eac 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/PrebuiltTask.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/PrebuiltTask.java @@ -124,7 +124,7 @@ public SchedulableTask mapTo(TornadoDevice mapping) { @Override public TornadoXPUDevice getDevice() { - return meta.getLogicDevice(); + return meta.getXPUDevice(); } @Override diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/ReduceTaskGraph.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/ReduceTaskGraph.java index 889c8d2432..cb4f38ce77 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/ReduceTaskGraph.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/ReduceTaskGraph.java @@ -65,6 +65,7 @@ import uk.ac.manchester.tornado.runtime.analyzer.ReduceCodeAnalysis.REDUCE_OPERATION; import uk.ac.manchester.tornado.runtime.common.TornadoOptions; import uk.ac.manchester.tornado.runtime.tasks.meta.MetaDataUtils; +import uk.ac.manchester.tornado.runtime.tasks.meta.MetaDataUtils.BackendSelectionContainer; import uk.ac.manchester.tornado.runtime.tasks.meta.TaskMetaData; class ReduceTaskGraph { @@ -75,7 +76,7 @@ class ReduceTaskGraph { private static final String TASK_GRAPH_PREFIX = "XXX__GENERATED_REDUCE"; private static final int DEFAULT_GPU_WORK_GROUP = 256; - private static final int DEFAULT_DRIVER_INDEX = 0; + private static final int DEFAULT_BACKEND_INDEX = 0; private static final int DEFAULT_DEVICE_INDEX = 0; private static AtomicInteger counterName = new AtomicInteger(0); private static AtomicInteger counterSeqName = new AtomicInteger(0); @@ -189,8 +190,8 @@ private static void inspectBinariesFPGA(String taskScheduleName, String graphNam for (int i = 0; i < binaries.length; i += 2) { String givenTaskName = binaries[i + 1].split(".device")[0]; if (givenTaskName.equals(idTaskGraph)) { - int[] info = MetaDataUtils.resolveDriverDeviceIndexes(MetaDataUtils.getProperty(idTaskGraph + ".device")); - int deviceNumber = info[1]; + BackendSelectionContainer info = MetaDataUtils.resolveDriverDeviceIndexes(MetaDataUtils.getProperty(idTaskGraph + ".device")); + int deviceNumber = info.deviceIndex(); if (!sequential) { originalBinaries.append("," + binaries[i] + "," + taskScheduleName + "." + taskName + ".device=0:" + deviceNumber); @@ -207,15 +208,15 @@ private boolean isAheadOfTime() { return TornadoOptions.FPGA_BINARIES != null; } - private int[] changeDriverAndDeviceIfNeeded(String taskScheduleName, String graphName, String taskName) { + private BackendSelectionContainer changeDriverAndDeviceIfNeeded(String taskScheduleName, String graphName, String taskName) { String idTaskGraph = graphName + "." + taskName; boolean isDeviceDefined = MetaDataUtils.getProperty(idTaskGraph + ".device") != null; if (isDeviceDefined) { - int[] info = MetaDataUtils.resolveDriverDeviceIndexes(MetaDataUtils.getProperty(idTaskGraph + ".device")); - int driverNumber = info[0]; - int deviceNumber = info[1]; - TornadoRuntimeProvider.setProperty(taskScheduleName + "." + taskName + ".device", driverNumber + ":" + deviceNumber); + BackendSelectionContainer info = MetaDataUtils.resolveDriverDeviceIndexes(MetaDataUtils.getProperty(idTaskGraph + ".device")); + int backendIndex = info.backendIndex(); + int deviceNumber = info.deviceIndex(); + TornadoRuntimeProvider.setProperty(taskScheduleName + "." + taskName + ".device", backendIndex + ":" + deviceNumber); return info; } return null; @@ -420,7 +421,7 @@ TaskGraph scheduleWithReduction(MetaReduceCodeAnalysis metaReduceTable) { reduceOperandTable = new HashMap<>(); } - int driverToRun = DEFAULT_DRIVER_INDEX; + int backendToRun = DEFAULT_BACKEND_INDEX; int deviceToRun = DEFAULT_DEVICE_INDEX; // Create new buffer variables and update the corresponding streamIn and @@ -432,10 +433,10 @@ TaskGraph scheduleWithReduction(MetaReduceCodeAnalysis metaReduceTable) { List streamReduceList = new ArrayList<>(); - int[] driverAndDevice = changeDriverAndDeviceIfNeeded(taskScheduleReduceName, graphName, taskPackage.getId()); - if (driverAndDevice != null) { - driverToRun = driverAndDevice[0]; - deviceToRun = driverAndDevice[1]; + BackendSelectionContainer selectionContainer = changeDriverAndDeviceIfNeeded(taskScheduleReduceName, graphName, taskPackage.getId()); + if (selectionContainer != null) { + backendToRun = selectionContainer.backendIndex(); + deviceToRun = selectionContainer.deviceIndex(); } inspectBinariesFPGA(taskScheduleReduceName, graphName, taskPackage.getId(), false); @@ -474,7 +475,7 @@ TaskGraph scheduleWithReduction(MetaReduceCodeAnalysis metaReduceTable) { } // Set the new array size - int sizeReductionArray = obtainSizeArrayResult(driverToRun, deviceToRun, inputSize); + int sizeReductionArray = obtainSizeArrayResult(backendToRun, deviceToRun, inputSize); Object newDeviceArray = createNewReduceArray(originalReduceArray, sizeReductionArray); Object neutralElement = getNeutralElement(originalReduceArray); fillOutputArrayWithNeutral(newDeviceArray, neutralElement); @@ -570,7 +571,7 @@ TaskGraph scheduleWithReduction(MetaReduceCodeAnalysis metaReduceTable) { for (REDUCE_OPERATION operation : operations) { final String newTaskSequentialName = SEQUENTIAL_TASK_REDUCE_NAME + counterSeqName.get(); String fullName = rewrittenTaskGraph.getTaskGraphName() + "." + newTaskSequentialName; - TornadoRuntimeProvider.setProperty(fullName + ".device", driverToRun + ":" + deviceToRun); + TornadoRuntimeProvider.setProperty(fullName + ".device", backendToRun + ":" + deviceToRun); inspectBinariesFPGA(taskScheduleReduceName, graphName, taskPackage.getId(), true); switch (operation) { @@ -615,7 +616,8 @@ private boolean checkAllArgumentsPerTask() { continue; } if (!rewrittenTaskGraph.getArgumentsLookup().contains(parameter)) { - throw new TornadoTaskRuntimeException("Parameter #" + i + " <" + parameter + "> from task <" + task.getId() + "> not specified either in `transferToDevice` or `transferToHost` functions"); + throw new TornadoTaskRuntimeException("Parameter #" + i + " <" + parameter + "> from task <" + task + .getId() + "> not specified either in `transferToDevice` or `transferToHost` functions"); } } } diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/TornadoTaskGraph.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/TornadoTaskGraph.java index 762ca3aed1..636442f094 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/TornadoTaskGraph.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/TornadoTaskGraph.java @@ -23,9 +23,38 @@ */ package uk.ac.manchester.tornado.runtime.tasks; -import jdk.vm.ci.meta.ResolvedJavaMethod; +import static uk.ac.manchester.tornado.api.profiler.ProfilerType.ALLOCATION_BYTES; +import static uk.ac.manchester.tornado.api.profiler.ProfilerType.TOTAL_COPY_IN_SIZE_BYTES; +import static uk.ac.manchester.tornado.api.profiler.ProfilerType.TOTAL_COPY_OUT_SIZE_BYTES; +import static uk.ac.manchester.tornado.api.profiler.ProfilerType.TOTAL_KERNEL_TIME; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.reflect.Array; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + import org.graalvm.compiler.graph.Graph; import org.graalvm.compiler.phases.util.Providers; + +import jdk.vm.ci.meta.ResolvedJavaMethod; import uk.ac.manchester.tornado.api.DRMode; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -33,7 +62,6 @@ import uk.ac.manchester.tornado.api.Policy; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.TornadoBackend; -import uk.ac.manchester.tornado.api.TornadoDeviceContext; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; import uk.ac.manchester.tornado.api.TornadoRuntime; import uk.ac.manchester.tornado.api.TornadoTaskGraphInterface; @@ -61,6 +89,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; import uk.ac.manchester.tornado.api.enums.ProfilerMode; import uk.ac.manchester.tornado.api.enums.TornadoDeviceType; +import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType; import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException; import uk.ac.manchester.tornado.api.exceptions.TornadoDynamicReconfigurationException; import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException; @@ -99,34 +128,6 @@ import uk.ac.manchester.tornado.runtime.tasks.meta.ScheduleMetaData; import uk.ac.manchester.tornado.runtime.tasks.meta.TaskMetaData; -import java.io.IOException; -import java.lang.foreign.MemorySegment; -import java.lang.reflect.Array; -import java.lang.reflect.Method; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.TreeMap; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Consumer; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -import static uk.ac.manchester.tornado.api.profiler.ProfilerType.ALLOCATION_BYTES; -import static uk.ac.manchester.tornado.api.profiler.ProfilerType.TOTAL_COPY_IN_SIZE_BYTES; -import static uk.ac.manchester.tornado.api.profiler.ProfilerType.TOTAL_COPY_OUT_SIZE_BYTES; -import static uk.ac.manchester.tornado.api.profiler.ProfilerType.TOTAL_KERNEL_TIME; - /** * Implementation of the Tornado API for running on heterogeneous devices. */ @@ -206,7 +207,7 @@ public class TornadoTaskGraph implements TornadoTaskGraphInterface { * Task Schedule implementation that uses GPU/FPGA and multicore backends. This constructor must be public. It is invoked using the reflection API. * * @param taskScheduleName - * Task-Schedule name + * Task-Schedule name */ public TornadoTaskGraph(String taskScheduleName) { executionContext = new TornadoExecutionContext(taskScheduleName); @@ -253,41 +254,41 @@ static void performStreamInObject(TaskGraph task, List inputObjects, fin task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5)); break; case 7: - task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), - inputObjects.get(6)); + task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), inputObjects.get( + 6)); break; case 8: - task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), - inputObjects.get(6), inputObjects.get(7)); + task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), inputObjects.get( + 6), inputObjects.get(7)); break; case 9: - task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), - inputObjects.get(6), inputObjects.get(7), inputObjects.get(8)); + task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), inputObjects.get( + 6), inputObjects.get(7), inputObjects.get(8)); break; case 10: - task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), - inputObjects.get(6), inputObjects.get(7), inputObjects.get(8), inputObjects.get(9)); + task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), inputObjects.get( + 6), inputObjects.get(7), inputObjects.get(8), inputObjects.get(9)); break; case 11: - task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), - inputObjects.get(6), inputObjects.get(7), inputObjects.get(8), inputObjects.get(9), inputObjects.get(10)); + task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), inputObjects.get( + 6), inputObjects.get(7), inputObjects.get(8), inputObjects.get(9), inputObjects.get(10)); break; case 12: - task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), - inputObjects.get(6), inputObjects.get(7), inputObjects.get(8), inputObjects.get(9), inputObjects.get(10), inputObjects.get(11)); + task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), inputObjects.get( + 6), inputObjects.get(7), inputObjects.get(8), inputObjects.get(9), inputObjects.get(10), inputObjects.get(11)); break; case 13: - task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), - inputObjects.get(6), inputObjects.get(7), inputObjects.get(8), inputObjects.get(9), inputObjects.get(10), inputObjects.get(11), inputObjects.get(12)); + task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), inputObjects.get( + 6), inputObjects.get(7), inputObjects.get(8), inputObjects.get(9), inputObjects.get(10), inputObjects.get(11), inputObjects.get(12)); break; case 14: - task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), - inputObjects.get(6), inputObjects.get(7), inputObjects.get(8), inputObjects.get(9), inputObjects.get(10), inputObjects.get(11), inputObjects.get(12), inputObjects.get(13)); + task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), inputObjects.get( + 6), inputObjects.get(7), inputObjects.get(8), inputObjects.get(9), inputObjects.get(10), inputObjects.get(11), inputObjects.get(12), inputObjects.get(13)); break; case 15: - task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), - inputObjects.get(6), inputObjects.get(7), inputObjects.get(8), inputObjects.get(9), inputObjects.get(10), inputObjects.get(11), inputObjects.get(12), inputObjects.get(13), - inputObjects.get(14)); + task.transferToDevice(dataTransferMode, inputObjects.get(0), inputObjects.get(1), inputObjects.get(2), inputObjects.get(3), inputObjects.get(4), inputObjects.get(5), inputObjects.get( + 6), inputObjects.get(7), inputObjects.get(8), inputObjects.get(9), inputObjects.get(10), inputObjects.get(11), inputObjects.get(12), inputObjects.get(13), inputObjects.get( + 14)); break; default: System.out.println("COPY-IN Not supported yet: " + numObjectsCopyIn); @@ -490,6 +491,11 @@ public long getCurrentDeviceMemoryUsage() { return executionContext.getCurrentDeviceMemoryUsage(); } + @Override + public void withCompilerFlags(TornadoVMBackendType backendType, String compilerFlags) { + executionContext.meta().setCompilerFlags(backendType, compilerFlags); + } + @Override public long getTotalBytesTransferred() { return getProfilerValue(ProfilerType.TOTAL_COPY_IN_SIZE_BYTES) + getProfilerValue(TOTAL_COPY_OUT_SIZE_BYTES); @@ -519,7 +525,7 @@ public TornadoDevice getDevice() { @Override public void setDevice(TornadoDevice device) { - TornadoDevice oldDevice = meta().getLogicDevice(); + TornadoDevice oldDevice = meta().getXPUDevice(); // prevent to set again the same device as it invalidates its state if (oldDevice.equals(device)) { @@ -535,7 +541,7 @@ public void setDevice(TornadoDevice device) { task.meta().setDevice(device); if (task instanceof CompilableTask compilableTask) { ResolvedJavaMethod method = TornadoCoreRuntime.getTornadoRuntime().resolveMethod(compilableTask.getMethod()); - if (!meta().getLogicDevice().getDeviceContext().isCached(method.getName(), compilableTask)) { + if (!meta().getXPUDevice().getDeviceContext().isCached(method.getName(), compilableTask)) { updateInner(i, executionContext.getTask(i)); } } @@ -568,7 +574,7 @@ private void reuseDeviceBuffersForSameDevice(TornadoDevice device) { @Override public void setDevice(String taskName, TornadoDevice device) { - TornadoDevice oldDevice = meta().getLogicDevice(); + TornadoDevice oldDevice = meta().getXPUDevice(); // Make sure that a sketch is available for the device. for (int i = 0; i < executionContext.getTaskCount(); i++) { @@ -680,14 +686,14 @@ private void logTaskMethodHandle(SchedulableTask task) { } private void updateDeviceContext() { - executionContext.setDevice(meta().getLogicDevice()); + executionContext.setDevice(meta().getXPUDevice()); } /** * Compile a {@link TaskGraph} into TornadoVM byte-code. * * @param setNewDevice: - * boolean that specifies if set a new device or not. + * boolean that specifies if set a new device or not. */ private TornadoVM compileGraphAndBuildVM(boolean setNewDevice) { final ByteBuffer buffer = ByteBuffer.wrap(highLevelCode); @@ -737,7 +743,7 @@ private CompileInfo extractCompileInfo() { return COMPILE_ONLY; } - if (tornadoVMBytecodeBuilder != null && !isLastDeviceListEmpty() && !(compareDevices(executionContext.getLastDevices(), meta().getLogicDevice()))) { + if (tornadoVMBytecodeBuilder != null && !isLastDeviceListEmpty() && !(compareDevices(executionContext.getLastDevices(), meta().getXPUDevice()))) { return COMPILE_AND_UPDATE; } @@ -745,7 +751,7 @@ private CompileInfo extractCompileInfo() { return COMPILE_ONLY; } - if (!compareDevices(executionContext.getLastDevices(), meta().getLogicDevice())) { + if (!compareDevices(executionContext.getLastDevices(), meta().getXPUDevice())) { return COMPILE_AND_UPDATE; } @@ -775,12 +781,12 @@ private boolean compileComputeGraphToTornadoVMBytecode() { timeProfiler.start(ProfilerType.TOTAL_BYTE_CODE_GENERATION); executionContext.scheduleTaskToDevices(); TornadoVM tornadoVM = compileGraphAndBuildVM(compileInfo.updateDevice); - vmTable.put(meta().getLogicDevice(), tornadoVM); + vmTable.put(meta().getXPUDevice(), tornadoVM); timeProfiler.stop(ProfilerType.TOTAL_BYTE_CODE_GENERATION); } - executionContext.addLastDevice(meta().getLogicDevice()); + executionContext.addLastDevice(meta().getXPUDevice()); - vm = vmTable.get(meta().getLogicDevice()); + vm = vmTable.get(meta().getXPUDevice()); /* * Set the grid scheduler outside the constructor of the {@link @@ -1012,8 +1018,8 @@ public void transferToHost(final int mode, Object... objects) { @Override public void dump() { final int width = 16; - System.out.printf("code : capacity = %s, in use = %s %n", RuntimeUtilities.humanReadableByteCount(hlBuffer.capacity(), true), - RuntimeUtilities.humanReadableByteCount(hlBuffer.position(), true)); + System.out.printf("code : capacity = %s, in use = %s %n", RuntimeUtilities.humanReadableByteCount(hlBuffer.capacity(), true), RuntimeUtilities.humanReadableByteCount(hlBuffer.position(), + true)); for (int i = 0; i < hlBuffer.position(); i += width) { System.out.printf("[0x%04x]: ", i); for (int j = 0; j < Math.min(hlBuffer.capacity() - i, width); j++) { @@ -1079,12 +1085,12 @@ private void free() { } inputModesObjects.forEach(inputStreamObject -> freeDeviceMemoryObject(inputStreamObject.getObject())); outputModeObjects.forEach(outputStreamObject -> freeDeviceMemoryObject(outputStreamObject.getObject())); - meta().getLogicDevice().getDeviceContext().reset(executionPlanId); + meta().getXPUDevice().getDeviceContext().reset(executionPlanId); } private void freeDeviceMemoryObject(Object object) { final LocalObjectState localState = executionContext.getLocalStateObject(object); - releaseObjectFromDeviceMemory(localState, meta().getLogicDevice()); + releaseObjectFromDeviceMemory(localState, meta().getXPUDevice()); } private void releaseObjectFromDeviceMemory(final LocalObjectState localState, final TornadoDevice device) { @@ -1109,7 +1115,7 @@ private void syncField(Object object) { private Event syncObjectInner(Object object) { final LocalObjectState localState = executionContext.getLocalStateObject(object); final DataObjectState dataObjectState = localState.getDataObjectState(); - final TornadoXPUDevice device = meta().getLogicDevice(); + final TornadoXPUDevice device = meta().getXPUDevice(); final XPUDeviceBufferState deviceState = dataObjectState.getDeviceBufferState(device); if (deviceState.isLockedBuffer()) { return device.resolveEvent(executionPlanId, device.streamOutBlocking(executionPlanId, object, 0, deviceState, null)); @@ -1120,7 +1126,7 @@ private Event syncObjectInner(Object object) { private Event syncObjectInner(Object object, long offset, long partialCopySize) { final LocalObjectState localState = executionContext.getLocalStateObject(object); final DataObjectState dataObjectState = localState.getDataObjectState(); - final TornadoXPUDevice device = meta().getLogicDevice(); + final TornadoXPUDevice device = meta().getXPUDevice(); final XPUDeviceBufferState deviceState = dataObjectState.getDeviceBufferState(device); deviceState.setPartialCopySize(partialCopySize); if (deviceState.isLockedBuffer()) { @@ -1132,7 +1138,7 @@ private Event syncObjectInner(Object object, long offset, long partialCopySize) private Event syncObjectInnerLazy(Object object, long hostOffset, long bufferSize) { final LocalObjectState localState = executionContext.getLocalStateObject(object); final DataObjectState dataObjectState = localState.getDataObjectState(); - final TornadoXPUDevice device = meta().getLogicDevice(); + final TornadoXPUDevice device = meta().getXPUDevice(); final XPUDeviceBufferState deviceBufferState = dataObjectState.getDeviceBufferState(device); if (deviceBufferState.isLockedBuffer()) { deviceBufferState.getXPUBuffer().setSizeSubRegion(bufferSize); @@ -1206,7 +1212,7 @@ public void syncRuntimeTransferToHost(Object... objects) { value += eventParameter.getElapsedTime(); timeProfiler.setTimer(ProfilerType.COPY_OUT_TIME_SYNC, value); LocalObjectState localState = executionContext.getLocalStateObject(objects[i]); - XPUDeviceBufferState deviceObjectState = localState.getDataObjectState().getDeviceBufferState(meta().getLogicDevice()); + XPUDeviceBufferState deviceObjectState = localState.getDataObjectState().getDeviceBufferState(meta().getXPUDevice()); timeProfiler.addValueToMetric(ProfilerType.COPY_OUT_SIZE_BYTES_SYNC, TimeProfiler.NO_TASK_NAME, deviceObjectState.getXPUBuffer().size()); } updateProfiler(); @@ -1237,7 +1243,7 @@ public void syncRuntimeTransferToHost(Object object, long offset, long partialCo value += event.getElapsedTime(); timeProfiler.setTimer(ProfilerType.COPY_OUT_TIME_SYNC, value); LocalObjectState localState = executionContext.getLocalStateObject(object); - XPUDeviceBufferState deviceObjectState = localState.getDataObjectState().getDeviceBufferState(meta().getLogicDevice()); + XPUDeviceBufferState deviceObjectState = localState.getDataObjectState().getDeviceBufferState(meta().getXPUDevice()); timeProfiler.addValueToMetric(ProfilerType.COPY_OUT_SIZE_BYTES_SYNC, TimeProfiler.NO_TASK_NAME, deviceObjectState.getXPUBuffer().size()); updateProfiler(); } @@ -1311,7 +1317,8 @@ private boolean checkAllArgumentsPerTask() { continue; } if (!argumentsLookUp.contains(parameter)) { - throw new TornadoTaskRuntimeException("Parameter #" + i + " <" + parameter + "> from task <" + task.getId() + "> not specified either in transferToDevice or transferToHost functions"); + throw new TornadoTaskRuntimeException("Parameter #" + i + " <" + parameter + "> from task <" + task + .getId() + "> not specified either in transferToDevice or transferToHost functions"); } } } @@ -1468,68 +1475,66 @@ private void runSequentialCodeInThread(TaskPackage taskPackage) { break; case 5: @SuppressWarnings("rawtypes") Task5 task5 = (Task5) taskPackage.getTaskParameters()[0]; - task5.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], - taskPackage.getTaskParameters()[5]); + task5.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], taskPackage + .getTaskParameters()[5]); break; case 6: @SuppressWarnings("rawtypes") Task6 task6 = (Task6) taskPackage.getTaskParameters()[0]; - task6.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], - taskPackage.getTaskParameters()[5], taskPackage.getTaskParameters()[6]); + task6.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], taskPackage + .getTaskParameters()[5], taskPackage.getTaskParameters()[6]); break; case 7: @SuppressWarnings("rawtypes") Task7 task7 = (Task7) taskPackage.getTaskParameters()[0]; - task7.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], - taskPackage.getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7]); + task7.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], taskPackage + .getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7]); break; case 8: @SuppressWarnings("rawtypes") Task8 task8 = (Task8) taskPackage.getTaskParameters()[0]; - task8.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], - taskPackage.getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8]); + task8.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], taskPackage + .getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8]); break; case 9: @SuppressWarnings("rawtypes") Task9 task9 = (Task9) taskPackage.getTaskParameters()[0]; - task9.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], - taskPackage.getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], - taskPackage.getTaskParameters()[9]); + task9.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], taskPackage + .getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], taskPackage.getTaskParameters()[9]); break; case 10: @SuppressWarnings("rawtypes") Task10 task10 = (Task10) taskPackage.getTaskParameters()[0]; - task10.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], - taskPackage.getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], - taskPackage.getTaskParameters()[9], taskPackage.getTaskParameters()[10]); + task10.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], taskPackage + .getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], taskPackage.getTaskParameters()[9], + taskPackage.getTaskParameters()[10]); break; case 11: @SuppressWarnings("rawtypes") Task11 task11 = (Task11) taskPackage.getTaskParameters()[0]; - task11.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], - taskPackage.getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], - taskPackage.getTaskParameters()[9], taskPackage.getTaskParameters()[10], taskPackage.getTaskParameters()[11]); + task11.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], taskPackage + .getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], taskPackage.getTaskParameters()[9], + taskPackage.getTaskParameters()[10], taskPackage.getTaskParameters()[11]); break; case 12: @SuppressWarnings("rawtypes") Task12 task12 = (Task12) taskPackage.getTaskParameters()[0]; - task12.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], - taskPackage.getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], - taskPackage.getTaskParameters()[9], taskPackage.getTaskParameters()[10], taskPackage.getTaskParameters()[11], taskPackage.getTaskParameters()[12]); + task12.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], taskPackage + .getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], taskPackage.getTaskParameters()[9], + taskPackage.getTaskParameters()[10], taskPackage.getTaskParameters()[11], taskPackage.getTaskParameters()[12]); break; case 13: @SuppressWarnings("rawtypes") Task13 task13 = (Task13) taskPackage.getTaskParameters()[0]; - task13.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], - taskPackage.getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], - taskPackage.getTaskParameters()[9], taskPackage.getTaskParameters()[10], taskPackage.getTaskParameters()[11], taskPackage.getTaskParameters()[12], - taskPackage.getTaskParameters()[13]); + task13.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], taskPackage + .getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], taskPackage.getTaskParameters()[9], + taskPackage.getTaskParameters()[10], taskPackage.getTaskParameters()[11], taskPackage.getTaskParameters()[12], taskPackage.getTaskParameters()[13]); break; case 14: @SuppressWarnings("rawtypes") Task14 task14 = (Task14) taskPackage.getTaskParameters()[0]; - task14.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], - taskPackage.getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], - taskPackage.getTaskParameters()[9], taskPackage.getTaskParameters()[10], taskPackage.getTaskParameters()[11], taskPackage.getTaskParameters()[12], - taskPackage.getTaskParameters()[13], taskPackage.getTaskParameters()[14]); + task14.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], taskPackage + .getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], taskPackage.getTaskParameters()[9], + taskPackage.getTaskParameters()[10], taskPackage.getTaskParameters()[11], taskPackage.getTaskParameters()[12], taskPackage.getTaskParameters()[13], taskPackage + .getTaskParameters()[14]); break; case 15: @SuppressWarnings("rawtypes") Task15 task15 = (Task15) taskPackage.getTaskParameters()[0]; - task15.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], - taskPackage.getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], - taskPackage.getTaskParameters()[9], taskPackage.getTaskParameters()[10], taskPackage.getTaskParameters()[11], taskPackage.getTaskParameters()[12], - taskPackage.getTaskParameters()[13], taskPackage.getTaskParameters()[14], taskPackage.getTaskParameters()[15]); + task15.apply(taskPackage.getTaskParameters()[1], taskPackage.getTaskParameters()[2], taskPackage.getTaskParameters()[3], taskPackage.getTaskParameters()[4], taskPackage + .getTaskParameters()[5], taskPackage.getTaskParameters()[6], taskPackage.getTaskParameters()[7], taskPackage.getTaskParameters()[8], taskPackage.getTaskParameters()[9], + taskPackage.getTaskParameters()[10], taskPackage.getTaskParameters()[11], taskPackage.getTaskParameters()[12], taskPackage.getTaskParameters()[13], taskPackage + .getTaskParameters()[14], taskPackage.getTaskParameters()[15]); break; default: throw new TornadoRuntimeException("Sequential Runner not supported yet. Number of parameters: " + type); @@ -1991,9 +1996,9 @@ private void runWithSequentialProfiler(Policy policy) { * Experimental method to sync all objects when making a clone copy for all output objects per device. * * @param policy - * input policy + * input policy * @param numDevices - * number of devices + * number of devices */ private void restoreVarsIntoJavaHeap(Policy policy, int numDevices) { if (policyTimeTable.get(policy) < numDevices) { @@ -2086,55 +2091,47 @@ private void addInner(int index, int type, Method method, ScheduleMetaData meta, updateInner(index, TaskUtils.createTask(method, meta, id, (Task6) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6])); break; case 7: - updateInner(index, - TaskUtils.createTask(method, meta, id, (Task7) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7])); + updateInner(index, TaskUtils.createTask(method, meta, id, (Task7) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], + parameters[7])); break; case 8: - updateInner(index, - TaskUtils.createTask(method, meta, id, (Task8) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], - parameters[8])); + updateInner(index, TaskUtils.createTask(method, meta, id, (Task8) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], + parameters[7], parameters[8])); break; case 9: - updateInner(index, - TaskUtils.createTask(method, meta, id, (Task9) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], - parameters[8], parameters[9])); + updateInner(index, TaskUtils.createTask(method, meta, id, (Task9) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], + parameters[7], parameters[8], parameters[9])); break; case 10: - updateInner(index, - TaskUtils.createTask(method, meta, id, (Task10) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], - parameters[8], parameters[9], parameters[10])); + updateInner(index, TaskUtils.createTask(method, meta, id, (Task10) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], + parameters[7], parameters[8], parameters[9], parameters[10])); break; case 11: - updateInner(index, - TaskUtils.createTask(method, meta, id, (Task11) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], - parameters[8], parameters[9], parameters[10], parameters[11])); + updateInner(index, TaskUtils.createTask(method, meta, id, (Task11) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], + parameters[7], parameters[8], parameters[9], parameters[10], parameters[11])); break; case 12: - updateInner(index, - TaskUtils.createTask(method, meta, id, (Task12) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], - parameters[8], parameters[9], parameters[10], parameters[11], parameters[12])); + updateInner(index, TaskUtils.createTask(method, meta, id, (Task12) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], + parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12])); break; case 13: - updateInner(index, - TaskUtils.createTask(method, meta, id, (Task13) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], - parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13])); + updateInner(index, TaskUtils.createTask(method, meta, id, (Task13) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], + parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13])); break; case 14: - updateInner(index, - TaskUtils.createTask(method, meta, id, (Task14) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], - parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14])); + updateInner(index, TaskUtils.createTask(method, meta, id, (Task14) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], + parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14])); break; case 15: - updateInner(index, - TaskUtils.createTask(method, meta, id, (Task15) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], - parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15])); + updateInner(index, TaskUtils.createTask(method, meta, id, (Task15) parameters[0], parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], + parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15])); break; default: throw new TornadoRuntimeException("Task not supported yet. Type: " + type); } } - @SuppressWarnings({"rawtypes", "unchecked"}) + @SuppressWarnings({ "rawtypes", "unchecked" }) private void addInner(int type, Method method, ScheduleMetaData meta, String id, Object[] parameters) { switch (type) { case 0: diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/meta/AbstractMetaData.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/meta/AbstractMetaData.java index c6a4e48917..32da7f2787 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/meta/AbstractMetaData.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/meta/AbstractMetaData.java @@ -27,14 +27,15 @@ import static uk.ac.manchester.tornado.runtime.tasks.meta.MetaDataUtils.resolveDevice; import java.util.Arrays; -import java.util.HashSet; import java.util.List; +import java.util.concurrent.ConcurrentHashMap; import jdk.vm.ci.meta.ResolvedJavaMethod; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.common.TornadoDevice; import uk.ac.manchester.tornado.api.common.TornadoEvents; +import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType; import uk.ac.manchester.tornado.api.memory.TaskMetaDataInterface; import uk.ac.manchester.tornado.api.profiler.TornadoProfiler; import uk.ac.manchester.tornado.runtime.TornadoAcceleratorBackend; @@ -42,32 +43,22 @@ import uk.ac.manchester.tornado.runtime.common.Tornado; import uk.ac.manchester.tornado.runtime.common.TornadoOptions; import uk.ac.manchester.tornado.runtime.common.TornadoXPUDevice; +import uk.ac.manchester.tornado.runtime.tasks.meta.MetaDataUtils.BackendSelectionContainer; public abstract class AbstractMetaData implements TaskMetaDataInterface { private static final long[] SEQUENTIAL_GLOBAL_WORK_GROUP = { 1, 1, 1 }; private static final String TRUE = "True"; private static final String FALSE = "False"; - private final boolean isDeviceDefined; + private static final String DEFAULT_OPENCL_COMPILER_FLAGS = "-cl-mad-enable -cl-fast-relaxed-math -w"; + private static final String DEFAULT_PTX_COMPILER_FLAGS = " "; + private static final String DEFAULT_SPIRV_LEVEL_ZERO_COMPILER_FLAGS = "-ze-opt-level 2 -ze-opt-large-register-file"; - private final HashSet openCLBuiltOptions = new HashSet<>(Arrays.asList(// - "-cl-single-precision-constant", // - "-cl-denorms-are-zero", // - "-cl-opt-disable", // - "-cl-strict-aliasing", // - "-cl-mad-enable", // - "-cl-no-signed-zeros", // - "-cl-unsafe-math-optimizations", // - "-cl-finite-math-only", // - "-cl-fast-relaxed-math", // - "-w", // - "-cl-std=CL2.0" // - )); + private final boolean isDeviceDefined; private boolean threadInfoEnabled; - private final boolean debugMode; private final boolean dumpEvents; - private final boolean dumpProfiles; + private final boolean isOpenclGpuBlockXDefined; private final int openclGpuBlockX; private final boolean isOpenclGpuBlock2DXDefined; @@ -80,7 +71,7 @@ public abstract class AbstractMetaData implements TaskMetaDataInterface { private final boolean isEnableParallelizationDefined; private final boolean isCpuConfigDefined; private final String cpuConfig; - private String id; + private final String id; private TornadoXPUDevice device; private int backendIndex; private int deviceIndex; @@ -93,7 +84,7 @@ public abstract class AbstractMetaData implements TaskMetaDataInterface { private ResolvedJavaMethod graph; private boolean useGridScheduler; private boolean isOpenclCompilerFlagsDefined; - private String openclCompilerOptions; + private ConcurrentHashMap compilerOptionsPerBackend; private boolean openclUseDriverScheduling; private boolean printKernel; @@ -104,9 +95,9 @@ public abstract class AbstractMetaData implements TaskMetaDataInterface { isDeviceDefined = getProperty(id + ".device") != null; if (isDeviceDefined) { - int[] a = MetaDataUtils.resolveDriverDeviceIndexes(getProperty(id + ".device")); - backendIndex = a[0]; - deviceIndex = a[1]; + BackendSelectionContainer backendSelection = MetaDataUtils.resolveDriverDeviceIndexes(getProperty(id + ".device")); + backendIndex = backendSelection.backendIndex(); + deviceIndex = backendSelection.deviceIndex(); } else if (null != parent) { backendIndex = parent.getBackendIndex(); deviceIndex = parent.getDeviceIndex(); @@ -120,13 +111,16 @@ public abstract class AbstractMetaData implements TaskMetaDataInterface { threadInfoEnabled = TornadoOptions.THREAD_INFO; printKernel = TornadoOptions.PRINT_KERNEL_SOURCE; - debugMode = parseBoolean(getDefault("debug", id, FALSE)); + dumpEvents = parseBoolean(getDefault("events.dump", id, TRUE)); - dumpProfiles = parseBoolean(getDefault("profiles.print", id, FALSE)); dumpTaskGraph = Boolean.parseBoolean(System.getProperty("dump.taskgraph", FALSE)); // Compilation flags - > only for OpenCL - openclCompilerOptions = (getProperty("tornado.opencl.compiler.options") == null) ? "-w" : getProperty("tornado.opencl.compiler.options"); + compilerOptionsPerBackend = new ConcurrentHashMap<>(); + compilerOptionsPerBackend.put(TornadoVMBackendType.OPENCL, DEFAULT_OPENCL_COMPILER_FLAGS); + compilerOptionsPerBackend.put(TornadoVMBackendType.PTX, DEFAULT_PTX_COMPILER_FLAGS); + compilerOptionsPerBackend.put(TornadoVMBackendType.SPIRV, DEFAULT_SPIRV_LEVEL_ZERO_COMPILER_FLAGS); + isOpenclCompilerFlagsDefined = getProperty("tornado.opencl.compiler.options") != null; // Thread Configurations @@ -153,8 +147,11 @@ protected static String getDefault(String keySuffix, String id, String defaultVa return (propertyValue != null) ? propertyValue : Tornado.getProperty("tornado" + "." + keySuffix, defaultValue); } - public TornadoXPUDevice getLogicDevice() { - return device != null ? device : (device = resolveDevice(Tornado.getProperty(id + ".device", backendIndex + ":" + deviceIndex))); + public TornadoXPUDevice getXPUDevice() { + if (device == null) { + device = resolveDevice(Tornado.getProperty(id + ".device", backendIndex + ":" + deviceIndex)); + } + return device; } private int getDeviceIndex(int driverIndex, TornadoDevice device) { @@ -212,29 +209,24 @@ public boolean isThreadInfoEnabled() { } public boolean isDebug() { - return debugMode; + return TornadoOptions.DEBUG; } public boolean shouldDumpEvents() { return dumpEvents; } - public boolean shouldDumpProfiles() { - return dumpProfiles; - } - public boolean shouldDumpTaskGraph() { return dumpTaskGraph; } - public String getCompilerFlags() { - return composeBuiltOptions(openclCompilerOptions); + public String getCompilerFlags(TornadoVMBackendType backendType) { + return compilerOptionsPerBackend.get(backendType); } @Override - public void setCompilerFlags(String value) { - openclCompilerOptions = value; - isOpenclCompilerFlagsDefined = true; + public void setCompilerFlags(TornadoVMBackendType backendType, String compilerFlags) { + compilerOptionsPerBackend.put(backendType, compilerFlags); } public int getOpenCLGpuBlockX() { @@ -273,17 +265,6 @@ public boolean isOpenclCompilerFlagsDefined() { return isOpenclCompilerFlagsDefined; } - public String composeBuiltOptions(String rawFlags) { - rawFlags = rawFlags.replace(",", " "); - for (String str : rawFlags.split(" ")) { - if (!openCLBuiltOptions.contains(str)) { - rawFlags = " "; - break; - } - } - return rawFlags; - } - public boolean isOpenclGpuBlockXDefined() { return isOpenclGpuBlockXDefined; } diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/meta/MetaDataUtils.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/meta/MetaDataUtils.java index 541268f984..ca06962c44 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/meta/MetaDataUtils.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/meta/MetaDataUtils.java @@ -42,9 +42,12 @@ public static TornadoXPUDevice resolveDevice(String device) { return (TornadoXPUDevice) driver.getDevice(Integer.parseInt(ids[1])); } - public static int[] resolveDriverDeviceIndexes(String device) { + public record BackendSelectionContainer(int backendIndex, int deviceIndex) { + } + + public static BackendSelectionContainer resolveDriverDeviceIndexes(String device) { final String[] ids = device.split(":"); - return new int[] { Integer.parseInt(ids[0]), Integer.parseInt(ids[1]) }; + return new BackendSelectionContainer(Integer.parseInt(ids[0]), Integer.parseInt(ids[1])); } public static String[] processPrecompiledBinariesFromFile(String fileName) { diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/meta/TaskMetaData.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/meta/TaskMetaData.java index cf65d87327..4763c9d810 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/meta/TaskMetaData.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/meta/TaskMetaData.java @@ -143,7 +143,7 @@ public long[] initLocalWork() { } public void addProfile(int id) { - final TornadoXPUDevice device = getLogicDevice(); + final TornadoXPUDevice device = getXPUDevice(); BitSet events; profiles.computeIfAbsent(device, k -> new BitSet(EVENT_WINDOW)); events = profiles.get(device); @@ -179,11 +179,11 @@ public int getConstantSize() { } @Override - public TornadoXPUDevice getLogicDevice() { + public TornadoXPUDevice getXPUDevice() { if (scheduleMetaData.isDeviceManuallySet() || (scheduleMetaData.isDeviceDefined() && !isDeviceDefined())) { - return scheduleMetaData.getLogicDevice(); + return scheduleMetaData.getXPUDevice(); } - return super.getLogicDevice(); + return super.getXPUDevice(); } public int getDims() { @@ -251,11 +251,6 @@ public void setLocalWork(long[] values) { localWorkDefined = true; } - @Override - public String getCompilerFlags() { - return isOpenclCompilerFlagsDefined() ? super.getCompilerFlags() : scheduleMetaData.getCompilerFlags(); - } - @Override public int getOpenCLGpuBlock2DX() { return isOpenclGpuBlock2DXDefined() ? super.getOpenCLGpuBlock2DX() : scheduleMetaData.getOpenCLGpuBlock2DX(); @@ -305,10 +300,10 @@ private long[] calculateNumberOfWorkgroupsFromDomain(DomainTree domain) { public void printThreadDims() { StringBuilder deviceDebug = new StringBuilder(); - boolean deviceBelongsToPTX = isPTXDevice(getLogicDevice()); + boolean deviceBelongsToPTX = isPTXDevice(getXPUDevice()); deviceDebug.append("Task info: " + getId() + "\n"); - deviceDebug.append("\tBackend : " + getLogicDevice().getTornadoVMBackend().name() + "\n"); - deviceDebug.append("\tDevice : " + getLogicDevice().getDescription() + "\n"); + deviceDebug.append("\tBackend : " + getXPUDevice().getTornadoVMBackend().name() + "\n"); + deviceDebug.append("\tDevice : " + getXPUDevice().getDescription() + "\n"); deviceDebug.append("\tDims : " + (this.isWorkerGridAvailable() ? getWorkerGrid(getId()).dimension() : (hasDomain() ? domain.getDepth() : 0)) + "\n"); if (!deviceBelongsToPTX) { @@ -336,11 +331,6 @@ public boolean isPTXDevice(TornadoXPUDevice device) { return device.getTornadoVMBackend().equals(TornadoVMBackendType.PTX); } - @Override - public boolean shouldDumpProfiles() { - return super.shouldDumpProfiles() || scheduleMetaData.shouldDumpProfiles(); - } - @Override public boolean shouldDumpEvents() { return super.shouldDumpEvents() || scheduleMetaData.shouldDumpEvents();