This repository implements the Sliced Recurrent Neural Networks of Zeping Yu in Tensorflow 2. It's based on the original Keras implementation.
- Implementation in Tensorflow 2 and Python 3
- Implements TimeDistributed Module which has problem working with GRU in current Tensorflow version.
- Efficient data processing pipeline
- Training pipeline of Yelp2013, Yelp2014, Yelp25 and custom dataset
- Inference pipeline
It's strongly recommended to use Anaconda to build the environment.
conda env create -f environment.yml
conda activate tensorflow
Everytime training on a new dataset, remember to add the --create_tfrecord option to create tfrecords from the dataset. Otherwise, it would use old tfrecords saved in ./tfrecord which means still training on old dataset. Once the tfrecords are created, you could skip it next time.
The model with highest evaluation accuracy would be saved in ./saved_model
This implementation uses GloVe embedding as the pretrained embedding weights. Please download glove.6B.200d.txt from this link and save as ./data/glove.6B.200d.txt.
Download Yelp2013 dataset from this link and save as ./data/yelp_2013.csv
python train.py --dataset Yelp2013 \
--val_size 0.1 \
--test_size 0.1 \
--epochs 10 \
--create_tfrecord \
--batch_size 2048 #RTX2080ti
Download Yelp2014 dataset from this link and save as ./data/yelp_2014.csv
python train.py --dataset Yelp2014 \
--val_size 0.1 \
--test_size 0.1 \
--epochs 10 \
--create_tfrecord \
--batch_size 2048
Download Yelp2015 dataset from this link and save as ./data/yelp_2015.csv
python train.py --dataset Yelp2015 \
--val_size 0.1 \
--test_size 0.1 \
--epochs 10 \
--create_tfrecord \
--batch_size 2048
Please refer to data/train_sample.csv to prepare the custom data.
The train dataset should be a csv file containing two columns.
- The first column contains the label ranging from 1 to class_num
- The second column contains the sentences.
python train.py --dataset Custom \
--train_data_path ./data/train_sample.csv \
--val_size 0.1 \
--test_size 0.1 \
--epochs 10\
--create_tfrecord \
--batch_size 2048
Dataset | Accuracy | Second per epoch |
---|---|---|
Yelp 2013 | 0.650 | 89 |
Yelp 2014 | 0.689 | 128 |
Yelp 2015 | 0.720 | 171 |
Dataset | Accuracy | Second per epoch |
---|---|---|
Yelp 2013 | 0.653 | 141 |
Yelp 2014 | 0.691 | 189 |
Yelp 2015 | 0.722 | 264 |
The accuracy of RNN is about 1% lower than the paper. It is because this implementation uses batch size=2048 to fully make use of the GPU resource while the paper uses 50