Skip to content

Commit a38deff

Browse files
committed
[Transformer-XL/PyT] Large model support; multi-node training; inference with TorchScript
1 parent 1def26d commit a38deff

30 files changed

+2777
-989
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
**/.DS_Store
22
__pycache__/
33
data/
4+
results/
5+
*.out
6+
*.log
7+
*.json

PyTorch/LanguageModeling/Transformer-XL/README.md

Lines changed: 775 additions & 216 deletions
Large diffs are not rendered by default.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
LM-TFM*
22
internal/result*
3+
*.out
4+
*.log
5+
*.json

PyTorch/LanguageModeling/Transformer-XL/pytorch/Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:19.09-py3
15+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:19.11-py3
1616
FROM ${FROM_IMAGE_NAME}
1717

1818
ENV LANG C.UTF-8
@@ -26,5 +26,6 @@ WORKDIR /workspace/transformer-xl/pytorch
2626

2727
COPY requirements.txt .
2828
RUN pip install --no-cache-dir -r requirements.txt
29+
RUN pip install --no-cache-dir git+https://github.com/NVIDIA/dllogger.git#egg=dllogger
2930

3031
ADD . /workspace/transformer-xl/pytorch

PyTorch/LanguageModeling/Transformer-XL/pytorch/data_utils.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,32 +27,41 @@
2727

2828

2929
class LMOrderedIterator(object):
30-
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):
30+
def __init__(self, data, bsz, bptt, device='cpu', mem_len=None, ext_len=None, warmup=True):
3131
"""
3232
data -- LongTensor -- the LongTensor is strictly ordered
3333
"""
3434
self.bsz = bsz
3535
self.bptt = bptt
3636
self.ext_len = ext_len if ext_len is not None else 0
37+
self.mem_len = mem_len
38+
self.warmup = warmup
3739

3840
self.device = device
3941

4042
# Work out how cleanly we can divide the dataset into bsz parts.
41-
self.n_step = data.size(0) // bsz
43+
n_step = data.size(0) // bsz
4244

4345
# Trim off any extra elements that wouldn't cleanly fit (remainders).
44-
data = data.narrow(0, 0, self.n_step * bsz)
46+
data = data[:n_step * bsz]
4547

4648
# Evenly divide the data across the bsz batches.
4749
self.data = data.view(bsz, -1).t().contiguous()
4850

51+
if mem_len and warmup:
52+
self.warmup_batches = (mem_len + bptt - 1) // bptt
53+
self.warmup_elems = self.warmup_batches * bptt
54+
55+
warmup_data = self.data.roll((self.warmup_elems, 1), (0, 1))[:self.warmup_elems]
56+
self.data = torch.cat((warmup_data, self.data))
57+
4958
# Partition data for DistributedDataParallel
5059
world_size = utils.distributed.get_world_size()
5160
rank = utils.distributed.get_rank()
52-
self.data = self.data.chunk(world_size, dim=1)[rank].to(device)
61+
self.data = self.data.chunk(world_size, dim=1)[rank]
5362

5463
# Number of mini-batches
55-
self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
64+
self.n_batch = (self.data.size(0) + self.bptt - 1) // self.bptt
5665

5766
def roll(self):
5867
for i in range(self.data.size(1)):
@@ -70,10 +79,15 @@ def get_batch(self, i, bptt=None):
7079
end_idx = i + seq_len
7180
beg_idx = max(0, i - self.ext_len)
7281

73-
data = self.data[beg_idx:end_idx]
74-
target = self.data[i+1:i+1+seq_len]
82+
data = self.data[beg_idx:end_idx].to(self.device)
83+
target = self.data[i+1:i+1+seq_len].to(self.device)
84+
85+
if self.mem_len and self.warmup:
86+
warm = i >= self.warmup_elems
87+
else:
88+
warm = True
7589

76-
return data, target, seq_len
90+
return data, target, seq_len, warm
7791

7892
def get_fixlen_iter(self, start=0):
7993
for i in range(start, self.data.size(0) - 1, self.bptt):

0 commit comments

Comments
 (0)