diff --git a/patrickstar/core/chunk_data.py b/patrickstar/core/chunk_data.py index f0e594f2a..9b0a452a8 100644 --- a/patrickstar/core/chunk_data.py +++ b/patrickstar/core/chunk_data.py @@ -31,7 +31,6 @@ import torch from patrickstar.core.memtracer import RuntimeMemTracer -from patrickstar.manager.cuda_context import CUDAContext from patrickstar.profiler import profiler from patrickstar.utils import logger, getsizeof import patrickstar.utils.global_timer as global_timer @@ -245,7 +244,6 @@ def move(self, target_device: torch.device): f"used mem {self.memory_tracer.used_chunk_mem(target_device.type) / 1e6} MB" ) - cuda_ctx = CUDAContext() # TODO(jiaruifang) asyc copy. if target_device.type == "cpu": pinned_payload_cpu = torch.empty( @@ -254,13 +252,11 @@ def move(self, target_device: torch.device): device="cpu:0", pin_memory=True, ) - with torch.cuda.stream(cuda_ctx.copy_stream): - pinned_payload_cpu.copy_(self.payload) + pinned_payload_cpu.copy_(self.payload) self.payload = pinned_payload_cpu elif target_device.type == "cuda": self.payload = self.payload.pin_memory() - with torch.cuda.stream(cuda_ctx.copy_stream): - self.payload = self.payload.to(target_device) + self.payload = self.payload.to(target_device) self.memory_tracer.delete(src_device.type, self.get_payload_space()) self.memory_tracer.add(target_device.type, self.get_payload_space())