-
Notifications
You must be signed in to change notification settings - Fork 114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prebuilt API test using a multi-backend setup and selecting specific device #549
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
|
||
import static org.junit.Assert.assertEquals; | ||
|
||
import java.util.List; | ||
import java.util.stream.IntStream; | ||
|
||
import org.junit.BeforeClass; | ||
|
@@ -29,6 +30,8 @@ | |
import uk.ac.manchester.tornado.api.ImmutableTaskGraph; | ||
import uk.ac.manchester.tornado.api.KernelContext; | ||
import uk.ac.manchester.tornado.api.TaskGraph; | ||
import uk.ac.manchester.tornado.api.TornadoBackend; | ||
import uk.ac.manchester.tornado.api.TornadoDeviceMap; | ||
import uk.ac.manchester.tornado.api.TornadoExecutionPlan; | ||
import uk.ac.manchester.tornado.api.WorkerGrid; | ||
import uk.ac.manchester.tornado.api.WorkerGrid1D; | ||
|
@@ -41,16 +44,18 @@ | |
import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider; | ||
import uk.ac.manchester.tornado.api.types.arrays.IntArray; | ||
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; | ||
import uk.ac.manchester.tornado.unittests.common.TornadoVMMultiDeviceNotSupported; | ||
import uk.ac.manchester.tornado.unittests.common.TornadoVMPTXNotSupported; | ||
|
||
/** | ||
* <p> | ||
* How to run? | ||
* </p> | ||
* <code> | ||
* tornado-test -V uk.ac.manchester.tornado.unittests.prebuilt.PrebuiltTest | ||
* tornado-test -V uk.ac.manchester.tornado.unittests.prebuilt.PrebuiltTests | ||
* </code> | ||
*/ | ||
public class PrebuiltTest extends TornadoTestBase { | ||
public class PrebuiltTests extends TornadoTestBase { | ||
private static final String TORNADO_SDK = "TORNADO_SDK"; | ||
private static TornadoDevice defaultDevice; | ||
private static String FILE_PATH; | ||
|
@@ -122,7 +127,7 @@ public void testPrebuilt01() throws TornadoExecutionPlanException { | |
} | ||
|
||
@Test | ||
public void testPrebuilt01Multi() throws TornadoExecutionPlanException { | ||
public void testPrebuilt01MultiIterations() throws TornadoExecutionPlanException { | ||
|
||
final int numElements = 8; | ||
IntArray a = new IntArray(numElements); | ||
|
@@ -234,7 +239,7 @@ public void testPrebuilt02SPIRV() throws TornadoExecutionPlanException { | |
} | ||
|
||
/** | ||
* This test case verifies that the {@link PrebuiltTest#testPrebuilt03SPIRV} runs correctly though a | ||
* This test case verifies that the {@link PrebuiltTests#testPrebuilt03SPIRV} runs correctly though a | ||
* SPIR-V or OpenCL runtime if the device supports SPIR-V. | ||
* | ||
* <p>Expected outcome: - If the current backend type is PTX, the test should have | ||
|
@@ -342,4 +347,83 @@ public void testPrebuilt04SPIRVThroughOpenCLRuntime() throws TornadoExecutionPla | |
assertEquals(512, finalSum, 0.0f); | ||
|
||
} | ||
|
||
/** | ||
* This test is intended to be passed with multiple backends (e.g., OpenCL and PTX). | ||
* The PTX backend needs to be installed. Otherwise, an exception is thrown. | ||
* | ||
* <p> How to run? | ||
* <code> | ||
* tornado-test -V uk.ac.manchester.tornado.unittests.prebuilt.PrebuiltTests#testPrebuiltMutiBackend | ||
* </code> | ||
* </p> | ||
* | ||
* @throws TornadoExecutionPlanException | ||
*/ | ||
@Test | ||
public void testPrebuiltMutiBackend() throws TornadoExecutionPlanException { | ||
|
||
final int numElements = 8; | ||
IntArray a = new IntArray(numElements); | ||
IntArray b = new IntArray(numElements); | ||
IntArray c = new IntArray(numElements); | ||
|
||
a.init(1); | ||
b.init(2); | ||
|
||
// Force to use the PTX Backend. | ||
FILE_PATH += "add.ptx"; | ||
|
||
TornadoDeviceMap tornadoDeviceMap = TornadoExecutionPlan.getTornadoDeviceMap(); | ||
if (tornadoDeviceMap.getNumBackends() < 2) { | ||
throw new TornadoVMMultiDeviceNotSupported("Test designed to run with multiple backends"); | ||
} | ||
|
||
List<TornadoBackend> ptxBackend = tornadoDeviceMap.getBackendsWithPredicate(backend -> backend.getBackendType() == TornadoVMBackendType.PTX); | ||
|
||
if (ptxBackend == null || ptxBackend.isEmpty()) { | ||
throw new TornadoVMPTXNotSupported("Test designed to run with multiple backends, including a PTX backend"); | ||
} | ||
Comment on lines
+377
to
+386
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not very keen to have the check in the test body. Shall we change to something like the following condition assertAvailableDrivers(2); that exists in The condition should assert if (OpenCL, PTX) or (SPIR-V, PTX) backends have been built, otherwise result in UNSUPPORTED result. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to check that at least two drivers are installed, and one of them must be the PTX to pass this particular test. The combinations are also valid for any two backend combination in which the selected device is not the default device. I think the throw exception here is fine. |
||
|
||
// Access the first device within the NVIDIA PTX Backend | ||
TornadoDevice device = ptxBackend.getFirst().getDevice(0); | ||
|
||
// Define accessors for each parameter | ||
AccessorParameters accessorParameters = new AccessorParameters(3); | ||
accessorParameters.set(0, a, Access.READ_WRITE); | ||
accessorParameters.set(1, b, Access.READ_WRITE); | ||
accessorParameters.set(2, c, Access.WRITE_ONLY); | ||
|
||
// Define the Task-Graph | ||
TaskGraph taskGraph = new TaskGraph("s0") // | ||
.transferToDevice(DataTransferMode.EVERY_EXECUTION, a, b) // | ||
.prebuiltTask("t0", //task name | ||
"add", // name of the low-level kernel to invoke | ||
FILE_PATH, // file name | ||
accessorParameters) // accessors | ||
.transferToHost(DataTransferMode.EVERY_EXECUTION, c); | ||
|
||
// When using the prebuilt API, we need to define the WorkerGrid, otherwise it will launch 1 thread | ||
// on the target device | ||
WorkerGrid workerGrid = new WorkerGrid1D(numElements); | ||
GridScheduler gridScheduler = new GridScheduler("s0.t0", workerGrid); | ||
|
||
// Launch the application on the target device | ||
try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(taskGraph.snapshot())) { | ||
|
||
executionPlan.withGridScheduler(gridScheduler) // | ||
.withDevice(device) // | ||
.execute(); | ||
|
||
// Run task multiple times | ||
for (int i = 0; i < 10; i++) { | ||
executionPlan.execute(); | ||
for (int j = 0; j < c.getSize(); j++) { | ||
assertEquals(a.get(j) + b.get(j), c.get(j)); | ||
} | ||
IntStream.range(0, numElements).forEach(k -> a.set(k, c.get(k))); | ||
} | ||
} | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be
1.0.8-dev
, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right. I will fix this