Skip to content

Commit

Permalink
[fix] Data Access for PrebuiltTaskGraph fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
jjfumero committed Aug 27, 2024
1 parent d1f870e commit fa705ba
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package uk.ac.manchester.tornado.api.common;

import java.util.stream.IntStream;

import uk.ac.manchester.tornado.api.AccessorParameters;

public class PrebuiltTaskPackage extends TaskPackage {
Expand All @@ -32,13 +34,11 @@ public class PrebuiltTaskPackage extends TaskPackage {
this.entryPoint = entryPoint;
this.filename = fileName;
this.args = new Object[accessorParameters.numAccessors()];
for (int i = 0; i < accessorParameters.numAccessors(); i++) {
this.args[i] = accessorParameters.getAccessor(i).object();
}
this.accesses = new Access[accessorParameters.numAccessors()];
for (int i = 0; i < accessorParameters.numAccessors(); i++) {
IntStream.range(0, accessorParameters.numAccessors()).forEach(i -> {
this.args[i] = accessorParameters.getAccessor(i).object();
this.accesses[i] = accessorParameters.getAccessor(i).access();
}
});
}

public PrebuiltTaskPackage withAtomics(int[] atomics) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,15 @@ private static DomainTree buildDomainTree(int[] dims) {

}

/**
* Marshal object from {@link PrebuiltTaskPackage} to {@link PrebuiltTask}.
*
* @param meta
* {@link ScheduleContext}
* @param taskPackage
* {@link PrebuiltTaskPackage}
* @return {@link PrebuiltTask}
*/
public static PrebuiltTask createTask(ScheduleContext meta, PrebuiltTaskPackage taskPackage) {
PrebuiltTask prebuiltTask = new PrebuiltTask(meta, //
taskPackage.getId(), //
Expand All @@ -319,7 +328,7 @@ public static PrebuiltTask createTask(ScheduleContext meta, PrebuiltTaskPackage
taskPackage.getArgs(), //
taskPackage.getAccesses());
if (taskPackage.getAtomics() != null) {
prebuiltTask.withAtomics(taskPackage.getAtomics());
prebuiltTask.setAtomics(taskPackage.getAtomics());
}
return prebuiltTask;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.profiler.TornadoProfiler;
import uk.ac.manchester.tornado.runtime.common.TornadoXPUDevice;
import uk.ac.manchester.tornado.runtime.domain.DomainTree;
import uk.ac.manchester.tornado.runtime.tasks.meta.ScheduleContext;
import uk.ac.manchester.tornado.runtime.tasks.meta.TaskDataContext;

Expand All @@ -50,41 +49,17 @@ public class PrebuiltTask implements SchedulableTask {
private boolean forceCompiler;
private int[] atomics;

public PrebuiltTask(ScheduleContext scheduleMeta, String id, String entryPoint, String filename, Object[] args, Access[] access, TornadoDevice device, DomainTree domain) {
this.entryPoint = entryPoint;
this.filename = filename;
this.args = args;
this.argumentsAccess = access;
meta = new TaskDataContext(scheduleMeta, id, access.length);
for (int i = 0; i < access.length; i++) {
meta.getArgumentsAccess()[i] = access[i];
}
meta.setDevice(device);
meta.setDomain(domain);

final long[] values = new long[domain.getDepth()];
for (int i = 0; i < domain.getDepth(); i++) {
values[i] = domain.get(i).cardinality();
}
meta.setGlobalWork(values);

}

public PrebuiltTask(ScheduleContext scheduleMeta, String id, String entryPoint, String filename, Object[] args, Access[] access) {
this.entryPoint = entryPoint;
this.filename = filename;
this.args = args;
this.argumentsAccess = access;
meta = new TaskDataContext(scheduleMeta, id, access.length);
for (int i = 0; i < access.length; i++) {
meta.getArgumentsAccess()[i] = access[i];
}

meta.setArgumentsAccess(access);
}

public PrebuiltTask withAtomics(int[] atomics) {
public void setAtomics(int[] atomics) {
this.atomics = atomics;
return this;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
package uk.ac.manchester.tornado.runtime.tasks.meta;

import static uk.ac.manchester.tornado.api.exceptions.TornadoInternalError.guarantee;
import static uk.ac.manchester.tornado.runtime.common.TornadoOptions.EVENT_WINDOW;

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
Expand Down Expand Up @@ -142,12 +141,8 @@ public long[] initLocalWork() {
return localWork;
}

public void addProfile(int id) {
final TornadoXPUDevice device = getXPUDevice();
BitSet events;
profiles.computeIfAbsent(device, k -> new BitSet(EVENT_WINDOW));
events = profiles.get(device);
events.set(id);
public void setArgumentsAccess(Access[] access) {
this.argumentsAccess = access;
}

public Access[] getArgumentsAccess() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,74 @@ public void testPrebuilt01() throws TornadoExecutionPlanException {
.withDevice(defaultDevice) //
.execute();
}
for (int j = 0; j < c.getSize(); j++) {
assertEquals(a.get(j) + b.get(j), c.get(j));
}

}

@Test
public void testPrebuilt01Multi() throws TornadoExecutionPlanException {

final int numElements = 8;
IntArray a = new IntArray(numElements);
IntArray b = new IntArray(numElements);
IntArray c = new IntArray(numElements);

a.init(1);
b.init(2);

switch (backendType) {
case PTX:
FILE_PATH += "add.ptx";
break;
case OPENCL:
FILE_PATH += "add.cl";
break;
case SPIRV:
FILE_PATH += "add.spv";
break;
default:
throw new TornadoRuntimeException("Backend not supported");
}

// Define accessors for each parameter
AccessorParameters accessorParameters = new AccessorParameters(3);
accessorParameters.set(0, a, Access.READ_WRITE);
accessorParameters.set(1, b, Access.READ_WRITE);
accessorParameters.set(2, c, Access.WRITE_ONLY);

// Define the Task-Graph
TaskGraph taskGraph = new TaskGraph("s0") //
.transferToDevice(DataTransferMode.EVERY_EXECUTION, a, b) //
.prebuiltTask("t0", //task name
"add", // name of the low-level kernel to invoke
FILE_PATH, // file name
accessorParameters) // accessors
.transferToHost(DataTransferMode.EVERY_EXECUTION, c);

for (int i = 0; i < c.getSize(); i++) {
assertEquals(a.get(i) + b.get(i), c.get(i));
// When using the prebuilt API, we need to define the WorkerGrid, otherwise it will launch 1 thread
// on the target device
WorkerGrid workerGrid = new WorkerGrid1D(numElements);
GridScheduler gridScheduler = new GridScheduler("s0.t0", workerGrid);

// Launch the application on the target device
try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(taskGraph.snapshot())) {

executionPlan.withGridScheduler(gridScheduler) //
.withDevice(defaultDevice) //
.execute();

// Run task multiple times
for (int i = 0; i < 10; i++) {
executionPlan.execute();
for (int j = 0; j < c.getSize(); j++) {
assertEquals(a.get(j) + b.get(j), c.get(j));
}
IntStream.range(0, numElements).forEach(k -> a.set(k, c.get(k)));
}
}

}

@Test
Expand Down

0 comments on commit fa705ba

Please sign in to comment.