Skip to content

Commit

Permalink
hotfix a bug caused by copy_ instream
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear committed Nov 24, 2021
1 parent 2e58787 commit 2fb2cdf
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions patrickstar/core/chunk_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import torch

from patrickstar.core.memtracer import RuntimeMemTracer
from patrickstar.manager.cuda_context import CUDAContext
from patrickstar.profiler import profiler
from patrickstar.utils import logger, getsizeof
import patrickstar.utils.global_timer as global_timer
Expand Down Expand Up @@ -245,7 +244,6 @@ def move(self, target_device: torch.device):
f"used mem {self.memory_tracer.used_chunk_mem(target_device.type) / 1e6} MB"
)

cuda_ctx = CUDAContext()
# TODO(jiaruifang) asyc copy.
if target_device.type == "cpu":
pinned_payload_cpu = torch.empty(
Expand All @@ -254,13 +252,11 @@ def move(self, target_device: torch.device):
device="cpu:0",
pin_memory=True,
)
with torch.cuda.stream(cuda_ctx.copy_stream):
pinned_payload_cpu.copy_(self.payload)
pinned_payload_cpu.copy_(self.payload)
self.payload = pinned_payload_cpu
elif target_device.type == "cuda":
self.payload = self.payload.pin_memory()
with torch.cuda.stream(cuda_ctx.copy_stream):
self.payload = self.payload.to(target_device)
self.payload = self.payload.to(target_device)

self.memory_tracer.delete(src_device.type, self.get_payload_space())
self.memory_tracer.add(target_device.type, self.get_payload_space())
Expand Down

0 comments on commit 2fb2cdf

Please sign in to comment.