2727
2828
2929class LMOrderedIterator (object ):
30- def __init__ (self , data , bsz , bptt , device = 'cpu' , ext_len = None ):
30+ def __init__ (self , data , bsz , bptt , device = 'cpu' , mem_len = None , ext_len = None , warmup = True ):
3131 """
3232 data -- LongTensor -- the LongTensor is strictly ordered
3333 """
3434 self .bsz = bsz
3535 self .bptt = bptt
3636 self .ext_len = ext_len if ext_len is not None else 0
37+ self .mem_len = mem_len
38+ self .warmup = warmup
3739
3840 self .device = device
3941
4042 # Work out how cleanly we can divide the dataset into bsz parts.
41- self . n_step = data .size (0 ) // bsz
43+ n_step = data .size (0 ) // bsz
4244
4345 # Trim off any extra elements that wouldn't cleanly fit (remainders).
44- data = data . narrow ( 0 , 0 , self . n_step * bsz )
46+ data = data [: n_step * bsz ]
4547
4648 # Evenly divide the data across the bsz batches.
4749 self .data = data .view (bsz , - 1 ).t ().contiguous ()
4850
51+ if mem_len and warmup :
52+ self .warmup_batches = (mem_len + bptt - 1 ) // bptt
53+ self .warmup_elems = self .warmup_batches * bptt
54+
55+ warmup_data = self .data .roll ((self .warmup_elems , 1 ), (0 , 1 ))[:self .warmup_elems ]
56+ self .data = torch .cat ((warmup_data , self .data ))
57+
4958 # Partition data for DistributedDataParallel
5059 world_size = utils .distributed .get_world_size ()
5160 rank = utils .distributed .get_rank ()
52- self .data = self .data .chunk (world_size , dim = 1 )[rank ]. to ( device )
61+ self .data = self .data .chunk (world_size , dim = 1 )[rank ]
5362
5463 # Number of mini-batches
55- self .n_batch = (self .n_step + self .bptt - 1 ) // self .bptt
64+ self .n_batch = (self .data . size ( 0 ) + self .bptt - 1 ) // self .bptt
5665
5766 def roll (self ):
5867 for i in range (self .data .size (1 )):
@@ -70,10 +79,15 @@ def get_batch(self, i, bptt=None):
7079 end_idx = i + seq_len
7180 beg_idx = max (0 , i - self .ext_len )
7281
73- data = self .data [beg_idx :end_idx ]
74- target = self .data [i + 1 :i + 1 + seq_len ]
82+ data = self .data [beg_idx :end_idx ].to (self .device )
83+ target = self .data [i + 1 :i + 1 + seq_len ].to (self .device )
84+
85+ if self .mem_len and self .warmup :
86+ warm = i >= self .warmup_elems
87+ else :
88+ warm = True
7589
76- return data , target , seq_len
90+ return data , target , seq_len , warm
7791
7892 def get_fixlen_iter (self , start = 0 ):
7993 for i in range (start , self .data .size (0 ) - 1 , self .bptt ):
0 commit comments