Skip to content

Commit cee66cd

Browse files
committed
make training configurable
1 parent 4fb494f commit cee66cd

File tree

1 file changed

+38
-16
lines changed

1 file changed

+38
-16
lines changed

tf-ner-poc/src/main/python/namefinder/namefinder.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
# This poc is based on source code taken from:
2121
# https://github.com/guillaumegenthial/sequence_tagging
2222

23-
import sys
2423
from math import floor
2524
import tensorflow as tf
2625
import re
2726
import numpy as np
2827
import zipfile
2928
import os
3029
from tempfile import TemporaryDirectory
30+
import argparse
3131

3232
# global variables for unknown word and numbers
3333
__UNK__ = '__UNK__'
@@ -68,12 +68,16 @@ def __str__(self):
6868
class NameFinder:
6969
label_dict = {}
7070

71-
def __init__(self, use_lower_case_embeddings=False, vector_size=100):
71+
def __init__(self, use_lower_case_embeddings, allow_unk, allow_num, digit_pattern, encoding, vector_size=100):
7272
self.__vector_size = vector_size
7373
self.__use_lower_case_embeddings = use_lower_case_embeddings
74+
self.__allow_unk = allow_unk
75+
self.__allow_num = allow_num
76+
self.__digit_pattern = re.compile(digit_pattern)
77+
self.__encoding = encoding
7478

7579
def load_data(self, word_dict, file):
76-
with open(file) as f:
80+
with open(file, encoding=self.__encoding) as f:
7781
raw_data = f.readlines()
7882

7983
sentences = []
@@ -96,7 +100,8 @@ def load_data(self, word_dict, file):
96100
if self.__use_lower_case_embeddings:
97101
token = token.lower()
98102

99-
# TODO: implement NUM encoding
103+
if self.__allow_num and self.__digit_pattern.match(token):
104+
token = __NUM__
100105

101106
if word_dict.get(token) is not None:
102107
vector = word_dict[token]
@@ -340,8 +345,8 @@ def write_mapping(tags, output_filename):
340345
f.write('{}\n'.format(tag))
341346

342347

343-
def load_glove(glove_file):
344-
with open(glove_file) as f:
348+
def load_glove(glove_file, encoding='utf-8'):
349+
with open(glove_file, encoding=encoding) as f:
345350

346351
word_dict = {}
347352
embeddings = []
@@ -381,16 +386,28 @@ def load_glove(glove_file):
381386

382387

383388
def main():
384-
if len(sys.argv) != 5:
385-
print("Usage namefinder.py embedding_file train_file dev_file test_file")
386-
return
387-
388-
word_dict, rev_word_dict, embeddings, vector_size = load_glove(sys.argv[1])
389-
390-
name_finder = NameFinder(vector_size)
391-
392-
sentences, labels, char_set = name_finder.load_data(word_dict, sys.argv[2])
393-
sentences_dev, labels_dev, char_set_dev = name_finder.load_data(word_dict, sys.argv[3])
389+
parser = argparse.ArgumentParser()
390+
parser.add_argument("embedding_file", help="path to the embeddings file.")
391+
parser.add_argument("train_file", help="path to the training file.")
392+
parser.add_argument("dev_file", help="path to the dev file.")
393+
parser.add_argument("--allow_unk", help="use general UNK token and vector for unknown tokens.", default=True)
394+
parser.add_argument("--allow_num", help="use general NUM token and vector for all numeric tokens.", default=False)
395+
parser.add_argument("--lower_case_embeddings", help="convert tokens to lowercase for embeddings lookup.",
396+
default=False)
397+
parser.add_argument("--digit_pattern", help="regex to use for identifying numeric tokens.",
398+
default='^\\d+(,\\d+)*(\\.\\d+)?$')
399+
parser.add_argument("--data_encoding", help="set encoding of train and dev data.", default='utf-8')
400+
parser.add_argument("--embeddings_encoding", help="set encoding of the embeddings.", default='utf-8')
401+
args = parser.parse_args()
402+
403+
word_dict, rev_word_dict, embeddings, vector_size = load_glove(args.embedding_file, args.embeddings_encoding)
404+
405+
name_finder = NameFinder(use_lower_case_embeddings=args.lower_case_embeddings, allow_unk=args.allow_unk,
406+
allow_num=args.allow_num, digit_pattern=args.digit_pattern,
407+
encoding=args.data_encoding, vector_size=vector_size)
408+
409+
sentences, labels, char_set = name_finder.load_data(word_dict, args.train_file)
410+
sentences_dev, labels_dev, char_set_dev = name_finder.load_data(word_dict, args.dev_file)
394411

395412
char_dict = {k: v for v, k in enumerate(char_set | char_set_dev)}
396413

@@ -472,6 +489,11 @@ def main():
472489
write_mapping(name_finder.label_dict, temp_model_dir + "/label_dict.txt")
473490
write_mapping(char_dict, temp_model_dir + "/char_dict.txt")
474491

492+
write_mapping({'lower_case_embeddings=' + str(args.lower_case_embeddings).lower(): 0,
493+
'allow_unk=' + str(args.allow_unk).lower(): 1,
494+
'allow_num=' + str(args.allow_num).lower(): 2,
495+
'digit_pattern=' + re.escape(args.digit_pattern): 3 }, temp_model_dir + "/config.properties")
496+
475497
zipf = zipfile.ZipFile("namefinder-" + str(epoch) + ".zip", 'w', zipfile.ZIP_DEFLATED)
476498

477499
for root, dirs, files in os.walk(temp_model_dir):

0 commit comments

Comments
 (0)