@@ -7,9 +7,9 @@ class DataLoader:
77
88 def __init__ (
99 self ,
10- train_fn ,
11- valid_fn ,
12- exts ,
10+ train_fn = None ,
11+ valid_fn = None ,
12+ exts = None ,
1313 batch_size = 64 ,
1414 device = 'cpu' ,
1515 max_vocab = 9999999 ,
@@ -39,40 +39,41 @@ def __init__(
3939 eos_token = None ,
4040 )
4141
42- train = TranslationDataset (
43- path = train_fn ,
44- exts = exts ,
45- fields = [('src' , self .src ), ('tgt' , self .tgt )],
46- max_length = max_length
47- )
48- valid = TranslationDataset (
49- path = valid_fn ,
50- exts = exts ,
51- fields = [('src' , self .src ), ('tgt' , self .tgt )],
52- max_length = max_length ,
53- )
54-
55- self .train_iter = data .BucketIterator (
56- train ,
57- batch_size = batch_size ,
58- device = 'cuda:%d' % device if device >= 0 else 'cpu' ,
59- shuffle = shuffle ,
60- # 비슷한 길이끼리 미니 배치를 만들도록 정렬
61- sort_key = lambda x : len (x .tgt ) + (max_length * len (x .src )),
62- sort_within_batch = True ,
63- )
64- self .valid_iter = data .BucketIterator (
65- valid ,
66- batch_size = batch_size ,
67- device = 'cuda:%d' % device if device >= 0 else 'cpu' ,
68- shuffle = False ,
69- # 비슷한 길이끼리 미니 배치를 만들도록 정렬
70- sort_key = lambda x : len (x .tgt ) + (max_length * len (x .src )),
71- sort_within_batch = True ,
72- )
73-
74- self .src .build_vocab (train , max_size = max_vocab )
75- self .tgt .build_vocab (train , max_size = max_vocab )
42+ if train_fn is not None and valid_fn is not None and exts is not None :
43+ train = TranslationDataset (
44+ path = train_fn ,
45+ exts = exts ,
46+ fields = [('src' , self .src ), ('tgt' , self .tgt )],
47+ max_length = max_length
48+ )
49+ valid = TranslationDataset (
50+ path = valid_fn ,
51+ exts = exts ,
52+ fields = [('src' , self .src ), ('tgt' , self .tgt )],
53+ max_length = max_length ,
54+ )
55+
56+ self .train_iter = data .BucketIterator (
57+ train ,
58+ batch_size = batch_size ,
59+ device = 'cuda:%d' % device if device >= 0 else 'cpu' ,
60+ shuffle = shuffle ,
61+ # 비슷한 길이끼리 미니 배치를 만들도록 정렬
62+ sort_key = lambda x : len (x .tgt ) + (max_length * len (x .src )),
63+ sort_within_batch = True ,
64+ )
65+ self .valid_iter = data .BucketIterator (
66+ valid ,
67+ batch_size = batch_size ,
68+ device = 'cuda:%d' % device if device >= 0 else 'cpu' ,
69+ shuffle = False ,
70+ # 비슷한 길이끼리 미니 배치를 만들도록 정렬
71+ sort_key = lambda x : len (x .tgt ) + (max_length * len (x .src )),
72+ sort_within_batch = True ,
73+ )
74+
75+ self .src .build_vocab (train , max_size = max_vocab )
76+ self .tgt .build_vocab (train , max_size = max_vocab )
7677
7778 def load_vocab (self , src_vocab , tgt_vocab ):
7879 self .src .vocab = src_vocab
0 commit comments