15
15
from pix2tex .utils import in_model_path , parse_args , seed_everything , get_optimizer , get_scheduler
16
16
17
17
18
- def data_parallel (module , x :torch .Tensor , device_ids , output_device = None , ** kwargs ):
19
- if not device_ids or len (device_ids ) == 1 :
20
- return module (x , ** kwargs )
21
- if output_device is None :
22
- output_device = device_ids [0 ]
23
- replicas = nn .parallel .replicate (module , device_ids )
24
- inputs = nn .parallel .scatter (x , device_ids ) # Slices tensors into approximately equal chunks and distributes them across given GPUs.
25
- kwargs = nn .parallel .scatter (kwargs , device_ids ) # Duplicates references to objects that are not tensors.
26
- replicas = replicas [:len (inputs )]
27
- kwargs = kwargs [:len (inputs )]
28
- outputs = nn .parallel .parallel_apply (replicas , inputs , kwargs )
29
- return nn .parallel .gather (outputs , output_device )
30
-
31
-
32
18
def gpu_memory_check (model , args ):
33
19
# check if largest batch can be handled by system
34
20
try :
35
21
batchsize = args .batchsize if args .get ('micro_batchsize' , - 1 ) == - 1 else args .micro_batchsize
36
22
for _ in range (5 ):
37
23
im = torch .empty (batchsize , args .channels , args .max_height , args .min_height , device = args .device ).float ()
38
24
seq = torch .randint (0 , args .num_tokens , (batchsize , args .max_seq_len ), device = args .device ).long ()
39
- # model.decoder(seq, context=model.encoder(im)).sum().backward()
40
- # encoded = data_parallel(model.encoder, inputs=im, device_ids=args.gpu_devices)
41
- # loss = data_parallel(model.decoder, inputs=seq, device_ids=args.gpu_devices, context=encoded)
42
- loss = data_parallel (model , im , device_ids = args .gpu_devices , tgt_seq = seq )
25
+ loss = model .data_parallel (im , device_ids = args .gpu_devices , tgt_seq = seq )
43
26
loss .sum ().backward ()
44
27
except RuntimeError :
45
28
raise RuntimeError ("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize , args .max_height , args .max_width ))
@@ -60,7 +43,6 @@ def train(args):
60
43
gpu_memory_check (model , args )
61
44
if args .load_chkpt is not None :
62
45
model .load_state_dict (torch .load (args .load_chkpt , map_location = device ))
63
- encoder , decoder = model .encoder , model .decoder
64
46
max_bleu , max_token_acc = 0 , 0
65
47
out_path = os .path .join (args .model_path , args .name )
66
48
os .makedirs (out_path , exist_ok = True )
@@ -86,14 +68,9 @@ def save_models(e, step=0):
86
68
total_loss = 0
87
69
for j in range (0 , len (im ), microbatch ):
88
70
tgt_seq , tgt_mask = seq ['input_ids' ][j :j + microbatch ].to (device ), seq ['attention_mask' ][j :j + microbatch ].bool ().to (device )
89
- # encoded = encoder(im[j:j+microbatch].to(device))
90
- # encoded = data_parallel(encoder, inputs=im[j:j+microbatch].to(device), device_ids=args.gpu_devices)
91
- # loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)*microbatch/args.batchsize
92
- # loss = data_parallel(module=decoder, inputs=tgt_seq, device_ids=args.gpu_devices, mask=tgt_mask, context=encoded)*microbatch/args.batchsize
93
- # loss.backward()
94
- loss = data_parallel (model ,im [j :j + microbatch ].to (device ), device_ids = args .gpu_devices , tgt_seq = tgt_seq , mask = tgt_mask )* microbatch / args .batchsize
95
- loss .mean ().backward () # data parallism loss is a vector
96
- total_loss += loss .mean ().item ()
71
+ loss = model .data_parallel (im [j :j + microbatch ].to (device ), device_ids = args .gpu_devices , tgt_seq = tgt_seq , mask = tgt_mask )* microbatch / args .batchsize
72
+ loss .backward () # data parallism loss is a vector
73
+ total_loss += loss .item ()
97
74
torch .nn .utils .clip_grad_norm_ (model .parameters (), 1 )
98
75
opt .step ()
99
76
scheduler .step ()
0 commit comments