Skip to content
This repository has been archived by the owner on Apr 30, 2021. It is now read-only.

Commit

Permalink
add segment_long
Browse files Browse the repository at this point in the history
  • Loading branch information
bedapudi6788 committed Aug 13, 2019
1 parent dcd06e0 commit 274f2dc
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 9 deletions.
57 changes: 53 additions & 4 deletions deepsegment/deepsegment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pydload
import pickle
import os
import logging
import time

model_links = {
'en': {
Expand All @@ -14,6 +16,17 @@
}


def chunk(l, n):
"""Yield successive n-sized chunks from l."""
chunked_l = []
for i in range(0, len(l), n):
chunked_l.append(l[i:i + n])

if not chunked_l:
chunked_l = [l]

return chunked_l

class DeepSegment():
seqtag_model = None
data_converter = None
Expand Down Expand Up @@ -54,25 +67,61 @@ def segment(self, sents):
if not DeepSegment.seqtag_model:
print('Please load the model first')

string_output = False
if not isinstance(sents, list):
logging.warn("Batch input strings for faster inference.")
string_output = True
sents = [sents]

sents = [sent.strip().split() for sent in sents]

max_len = len(max(sents, key=len))
if max_len >= 40:
logging.warn("Consider using segment_long for longer sentences.")

encoded_sents = DeepSegment.data_converter.transform(sents)
all_tags = DeepSegment.seqtag_model.predict(encoded_sents)
all_tags = [np.argmax(_tags, axis=1).tolist() for _tags in all_tags]

segmented_sentences = []
for sent, tags in zip(sents, all_tags):
segmented_sentences = [[] for _ in sents]
for sent_index, (sent, tags) in enumerate(zip(sents, all_tags)):
segmented_sent = []
for i, (word, tag) in enumerate(zip(sent, tags)):
if tag == 2 and i > 0 and segmented_sent:
segmented_sent = ' '.join(segmented_sent)
segmented_sentences.append(segmented_sent)
segmented_sentences[sent_index].append(segmented_sent)
segmented_sent = []

segmented_sent.append(word)
if segmented_sent:
segmented_sentences.append(' '.join(segmented_sent))
segmented_sentences[sent_index].append(' '.join(segmented_sent))

if string_output:
return segmented_sentences[0]

return segmented_sentences

def segment_long(self, sent, n_window=None):
if not n_window:
logging.warn("Using default n_window=10. Set this parameter based on your data.")
n_window = 10

if isinstance(sent, list):
logging.error("segment_long doesn't support batching as of now. Batching will be added in a future release.")
return None

segmented = []
sent = sent.split()
prefix = []
while sent:
current_n_window = n_window - len(prefix)
if current_n_window == 0:
current_n_window = n_window

window = prefix + sent[:current_n_window]
sent = sent[current_n_window:]
segmented_window = self.segment([' '.join(window)])[0]
segmented += segmented_window[:-1]
prefix = segmented_window[-1].split()

return segmented
5 changes: 0 additions & 5 deletions deepsegment/test.py

This file was deleted.

0 comments on commit 274f2dc

Please sign in to comment.