1+ import os
2+ from torchtext import data
3+
4+ PAD , BOS , EOS = 1 , 2 , 3
5+
6+ class DataLoader :
7+
8+ def __init__ (
9+ self ,
10+ train_fn = None ,
11+ valid_fn = None ,
12+ exts = None ,
13+ batch_size = 64 ,
14+ device = 'cpu' ,
15+ max_vocab = 9999999 ,
16+ max_length = 255 ,
17+ fix_length = None ,
18+ use_bos = True ,
19+ use_eos = True ,
20+ shuffle = True ,
21+ ):
22+
23+ self .src = data .Field (
24+ sequential = True ,
25+ use_vocab = True ,
26+ batch_first = True ,
27+ include_lengths = True ,
28+ fix_length = fix_length ,
29+ init_token = None ,
30+ eos_token = None ,
31+ )
32+ self .tgt = data .Field (
33+ sequential = True ,
34+ use_vocab = True ,
35+ batch_first = True ,
36+ include_lengths = True ,
37+ fix_length = fix_length ,
38+ init_token = None ,
39+ eos_token = None ,
40+ )
41+
42+ if train_fn is not None and valid_fn is not None and exts is not None :
43+ train = TranslationDataset (
44+ path = train_fn ,
45+ exts = exts ,
46+ fields = [('src' , self .src ), ('tgt' , self .tgt )],
47+ max_length = max_length
48+ )
49+ valid = TranslationDataset (
50+ path = valid_fn ,
51+ exts = exts ,
52+ fields = [('src' , self .src ), ('tgt' , self .tgt )],
53+ max_length = max_length ,
54+ )
55+
56+ self .train_iter = data .BucketIterator (
57+ train ,
58+ batch_size = batch_size ,
59+ device = 'cuda:%d' % device if device >= 0 else 'cpu' ,
60+ shuffle = shuffle ,
61+ # 비슷한 길이끼리 미니 배치를 만들도록 정렬
62+ sort_key = lambda x : len (x .tgt ) + (max_length * len (x .src )),
63+ sort_within_batch = True ,
64+ )
65+ self .valid_iter = data .BucketIterator (
66+ valid ,
67+ batch_size = batch_size ,
68+ device = 'cuda:%d' % device if device >= 0 else 'cpu' ,
69+ shuffle = False ,
70+ # 비슷한 길이끼리 미니 배치를 만들도록 정렬
71+ sort_key = lambda x : len (x .tgt ) + (max_length * len (x .src )),
72+ sort_within_batch = True ,
73+ )
74+
75+ self .src .build_vocab (train , max_size = max_vocab )
76+ self .tgt .build_vocab (train , max_size = max_vocab )
77+
78+ def load_vocab (self , src_vocab , tgt_vocab ):
79+ self .src .vocab = src_vocab
80+ self .tgt .vocab = tgt_vocab
81+
82+
83+ class TranslationDataset (data .Dataset ):
84+
85+ def __init__ (self , path , exts , fields , max_length = None , ** kwargs ):
86+ """Create a TranslationDataset given paths and fields.
87+
88+ MAX LENGTH로 각 데이터를 자르기 위한 예외처리 오버라이딩
89+
90+ Arguments:
91+ path: Common prefix of paths to the data files for both languages.
92+ exts: A tuple containing the extension to path for each language.
93+ fields: A tuple containing the fields that will be used for data
94+ in each language.
95+ Remaining keyword arguments: Passed to the constructor of
96+ data.Dataset.
97+ """
98+ if not isinstance (fields [0 ], (tuple , list )):
99+ fields = [('src' , fields [0 ]), ('trg' , fields [1 ])]
100+
101+ if not path .endswith ('.' ):
102+ path += '.'
103+
104+ src_path , trg_path = tuple (os .path .expanduser (path + x ) for x in exts )
105+
106+ examples = []
107+ with open (src_path , encoding = 'utf-8' ) as src_file , open (trg_path , encoding = 'utf-8' ) as trg_file :
108+ for src_line , trg_line in zip (src_file , trg_file ):
109+ src_line , trg_line = src_line .strip (), trg_line .strip ()
110+ if max_length and max_length < max (len (src_line .split ()), len (trg_line .split ())):
111+ continue
112+ if src_line != '' and trg_line != '' :
113+ examples += [data .Example .fromlist ([src_line , trg_line ], fields )]
114+
115+ super ().__init__ (examples , fields , ** kwargs )
116+
117+ @staticmethod
118+ def sort_key (ex ):
119+ return data .interleave_keys (len (ex .src ), len (ex .trg ))
0 commit comments