Skip to content

Commit 21e259d

Browse files
authored
Fix training speed regression introduced by "optimize VRAM for calculating pos_bias in LayoutLM v2, v3 (#26139)" (#30988)
* Revert "optimize VRAM for calculating pos_bias in LayoutLM v2, v3 (#26139)" This reverts commit a7e0ed8. * Instead of reverting commit, wrap indexing in torch.no_grad context * Apply wrapping in LayoutLMv2 * Add comments explaining reason for no_grad * Fix code format --------- Co-authored-by: Kevin Koehncke <kevin.koehncke@uipath.com>
1 parent 7f6e874 commit 21e259d

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

src/transformers/models/layoutlmv2/modeling_layoutlmv2.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,12 @@ def _calculate_1d_position_embeddings(self, position_ids):
383383
num_buckets=self.rel_pos_bins,
384384
max_distance=self.max_rel_pos,
385385
)
386-
rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
386+
# Since this is a simple indexing operation that is independent of the input,
387+
# no need to track gradients for this operation
388+
#
389+
# Without this no_grad context, training speed slows down significantly
390+
with torch.no_grad():
391+
rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
387392
rel_pos = rel_pos.contiguous()
388393
return rel_pos
389394

@@ -402,8 +407,13 @@ def _calculate_2d_position_embeddings(self, bbox):
402407
num_buckets=self.rel_2d_pos_bins,
403408
max_distance=self.max_rel_2d_pos,
404409
)
405-
rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
406-
rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
410+
# Since this is a simple indexing operation that is independent of the input,
411+
# no need to track gradients for this operation
412+
#
413+
# Without this no_grad context, training speed slows down significantly
414+
with torch.no_grad():
415+
rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
416+
rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
407417
rel_pos_x = rel_pos_x.contiguous()
408418
rel_pos_y = rel_pos_y.contiguous()
409419
rel_2d_pos = rel_pos_x + rel_pos_y

src/transformers/models/layoutlmv3/modeling_layoutlmv3.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,12 @@ def _cal_1d_pos_emb(self, position_ids):
600600
num_buckets=self.rel_pos_bins,
601601
max_distance=self.max_rel_pos,
602602
)
603-
rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
603+
# Since this is a simple indexing operation that is independent of the input,
604+
# no need to track gradients for this operation
605+
#
606+
# Without this no_grad context, training speed slows down significantly
607+
with torch.no_grad():
608+
rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
604609
rel_pos = rel_pos.contiguous()
605610
return rel_pos
606611

@@ -619,8 +624,13 @@ def _cal_2d_pos_emb(self, bbox):
619624
num_buckets=self.rel_2d_pos_bins,
620625
max_distance=self.max_rel_2d_pos,
621626
)
622-
rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
623-
rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
627+
# Since this is a simple indexing operation that is independent of the input,
628+
# no need to track gradients for this operation
629+
#
630+
# Without this no_grad context, training speed slows down significantly
631+
with torch.no_grad():
632+
rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
633+
rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
624634
rel_pos_x = rel_pos_x.contiguous()
625635
rel_pos_y = rel_pos_y.contiguous()
626636
rel_2d_pos = rel_pos_x + rel_pos_y

0 commit comments

Comments
 (0)