Skip to content

Commit 97f51f0

Browse files
YazhiGaoLeon Gao
andauthored
expose tfrecord dataset transformation function for LinkedIn usage (#10)
Co-authored-by: Leon Gao <legao@linkedin.com>
1 parent bc571a6 commit 97f51f0

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

src/detext/train/data_fn.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,33 @@ def input_fn_tfrecord(input_pattern,
271271
else:
272272
dataset = tf.data.TFRecordDataset(input_files[0])
273273

274+
dataset = tfrecord_transform_fn(dataset,
275+
batch_size,
276+
mode,
277+
vocab_table, vocab_table_for_id_ftr,
278+
feature_names,
279+
CLS, SEP, PAD, PAD_FOR_ID_FTR,
280+
output_buffer_size,
281+
max_len,
282+
min_len,
283+
cnn_filter_window_size,
284+
prefetch_size,
285+
num_data_process_threads)
286+
return dataset
287+
288+
289+
def tfrecord_transform_fn(dataset,
290+
batch_size,
291+
mode,
292+
vocab_table, vocab_table_for_id_ftr,
293+
feature_names,
294+
CLS, SEP, PAD, PAD_FOR_ID_FTR,
295+
output_buffer_size,
296+
max_len=None,
297+
min_len=None,
298+
cnn_filter_window_size=0,
299+
prefetch_size=100,
300+
num_data_process_threads=32):
274301
if mode == tf.estimator.ModeKeys.TRAIN:
275302
dataset = dataset.shuffle(output_buffer_size)
276303
dataset = dataset.repeat()

0 commit comments

Comments
 (0)