Skip to content

Commit

Permalink
Addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Feb 7, 2025
1 parent 555d06a commit 20ff4d5
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 46 deletions.
23 changes: 14 additions & 9 deletions inc/common/pjrt_implementation/buffer_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ class DeviceInstance;

class BufferInstance {
public:
BufferInstance(DeviceInstance &device,
std::unique_ptr<tt::runtime::Tensor> &tensor,
std::vector<std::uint32_t> shape,
std::vector<std::uint32_t> stride,
BufferInstance(DeviceInstance &device, tt::runtime::Tensor &tensor,
const std::vector<std::uint32_t> &shape,
const std::vector<std::uint32_t> &stride);

BufferInstance(DeviceInstance &device, tt::runtime::Tensor &tensor,
const std::vector<std::uint32_t> &shape,
const std::vector<std::uint32_t> &stride,
std::shared_ptr<void> host_buffer_ptr);
BufferInstance(DeviceInstance &device);
~BufferInstance();
Expand All @@ -46,7 +49,7 @@ class BufferInstance {
// the hook to get an unsafe pointer (avoids a copy).
return false;
}
tt::runtime::Tensor tensor() { return *tensor_; }
const tt::runtime::Tensor &getTensor() const { return tensor_; }

PJRT_Error *GetMemoryLayout(PJRT_Buffer_GetMemoryLayout_Args *args);
// Gets the required host size in bytes to copy to host.
Expand Down Expand Up @@ -76,7 +79,7 @@ class BufferInstance {
// API elements that must have the same lifetime as BufferInstance.
std::vector<int64_t> dims_;
std::vector<std::uint32_t> stride_;
std::unique_ptr<tt::runtime::Tensor> tensor_;
tt::runtime::Tensor tensor_;

std::vector<int64_t> minor_to_major_;
std::vector<int64_t> tile_dims_;
Expand All @@ -86,11 +89,13 @@ class BufferInstance {
std::optional<PJRT_Buffer_Type> DataType;

// OnReady event - currently not used.
std::shared_ptr<EventInstance> on_ready_event_;
EventInstance *on_ready_event_;

// Pointer to the host memory used to create this buffer, if buffer is created
// on device, the value of this pointer is nullptr.
std::shared_ptr<void> host_buffer_ptr_;
// on device, the value of this pointer is nullptr. It is necessary to keep
// track of this memory since the runtime will not clean it, and we need to
// pass the shared pointer to the runtime.
std::shared_ptr<void> host_buffer_ptr_ = nullptr;
};

} // namespace tt::pjrt
Expand Down
18 changes: 11 additions & 7 deletions inc/common/pjrt_implementation/device_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,17 @@ class DeviceInstance {
private:
tt_pjrt_status OpenDevice();

static size_t getTensorSize(const std::vector<std::uint32_t> &shape, size_t element_size);

BufferInstance *MakeDeviceBuffer(const void *data_ptr,
std::vector<std::uint32_t> &shape,
std::vector<std::uint32_t> &strides,
size_t element_size,
tt::target::DataType element_type);
static size_t getTensorSize(const std::vector<std::uint32_t> &shape,
size_t element_size);

// Create a buffer instance from a host data pointer, by copying it into
// another memory. This is necessary as we have no ownership of the passed
// pointer, and it might happen that the pointer is deallocated before the
// buffer is used.
std::unique_ptr<BufferInstance>
MakeDeviceBuffer(const void *data_ptr, std::vector<std::uint32_t> &shape,
std::vector<std::uint32_t> &strides, size_t element_size,
tt::target::DataType element_type);

ClientInstance &client_;
uint64_t last_transfer_timepoint_ = 0;
Expand Down
31 changes: 21 additions & 10 deletions src/common/pjrt_implementation/buffer_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@ int BufferInstance::id_counter_ = 0;
BufferInstance::~BufferInstance() = default;

BufferInstance::BufferInstance(DeviceInstance &device,
std::unique_ptr<tt::runtime::Tensor> &tensor,
std::vector<std::uint32_t> shape,
std::vector<std::uint32_t> stride,
std::shared_ptr<void> host_buffer_ptr)
: device_(device), host_buffer_ptr_(host_buffer_ptr) {
tt::runtime::Tensor &tensor,
const std::vector<std::uint32_t> &shape,
const std::vector<std::uint32_t> &stride)
: device_(device), tensor_(tensor) {
DLOG_F(LOG_DEBUG, "BufferInstance::BufferInstance");
tensor_ = std::move(tensor);
dims_.resize(shape.size());
for (int i = 0; i < shape.size(); i++) {
dims_[i] = shape[i];
Expand All @@ -34,6 +32,15 @@ BufferInstance::BufferInstance(DeviceInstance &device,
unique_id_ = id_counter_++;
}

BufferInstance::BufferInstance(DeviceInstance &device,
tt::runtime::Tensor &tensor,
const std::vector<std::uint32_t> &shape,
const std::vector<std::uint32_t> &stride,
std::shared_ptr<void> host_buffer_ptr)
: BufferInstance(device, tensor, shape, stride) {
host_buffer_ptr_ = host_buffer_ptr;
}

void BufferInstance::ComputeLayout() {
DLOG_F(LOG_DEBUG, "BufferInstance::ComputeLayout");
}
Expand Down Expand Up @@ -133,8 +140,12 @@ void BufferInstance::BindApi(PJRT_Api *api) {
+[](PJRT_Buffer_ReadyEvent_Args *args) -> PJRT_Error * {
DLOG_F(LOG_DEBUG, "BufferInstance::PJRT_Buffer_ReadyEvent");
BufferInstance *buffer = BufferInstance::Unwrap(args->buffer);
buffer->on_ready_event_ = std::make_shared<EventInstance>();
args->event = *buffer->on_ready_event_;
std::unique_ptr<EventInstance> onReadyEvent =
std::make_unique<EventInstance>();
buffer->on_ready_event_ = onReadyEvent.get();
// Releasing the ownership to the PJRT API caller since the caller is
// responsible for calling PJRT_Event_Destroy on event.
args->event = *onReadyEvent.release();
return nullptr;
};
// TODO: Rework the API to be Aliases(b1, b2) to let the plugin explicitly
Expand Down Expand Up @@ -209,7 +220,7 @@ tt_pjrt_status BufferInstance::CopyToHost(void *dst, size_t dst_size,
};

DLOG_F(INFO, "Copy to host id: %d", unique_id());
tt::runtime::memcpy(dst, tensor());
tt::runtime::memcpy(dst, getTensor());

EventInstance *copy_done_event = new EventInstance();
copy_done_event->OnReady(copy_done_callback, nullptr);
Expand All @@ -220,7 +231,7 @@ tt_pjrt_status BufferInstance::CopyToHost(void *dst, size_t dst_size,

PJRT_Buffer_Type BufferInstance::getRuntimeType() {
DLOG_F(LOG_DEBUG, "BufferInstance::element_type");
tt::target::DataType Type = tt::runtime::getTensorDataType(tensor());
tt::target::DataType Type = tt::runtime::getTensorDataType(getTensor());
return tt::pjrt::utils::convertElementTypeToBufferType(Type);
}

Expand Down
35 changes: 20 additions & 15 deletions src/common/pjrt_implementation/device_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// https://llvm.org/LICENSE.txt

#include <numeric>

#include "common/pjrt_implementation/device_instance.h"

#include "common/pjrt_implementation/buffer_instance.h"
Expand Down Expand Up @@ -80,38 +82,41 @@ tt_pjrt_status DeviceInstance::HostBufferToDevice(
shape.push_back(dims[i]);
strides.push_back(byte_strides[i] / element_size);
}
BufferInstance *buffer_instance =
std::unique_ptr<BufferInstance> buffer_instance =
MakeDeviceBuffer(data, shape, strides, element_size, element_type);
DLOG_F(INFO, "Buffer created with id: %d", buffer_instance->unique_id());
buffer_instance->setType(type);
*out_buffer = buffer_instance;
*out_buffer = buffer_instance.release();
EventInstance *event_instance = new EventInstance();
*out_done_with_host_buffer_event = event_instance;
return tt_pjrt_status::kSuccess;
}

size_t DeviceInstance::getTensorSize(const std::vector<std::uint32_t> &shape,
size_t element_size) {
size_t size = 1;
for (auto dim : shape) {
size *= dim;
}
return size * element_size;
size_t element_size) {
std::uint32_t elementsCount = std::accumulate(
shape.begin(), shape.end(), 1, std::multiplies<std::uint32_t>());

return static_cast<size_t>(elementsCount) * element_size;
}

BufferInstance *DeviceInstance::MakeDeviceBuffer(
std::unique_ptr<BufferInstance> DeviceInstance::MakeDeviceBuffer(
const void *data, std::vector<std::uint32_t> &shape,
std::vector<std::uint32_t> &strides, size_t element_size,
tt::target::DataType element_type) {
size_t tensor_size = getTensorSize(shape, element_size);
std::shared_ptr<void> new_memory(new char[tensor_size], [](void *ptr) {
delete[] static_cast<char *>(ptr);

std::shared_ptr<void> new_memory(new std::byte[tensor_size], [](void *ptr) {
delete[] static_cast<std::byte *>(ptr);
});

std::memcpy(new_memory.get(), data, tensor_size);
std::unique_ptr<tt::runtime::Tensor> device_tensor =
std::make_unique<tt::runtime::Tensor>(tt::runtime::createTensor(
new_memory, shape, strides, element_size, element_type));
return new BufferInstance(*this, device_tensor, shape, strides, new_memory);

tt::runtime::Tensor device_tensor = tt::runtime::createTensor(
new_memory, shape, strides, element_size, element_type);

return std::make_unique<BufferInstance>(*this, device_tensor, shape, strides,
new_memory);
}

} // namespace tt::pjrt
8 changes: 3 additions & 5 deletions src/common/pjrt_implementation/loaded_executable_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
for (size_t i = 0; i < args->num_args; ++i) {
BufferInstance *buffer =
BufferInstance::Unwrap(args->argument_lists[dev_index][i]);
rt_inputs.emplace_back(buffer->tensor());
rt_inputs.emplace_back(buffer->getTensor());
int64_t buffer_device_id =
buffer->device().device_description()->getDeviceId();
device_ids.insert(chip_ids[buffer_device_id]);
Expand Down Expand Up @@ -129,11 +129,9 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) {
// PJRT expects an empty shape for scalars.
std::vector<std::uint32_t> output_shape =
is_scalar ? std::vector<std::uint32_t>() : output_specs[i].shape;
std::unique_ptr<tt::runtime::Tensor> tensor_ptr =
std::make_unique<tt::runtime::Tensor>(rt_outputs[i]);
auto result_buffer = std::make_unique<BufferInstance>(
*this->addressable_devices_[dev_index], tensor_ptr, output_shape,
output_specs[i].stride, nullptr);
*this->addressable_devices_[dev_index], rt_outputs[i], output_shape,
output_specs[i].stride);
result_buffer->setType(tt::pjrt::utils::convertElementTypeToBufferType(
output_specs[i].dataType));
DLOG_F(INFO, "Runtime output id: %d", result_buffer->unique_id());
Expand Down
12 changes: 12 additions & 0 deletions tests/jax/ops/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ def conditionally_skip(from_dtype: DTypeLike, to_dtype: DTypeLike):
Extracted here in order not to pollute the test function.
"""
# ---------- Atol comparison failed ----------
# When no conversion is required, a no-op MLIR graph is created.
# However, due to input tensor ownership issues, the output tensor
# returned by the MLIR runtime will reference the same data as the input.
# If the input tensor is deallocated, the output tensor will lose access
# to valid data and may contain garbage.
# See issue #244 for more details.
if from_dtype == to_dtype or (from_dtype == jnp.uint32 and to_dtype == jnp.uint64):
pytest.xfail(
runtime_fail(
"Atol comparison failed. Calculated: atol=65535.0. Required: atol=0.16."
)
)

if from_dtype == jnp.uint32 and to_dtype in [jnp.uint16, jnp.int16]:
pytest.xfail(
Expand Down

0 comments on commit 20ff4d5

Please sign in to comment.