Skip to content

Commit

Permalink
initial device context at args.device
Browse files Browse the repository at this point in the history
if User A use gpu6,7 and User B use gpu0. Then UserB kills all process at gpu0 but User A's training also stopped. because `torch.cuda.empty_cache()` default initialize at rank0. 
Reference: pytorch/pytorch#25752 (comment)
  • Loading branch information
TITC authored May 21, 2022
1 parent 0938894 commit dbf75d9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pix2tex/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def gpu_memory_check(model, args):
except RuntimeError:
raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width))
model.zero_grad()
torch.cuda.empty_cache()
with torch.cuda.device(args.device):torch.cuda.empty_cache()
del im, seq


Expand Down

0 comments on commit dbf75d9

Please sign in to comment.