Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] loopIndexInWrite implementation moved to its root interface #474

Merged
merged 4 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import static uk.ac.manchester.tornado.runtime.common.TornadoOptions.DEVICE_AVAILABLE_MEMORY;

import java.util.ArrayList;
import java.util.List;

import uk.ac.manchester.tornado.api.TornadoDeviceContext;
import uk.ac.manchester.tornado.api.TornadoTargetDevice;
Expand All @@ -43,17 +44,14 @@
public abstract class TornadoBufferProvider {

protected final TornadoDeviceContext deviceContext;
protected final ArrayList<BufferContainer> freeBuffers;
protected final ArrayList<BufferContainer> usedBuffers;
protected final List<BufferContainer> freeBuffers;
protected final List<BufferContainer> usedBuffers;
protected long currentMemoryAvailable;

protected TornadoBufferProvider(TornadoDeviceContext deviceContext) {
this.deviceContext = deviceContext;
this.usedBuffers = new ArrayList<>();
this.freeBuffers = new ArrayList<>();

// There is no way of querying the available memory on the device.
// Instead, use a flag similar to -Xmx.
currentMemoryAvailable = TornadoOptions.DEVICE_AVAILABLE_MEMORY;
}

Expand Down Expand Up @@ -191,8 +189,15 @@ public synchronized void markBufferReleased(long buffer) {
}
}

public boolean checkBufferAvailability(int numBuffersRequired) {
return freeBuffers.size() >= numBuffersRequired;
/**
* Function that returns true if the there are, at least numBuffers available in the free list.
*
* @param numBuffers
* Number of free buffers.
* @return boolean.
*/
public boolean isNumFreeBuffersAvailable(int numBuffers) {
return freeBuffers.size() >= numBuffers;
}

public synchronized void resetBuffers() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,18 +470,6 @@ public TornadoInstalledCode installCode(SchedulableTask task) {
return loadPreCompiledBinaryForTask(task);
}

@Override
public boolean loopIndexInWrite(SchedulableTask task) {
if (task instanceof CompilableTask) {
final CompilableTask executable = (CompilableTask) task;
final ResolvedJavaMethod resolvedMethod = TornadoCoreRuntime.getTornadoRuntime().resolveMethod(executable.getMethod());
final Sketch sketch = TornadoSketcher.lookup(resolvedMethod, task.meta().getBackendIndex(), task.meta().getDeviceIndex());
return sketch.getBatchWriteThreadIndex();
} else {
return false;
}
}

private XPUBuffer createArrayWrapper(Class<?> type, OCLDeviceContext device, long batchSize) {
XPUBuffer result = null;
if (type == float[].class) {
Expand Down Expand Up @@ -575,7 +563,7 @@ private XPUBuffer createDeviceBuffer(Class<?> type, Object object, OCLDeviceCont
@Override
public synchronized long allocateObjects(Object[] objects, long batchSize, DeviceBufferState[] states) {
TornadoBufferProvider bufferProvider = getDeviceContext().getBufferProvider();
if (!bufferProvider.checkBufferAvailability(objects.length)) {
if (!bufferProvider.isNumFreeBuffersAvailable(objects.length)) {
bufferProvider.resetBuffers();
}
long allocatedSpace = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ public String toString() {
@Override
public TornadoSchedulingStrategy getPreferredSchedule() {
switch (Objects.requireNonNull(device.getDeviceType())) {
case CL_DEVICE_TYPE_GPU, CL_DEVICE_TYPE_ACCELERATOR, CL_DEVICE_TYPE_CUSTOM, CL_DEVICE_TYPE_ALL -> {
case CL_DEVICE_TYPE_GPU, //
CL_DEVICE_TYPE_ACCELERATOR,//
CL_DEVICE_TYPE_CUSTOM,//
CL_DEVICE_TYPE_ALL -> {//
return TornadoSchedulingStrategy.PER_ACCELERATOR_ITERATION;
}
case CL_DEVICE_TYPE_CPU -> {
Expand Down Expand Up @@ -449,18 +452,6 @@ public void setAtomicRegion(XPUBuffer bufferAtomics) {

}

@Override
public boolean loopIndexInWrite(SchedulableTask task) {
if (task instanceof CompilableTask) {
final CompilableTask executable = (CompilableTask) task;
final ResolvedJavaMethod resolvedMethod = TornadoCoreRuntime.getTornadoRuntime().resolveMethod(executable.getMethod());
final Sketch sketch = TornadoSketcher.lookup(resolvedMethod, task.meta().getBackendIndex(), task.meta().getDeviceIndex());
return sketch.getBatchWriteThreadIndex();
} else {
return false;
}
}

@Override
public int getAvailableProcessors() {
return ((VirtualOCLDevice) device).getAvailableProcessors();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ private XPUBuffer createDeviceBuffer(Class<?> type, Object object, long batchSiz
@Override
public synchronized long allocateObjects(Object[] objects, long batchSize, DeviceBufferState[] states) {
TornadoBufferProvider bufferProvider = getDeviceContext().getBufferProvider();
if (!bufferProvider.checkBufferAvailability(objects.length)) {
if (!bufferProvider.isNumFreeBuffersAvailable(objects.length)) {
bufferProvider.resetBuffers();
}
long allocatedSpace = 0;
Expand Down Expand Up @@ -674,18 +674,6 @@ public void setAtomicRegion(XPUBuffer bufferAtomics) {

}

@Override
public boolean loopIndexInWrite(SchedulableTask task) {
if (task instanceof CompilableTask) {
final CompilableTask executable = (CompilableTask) task;
final ResolvedJavaMethod resolvedMethod = TornadoCoreRuntime.getTornadoRuntime().resolveMethod(executable.getMethod());
final Sketch sketch = TornadoSketcher.lookup(resolvedMethod, task.meta().getBackendIndex(), task.meta().getDeviceIndex());
return sketch.getBatchWriteThreadIndex();
} else {
return false;
}
}

@Override
public String toString() {
return STR."\{getPlatformName()} -- \{device.getDeviceName()}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,6 @@ public void setAtomicRegion(XPUBuffer bufferAtomics) {
throw new RuntimeException("Unsupported");
}

@Override
public boolean loopIndexInWrite(SchedulableTask task) {
if (task instanceof CompilableTask) {
final CompilableTask executable = (CompilableTask) task;
final ResolvedJavaMethod resolvedMethod = TornadoCoreRuntime.getTornadoRuntime().resolveMethod(executable.getMethod());
final Sketch sketch = TornadoSketcher.lookup(resolvedMethod, task.meta().getBackendIndex(), task.meta().getDeviceIndex());
return sketch.getBatchWriteThreadIndex();
} else {
return false;
}
}

private XPUBuffer createArrayWrapper(Class<?> klass, SPIRVDeviceContext device, long batchSize) {
if (klass == int[].class) {
return new SPIRVIntArrayWrapper(device, batchSize);
Expand Down Expand Up @@ -332,7 +320,7 @@ private XPUBuffer createDeviceBuffer(Class<?> type, Object object, SPIRVDeviceCo
@Override
public synchronized long allocateObjects(Object[] objects, long batchSize, DeviceBufferState[] states) {
TornadoBufferProvider bufferProvider = getDeviceContext().getBufferProvider();
if (!bufferProvider.checkBufferAvailability(objects.length)) {
if (!bufferProvider.isNumFreeBuffersAvailable(objects.length)) {
bufferProvider.resetBuffers();
}
long allocatedSpace = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

import jdk.vm.ci.meta.ResolvedJavaMethod;
import uk.ac.manchester.tornado.api.TornadoDeviceContext;
import uk.ac.manchester.tornado.api.TornadoTargetDevice;
import uk.ac.manchester.tornado.api.common.Event;
Expand All @@ -42,9 +41,6 @@
import uk.ac.manchester.tornado.runtime.common.TornadoSchedulingStrategy;
import uk.ac.manchester.tornado.runtime.common.TornadoXPUDevice;
import uk.ac.manchester.tornado.runtime.common.XPUDeviceBufferState;
import uk.ac.manchester.tornado.runtime.sketcher.Sketch;
import uk.ac.manchester.tornado.runtime.sketcher.TornadoSketcher;
import uk.ac.manchester.tornado.runtime.tasks.CompilableTask;

public class JVMMapping implements TornadoXPUDevice {

Expand Down Expand Up @@ -249,18 +245,6 @@ public void setAtomicRegion(XPUBuffer bufferAtomics) {

}

@Override
public boolean loopIndexInWrite(SchedulableTask task) {
if (task instanceof CompilableTask) {
final CompilableTask executable = (CompilableTask) task;
final ResolvedJavaMethod resolvedMethod = TornadoCoreRuntime.getTornadoRuntime().resolveMethod(executable.getMethod());
final Sketch sketch = TornadoSketcher.lookup(resolvedMethod, executable.meta().getBackendIndex(), executable.meta().getDeviceIndex());
return sketch.getBatchWriteThreadIndex();
} else {
return false;
}
}

@Override
public long getMaxAllocMemory() {
return Runtime.getRuntime().maxMemory();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@
*/
package uk.ac.manchester.tornado.runtime.common;

import jdk.vm.ci.meta.ResolvedJavaMethod;
import uk.ac.manchester.tornado.api.common.SchedulableTask;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.memory.XPUBuffer;
import uk.ac.manchester.tornado.runtime.TornadoCoreRuntime;
import uk.ac.manchester.tornado.runtime.sketcher.Sketch;
import uk.ac.manchester.tornado.runtime.sketcher.TornadoSketcher;
import uk.ac.manchester.tornado.runtime.tasks.CompilableTask;

/**
* A Tornado accelerator device extending the {@link TornadoDevice} interface.
Expand Down Expand Up @@ -177,6 +182,14 @@ public interface TornadoXPUDevice extends TornadoDevice {
* @param task
* @return
*/
boolean loopIndexInWrite(SchedulableTask task);
default boolean loopIndexInWrite(SchedulableTask task) {
if (task instanceof CompilableTask executable) {
final ResolvedJavaMethod resolvedMethod = TornadoCoreRuntime.getTornadoRuntime().resolveMethod(executable.getMethod());
final Sketch sketch = TornadoSketcher.lookup(resolvedMethod, task.meta().getBackendIndex(), task.meta().getDeviceIndex());
return sketch.getBatchWriteThreadIndex();
} else {
return false;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,7 @@ public void testBatchThreadIndex() throws TornadoExecutionPlanException {

private long checkMaxHeapAllocationOnDevice(int size, MemoryUnit memoryUnit) throws UnsupportedConfigurationException {
long maxAllocMemory = getTornadoRuntime().getDefaultDevice().getDeviceContext().getMemoryManager().getHeapSize();

long memThreshold = switch (memoryUnit) {
case GB -> (long) size * 1024 * 1024 * 1024;
case MB -> (long) size * 1024 * 1024;
Expand Down