Skip to content

Latest commit

 

History

History
71 lines (32 loc) · 1.84 KB

README.md

File metadata and controls

71 lines (32 loc) · 1.84 KB

GRU

train GRU mdoel with stroke therapist and patient data

We collected some sample data in file trial2.csv to train the network.

Feel free to change the hyperparameters in config_1.json. Recommand to directly use config_2.json.

Uncomment the lines in run.py to save your predicted results.

Use this jupyter notebook to get a 3D graph of predicted trajectory. You can use the results I got from file: saved test&predict results Open In Colab

Model architecture:

model

results:

3D prediction graph

(red line for ground truth and green line for our prediction)

3d

z_axis:

results_2_try

z_axis model loss:

point by point prediction test loss: mae: 21.149654617941962

Model Loss

y_axis:

results_2_try

y_axis model loss:

point by point prediction test loss: mae: 34.590864224617455

Model Loss

x_axis:

results_2_try

x_axis model loss:

point by point prediction test loss: mae: 20.647305644118134

Model Loss

more updates later...