Skip to content

Commit

Permalink
add polish example
Browse files Browse the repository at this point in the history
  • Loading branch information
lipiji committed Nov 10, 2020
1 parent a0a14e6 commit 56a2c90
Show file tree
Hide file tree
Showing 5 changed files with 27,913 additions and 2 deletions.
63 changes: 63 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,69 @@ def parse_line(line, max_len, min_len):
return None
return xs_tpl, xs_seg, xs_pos, ys, ys_tpl, ys_seg, ys_pos

def s2xy_polish(lines, vocab, max_len, min_len):
data = []
for line in lines:
res = parse_line_polish(line, max_len, min_len)
data.append(res)
return batchify(data, vocab)

def parse_line_polish(line, max_len, min_len):
line = line.strip()
if not line:
return None
fs = line.split("<s2>")
author, cipai = fs[0].split("<s1>")
sents = fs[1].strip()
if len(sents) > max_len:
sents = sents[:max_len]
if len(sents) < min_len:
return None
sents = sents.split("</s>")

ys = []
xs_tpl = []
xs_seg = []
xs_pos = []

ctx = cipai
ws = [w for w in ctx]
xs_tpl = ws + [EOC]
xs_seg = [SS[0] for w in ws] + [EOC]
xs_pos = [SS[i+300] for i in range(len(ws))] + [EOC]

ys_tpl = []
ys_seg = []
ys_pos = []
for si, sent in enumerate(sents):
ws = []
sent = sent.strip()
if not sent:
continue
for w in sent:
ws.append(w)
if w == "_":
ys_tpl.append(CS[2])
else:
ys_tpl.append(w)
ys += ws + [RS]
ys_tpl += [RS]
ys_seg += [SS[si + 1] for w in ws] + [RS]
ys_pos += [PS[len(ws) - i] for i in range(len(ws))] + [RS]

ys += [EOS]
ys_tpl += [EOS]
ys_seg += [EOS]
ys_pos += [EOS]

xs_tpl += ys_tpl
xs_seg += ys_seg
xs_pos += ys_pos

if len(ys) < min_len:
return None

return xs_tpl, xs_seg, xs_pos, ys, ys_tpl, ys_seg, ys_pos

class DataLoader(object):
def __init__(self, vocab, filename, batch_size, max_len_y, min_len_y):
Expand Down
Loading

0 comments on commit 56a2c90

Please sign in to comment.