train GRU mdoel with stroke therapist and patient data
We collected some sample data in file trial2.csv to train the network.
Feel free to change the hyperparameters in config_1.json. Recommand to directly use config_2.json.
Uncomment the lines in run.py to save your predicted results.
Use this jupyter notebook to get a 3D graph of predicted trajectory. You can use the results I got from file: saved test&predict results
Model architecture:
results:
(red line for ground truth and green line for our prediction)
z_axis model loss:
point by point prediction test loss: mae: 21.149654617941962
y_axis model loss:
point by point prediction test loss: mae: 34.590864224617455
x_axis model loss:
point by point prediction test loss: mae: 20.647305644118134
more updates later...