Skip to content

Commit 71ba423

Browse files
committed
test2
1 parent 9c6c9be commit 71ba423

File tree

3 files changed

+203
-48
lines changed

3 files changed

+203
-48
lines changed
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import os
2+
from torchtext import data
3+
4+
PAD, BOS, EOS = 1, 2, 3
5+
6+
class DataLoader:
7+
8+
def __init__(
9+
self,
10+
train_fn=None,
11+
valid_fn=None,
12+
exts=None,
13+
batch_size=64,
14+
device='cpu',
15+
max_vocab=9999999,
16+
max_length=255,
17+
fix_length=None,
18+
use_bos=True,
19+
use_eos=True,
20+
shuffle=True,
21+
):
22+
23+
self.src = data.Field(
24+
sequential=True,
25+
use_vocab=True,
26+
batch_first=True,
27+
include_lengths=True,
28+
fix_length=fix_length,
29+
init_token=None,
30+
eos_token=None,
31+
)
32+
self.tgt = data.Field(
33+
sequential=True,
34+
use_vocab=True,
35+
batch_first=True,
36+
include_lengths=True,
37+
fix_length=fix_length,
38+
init_token=None,
39+
eos_token=None,
40+
)
41+
42+
if train_fn is not None and valid_fn is not None and exts is not None:
43+
train = TranslationDataset(
44+
path=train_fn,
45+
exts=exts,
46+
fields=[('src', self.src), ('tgt', self.tgt)],
47+
max_length=max_length
48+
)
49+
valid = TranslationDataset(
50+
path=valid_fn,
51+
exts=exts,
52+
fields=[('src', self.src), ('tgt', self.tgt)],
53+
max_length=max_length,
54+
)
55+
56+
self.train_iter = data.BucketIterator(
57+
train,
58+
batch_size=batch_size,
59+
device='cuda:%d' % device if device >= 0 else 'cpu',
60+
shuffle=shuffle,
61+
# 비슷한 길이끼리 미니 배치를 만들도록 정렬
62+
sort_key=lambda x: len(x.tgt) + (max_length * len(x.src)),
63+
sort_within_batch=True,
64+
)
65+
self.valid_iter = data.BucketIterator(
66+
valid,
67+
batch_size=batch_size,
68+
device='cuda:%d' % device if device >= 0 else 'cpu',
69+
shuffle=False,
70+
# 비슷한 길이끼리 미니 배치를 만들도록 정렬
71+
sort_key=lambda x: len(x.tgt) + (max_length * len(x.src)),
72+
sort_within_batch=True,
73+
)
74+
75+
self.src.build_vocab(train, max_size=max_vocab)
76+
self.tgt.build_vocab(train, max_size=max_vocab)
77+
78+
def load_vocab(self, src_vocab, tgt_vocab):
79+
self.src.vocab = src_vocab
80+
self.tgt.vocab = tgt_vocab
81+
82+
83+
class TranslationDataset(data.Dataset):
84+
85+
def __init__(self, path, exts, fields, max_length=None, **kwargs):
86+
"""Create a TranslationDataset given paths and fields.
87+
88+
MAX LENGTH로 각 데이터를 자르기 위한 예외처리 오버라이딩
89+
90+
Arguments:
91+
path: Common prefix of paths to the data files for both languages.
92+
exts: A tuple containing the extension to path for each language.
93+
fields: A tuple containing the fields that will be used for data
94+
in each language.
95+
Remaining keyword arguments: Passed to the constructor of
96+
data.Dataset.
97+
"""
98+
if not isinstance(fields[0], (tuple, list)):
99+
fields = [('src', fields[0]), ('trg', fields[1])]
100+
101+
if not path.endswith('.'):
102+
path += '.'
103+
104+
src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts)
105+
106+
examples = []
107+
with open(src_path, encoding='utf-8') as src_file, open(trg_path, encoding='utf-8') as trg_file:
108+
for src_line, trg_line in zip(src_file, trg_file):
109+
src_line, trg_line = src_line.strip(), trg_line.strip()
110+
if max_length and max_length < max(len(src_line.split()), len(trg_line.split())):
111+
continue
112+
if src_line != '' and trg_line != '':
113+
examples += [data.Example.fromlist([src_line, trg_line], fields)]
114+
115+
super().__init__(examples, fields, **kwargs)
116+
117+
@staticmethod
118+
def sort_key(ex):
119+
return data.interleave_keys(len(ex.src), len(ex.trg))

0 commit comments

Comments
 (0)