Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions quadrants/python/dlpack_funcs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,16 @@ pybind11::capsule field_to_dlpack(Program *program, SNode *snode, int element_nd
"'to_dlpack'.")
}

int field_in_tree_offset = program->get_field_in_tree_offset(tree_id, snode);
// A SNode tree packs every field of a program stage, so in-tree byte offsets routinely exceed 2^31; the offset must
// stay 64-bit end to end (DLTensor::byte_offset is uint64_t).
std::size_t field_in_tree_offset = program->get_field_in_tree_offset(tree_id, snode);

void *raw_ptr = nullptr;
DLDeviceType device_type = DLDeviceType::kDLCPU;
std::tie(raw_ptr, device_type) = get_raw_ptr(arch, program, tree_device_ptr);

int byte_offset = 0;
if (field_in_tree_offset >= 0) {
uint64_t byte_offset = 0;
if (field_in_tree_offset > 0) {
if (torch_supports_byte_offset()) {
byte_offset = field_in_tree_offset;
} else {
Expand Down
2 changes: 1 addition & 1 deletion quadrants/rhi/common/unified_allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void *UnifiedAllocator::allocate(std::size_t size, std::size_t alignment, bool e
void *ptr = HostMemoryPool::get_instance().allocate_raw_memory(allocation_size);
chunk.data = ptr;
chunk.head = (void *)((std::size_t)chunk.data + size);
chunk.tail = (void *)((std::size_t)chunk.head + allocation_size);
chunk.tail = (void *)((std::size_t)chunk.data + allocation_size);
chunk.is_exclusive = exclusive;

QD_ASSERT(chunk.data != nullptr);
Expand Down
3 changes: 2 additions & 1 deletion quadrants/rhi/llvm/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ uint64_t *CachingAllocator::allocate(LlvmDevice *device, const LlvmDevice::LlvmR
ptr_map_.insert(std::make_pair(remaining_head, remaining_sz));
}
ret = reinterpret_cast<uint64_t *>(it_blk->second);
mem_blocks_.erase(it_blk);
// Erase from ptr_map_ first: mem_blocks_.erase(it_blk) invalidates it_blk.
ptr_map_.erase(it_blk->second);
mem_blocks_.erase(it_blk);

} else {
ret = reinterpret_cast<uint64_t *>(device->allocate_llvm_runtime_memory_jit(params));
Expand Down
16 changes: 16 additions & 0 deletions tests/cpp/rhi/common/host_memory_pool_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,20 @@ TEST(HostMemoryPool, AllocateMemory) {
HostMemoryPoolTestHelper::setDefaultAllocatorSize(oldAllocatorSize);
}

TEST(HostMemoryPool, ChunkTailMatchesAllocationSize) {
auto oldAllocatorSize = HostMemoryPoolTestHelper::getDefaultAllocatorSize();
HostMemoryPoolTestHelper::setDefaultAllocatorSize(102400); // 100KB

HostMemoryPool pool;

// The first allocation creates a 100KB chunk. The second does not fit in the remaining 40KB, so it must open a new
// chunk instead of being placed past the real end of the first one.
void *ptr1 = pool.allocate(61440, 16);
void *ptr2 = pool.allocate(81920, 16);

EXPECT_NE((std::size_t)ptr2, (std::size_t)ptr1 + 61440);

HostMemoryPoolTestHelper::setDefaultAllocatorSize(oldAllocatorSize);
}

} // namespace quadrants::lang
22 changes: 22 additions & 0 deletions tests/python/quadrants/lang/test_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,25 @@ def test_dlpack_field_memory_allocation_before_to_dlpack():
assert (
second_time_tc == second_time.to_torch(device=second_time_tc.device)
).all(), f"{second_time_tc} != {second_time.to_torch(device=second_time_tc.device)}"


@test_utils.test(arch=[qd.cpu])
@pytest.mark.run_in_serial
@pytest.mark.slow
def test_dlpack_field_offset_past_2gib():
# Fields share one SNode tree, so the second field's in-tree byte offset exceeds 2^31; the DLPack export must carry
# it without truncation instead of aliasing the tree base.
pad = qd.field(dtype=qd.f32, shape=(560_000_000,))
victim = qd.field(dtype=qd.f32, shape=(1000,))

@qd.kernel
def fill():
for i in pad:
pad[i] = 1.0
for i in victim:
victim[i] = qd.cast(i, qd.f32)

fill()
qd.sync()
victim_tc = torch.utils.dlpack.from_dlpack(victim.to_dlpack())
np.testing.assert_array_equal(victim_tc.cpu().numpy(), np.arange(1000, dtype=np.float32))
Loading