Skip to content

Commit

Permalink
zero out all the allocated shm buffer (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-ningxin authored Jul 9, 2024
1 parent c866c43 commit 9eb16a9
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/ths_op/flux_shm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ nvshmem_create_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype)
auto size = torch::elementSize(dtype) *
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
FLUX_CHECK(size != 0);
void *ptr = nvshmem_malloc(size);
FLUX_CHECK(ptr != nullptr);
CUDA_CHECK(cudaMemset(ptr, 0, size)); // memset the allocated buffer
return at::from_blob(
nvshmem_malloc(size), shape, [](void *ptr) { nvshmem_free(ptr); }, option_gpu);
ptr, shape, [](void *ptr) { nvshmem_free(ptr); }, option_gpu);
}

std::vector<torch::Tensor>
Expand All @@ -71,6 +74,7 @@ nvshmem_create_tensor_list(const std::vector<int64_t> &shape, c10::ScalarType dt
std::vector<torch::Tensor> tensors;
tensors.reserve(local_world_size);
void *ptr = nvshmem_malloc(size);
CUDA_CHECK(cudaMemset(ptr, 0, size)); // memset the allocated buffer
FLUX_CHECK(ptr != nullptr);
int rank_offset = rank - local_rank;
for (int i = 0; i < local_world_size; i++) {
Expand Down Expand Up @@ -107,6 +111,7 @@ cudaipc_create_tensor_list(
FLUX_CHECK(size != 0);
void *ptr = nullptr;
CUDA_CHECK(cudaMalloc(&ptr, size));
CUDA_CHECK(cudaMemset(ptr, 0, size)); // memset the allocated buffer
cudaIpcMemHandle_t handle;
CUDA_CHECK(cudaIpcGetMemHandle(&handle, ptr));

Expand Down

0 comments on commit 9eb16a9

Please sign in to comment.