It is the generic golden program for deep learning with TensorFlow.
Following are the supported features.
- Data Format
- Predict Server
- Predict Client
- Use Cases
- Train model
- Export model
- Validate acc/auc
- Inference online
- Inference offline
- Network Model
- Logistic regression
- Deep neural network
- Convolution neural network
- Wide and deep model
- Customized models
- Others
- Checkpoint
- TensorBoard
- Exporter
- Dropout
- Optimizers
- Learning rate decay
- Batch normalization
- Distributed training
If your data is in CSV format, generate TFRecords like this.
cd ./data/cancer/
./generate_csv_tfrecords.py
If your data is in LIBSVM format, generate TFRecords like this.
cd ./data/a8a/
./generate_libsvm_tfrecord.py
For large dataset, you can use Spark to do that. Please refer to data.
You can train with the default configuration.
./dense_classifier.py
./sparse_classifier.py
Using different models or hyperparameters is easy with TensorFlow flags.
./dense_classifier.py --batch_size 1024 --epoch_number 1000 --step_to_validate 10 --optmizier adagrad --model dnn --model_network "128 32 8"
If you use other dataset like iris, no need to modify the code. Just run with parameters to specify the TFRecords files.
./dense_classifier.py --train_tfrecords_file ./data/iris/iris_train.csv.tfrecords --validate_tfrecords_file ./data/iris/iris_test.csv.tfrecords --feature_size 4 --label_size 3
If you want to use CNN model, try this command.
./dense_classifier.py --train_tfrecords_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --validate_tfrecords_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --feature_size 262144 --label_size 2 --batch_size 2 --validate_batch_size 2 --epoch_number -1 --model cnn
After training, it will export the model automatically. Or you can export manually.
./dense_classifier.py --mode export
If we want to run inference to validate the model, you can run like this.
./dense_classifier.py --mode inference
The program will generate TensorFlow event files automatically.
tensorboard --logdir ./tensorboard/
Then go to http://127.0.0.1:6006
in the browser.
The exported model is compatible with TensorFlow Serving. You can follow the document and run the tensorflow_model_server
.
./tensorflow_model_server --port=9000 --model_name=dense --model_base_path=./model/
We have provided some gRPC clients for dense and sparse models, such as Python predict client and Java predict client.
./predict_client.py --host 127.0.0.1 --port 9000 --model_name dense --model_version 1
mvn compile exec:java -Dexec.mainClass="com.tobe.DensePredictClient" -Dexec.args="127.0.0.1 9000 dense 1"
This project is widely used for different tasks with dense or sparse data.
If you want to make contirbutions, feel free to open an issue or pull-request.