diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index 4dc9510d14181..e001a0823562f 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -217,7 +217,7 @@ class HeterComm { #endif } - void create_storage(int start_index, int end_index, int keylen, int vallen); + void create_storage(int start_index, int end_index, size_t keylen, size_t vallen); void destroy_storage(int start_index, int end_index); void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key, GradType* src_val); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index edd82a0fd3d61..61f877b5aaea4 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -128,8 +128,8 @@ void HeterComm::memory_copy( template void HeterComm::create_storage(int start_index, int end_index, - int keylen, - int vallen) { + size_t keylen, + size_t vallen) { #if defined(PADDLE_WITH_CUDA) auto& allocator = allocators_[start_index]; auto& nodes = path_[start_index][end_index].nodes_;