Skip to content

Commit

Permalink
Optional gpu_devices setting
Browse files Browse the repository at this point in the history
  • Loading branch information
lukas-blecher committed May 19, 2022
1 parent b14d91c commit 426e594
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions pix2tex/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def data_parallel(module, inputs, device_ids, output_device=None, **kwargs):
if output_device is None:
output_device = device_ids[0]
replicas = nn.parallel.replicate(module, device_ids)
inputs = nn.parallel.scatter(inputs, device_ids) #Slices tensors into approximately equal chunks and distributes them across given GPUs.
kwargs = nn.parallel.scatter(kwargs, device_ids) # Duplicates references to objects that are not tensors.
inputs = nn.parallel.scatter(inputs, device_ids) # Slices tensors into approximately equal chunks and distributes them across given GPUs.
kwargs = nn.parallel.scatter(kwargs, device_ids) # Duplicates references to objects that are not tensors.
replicas = replicas[:len(inputs)]
kwargs = kwargs[:len(inputs)]
outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs)
Expand All @@ -41,7 +41,7 @@ def gpu_memory_check(model, args):
loss = data_parallel(model.decoder, inputs=seq, device_ids=args.gpu_devices, context=encoded)
loss.sum().backward()
except RuntimeError:
raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize."%(batchsize, args.max_height, args.max_width))
raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width))
model.zero_grad()
torch.cuda.empty_cache()
del im, seq
Expand Down Expand Up @@ -74,7 +74,7 @@ def save_models(e, step=0):
microbatch = args.get('micro_batchsize', -1)
if microbatch == -1:
microbatch = args.batchsize

try:
for e in range(args.epoch, args.epochs):
args.epoch = e
Expand All @@ -88,9 +88,9 @@ def save_models(e, step=0):
# encoded = encoder(im[j:j+microbatch].to(device))
encoded = data_parallel(encoder, inputs=im[j:j+microbatch].to(device), device_ids=args.gpu_devices)
# loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)*microbatch/args.batchsize
loss = data_parallel(module=decoder, inputs=tgt_seq, device_ids=args.gpu_devices, mask=tgt_mask, context=encoded)*microbatch/args.batchsize
loss = data_parallel(module=decoder, inputs=tgt_seq, device_ids=args.gpu_devices, mask=tgt_mask, context=encoded)*microbatch/args.batchsize
# loss.backward()
loss.mean().backward()# data parallism loss is a vector
loss.mean().backward() # data parallism loss is a vector
total_loss += loss.mean().item()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
opt.step()
Expand Down
4 changes: 2 additions & 2 deletions pix2tex/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def parse_args(args, **kwargs) -> Munch:
def get_device(args, kwargs):
device = 'cpu'
available_gpus = torch.cuda.device_count()
args.gpu_devices = args.gpu_devices if args.gpu_devices else range(available_gpus)
args.gpu_devices = args.gpu_devices if args.get('gpu_devices', False) else range(available_gpus)
if available_gpus > 0 and not kwargs.no_cuda:
device = 'cuda:%d' % args.gpu_devices[0] if args.gpu_devices else 0
assert available_gpus > = len(args.gpu_devices), "Available %d gpu, but specified gpu %s." % (available_gpus, ','.join(map(str, args.gpu_devices)))
assert available_gpus >= len(args.gpu_devices), "Available %d gpu, but specified gpu %s." % (available_gpus, ','.join(map(str, args.gpu_devices)))
assert max(args.gpu_devices) < available_gpus, "legal gpu_devices should in [%s], received [%s]" % (','.join(map(str, range(available_gpus))),','.join(map(str, args.gpu_devices)))
return device

Expand Down

0 comments on commit 426e594

Please sign in to comment.