9
9
10
10
import numpy as np
11
11
import time
12
- from physics_engine import gen
12
+ from physics_engine import gen , make_video
13
13
FLAGS = None
14
14
15
15
def m (O ,Rr ,Rs ,Ra ):
@@ -108,11 +108,12 @@ def train():
108
108
tf .global_variables_initializer ().run ();
109
109
110
110
# Data Generation
111
- set_num = 2000 ;
111
+ #set_num=2000;
112
+ set_num = 1 ;
112
113
total_data = np .zeros ((999 * set_num ,FLAGS .Ds ,FLAGS .No ),dtype = object );
113
114
total_label = np .zeros ((999 * set_num ,FLAGS .Dp ,FLAGS .No ),dtype = object );
114
115
for i in range (set_num ):
115
- raw_data = gen (FLAGS .No ,False );
116
+ raw_data = gen (FLAGS .No ,True );
116
117
data = np .zeros ((999 ,FLAGS .Ds ,FLAGS .No ),dtype = object );
117
118
label = np .zeros ((999 ,FLAGS .Dp ,FLAGS .No ),dtype = object );
118
119
for j in range (1000 - 1 ):
@@ -121,8 +122,10 @@ def train():
121
122
total_label [i * 999 :(i + 1 )* 999 ,:]= label ;
122
123
123
124
# Shuffle
124
- tr_data_num = 1000000 ;
125
- val_data_num = 200000 ;
125
+ #tr_data_num=1000000;
126
+ #val_data_num=200000;
127
+ tr_data_num = 400 ;
128
+ val_data_num = 300 ;
126
129
total_idx = range (len (total_data ));np .random .shuffle (total_idx );
127
130
mixed_data = total_data [total_idx ];
128
131
mixed_label = total_label [total_idx ];
@@ -149,7 +152,8 @@ def train():
149
152
cnt += 1 ;
150
153
151
154
# Training
152
- for i in range (2000 ):
155
+ max_epoches = 2000 ;
156
+ for i in range (max_epoches ):
153
157
for j in range (int (len (train_data )/ mini_batch_num )):
154
158
batch_data = train_data [j * mini_batch_num :(j + 1 )* mini_batch_num ];
155
159
batch_label = train_label [j * mini_batch_num :(j + 1 )* mini_batch_num ];
@@ -160,6 +164,23 @@ def train():
160
164
batch_label = val_label [j * mini_batch_num :(j + 1 )* mini_batch_num ];
161
165
val_loss += sess .run (loss ,feed_dict = {O :batch_data ,Rr :Rr_data ,Rs :Rs_data ,Ra :Ra_data ,P_label :batch_label ,X :X_data });
162
166
print ("Epoch " + str (i + 1 )+ " Validation MSE: " + str (val_loss / (j + 1 )));
167
+
168
+ # Make Video
169
+ frame_len = 250 ;
170
+ raw_data = gen (FLAGS .No ,True );
171
+ xy_origin = raw_data [:frame_len ,:,1 :3 ];
172
+ estimated_data = np .zeros ((frame_len ,FLAGS .No ,FLAGS .Ds ),dtype = float );
173
+ estimated_data [0 ]= raw_data [0 ];
174
+ for i in range (1 ,frame_len ):
175
+ velocities = sess .run (P ,feed_dict = {O :[np .transpose (estimated_data [i - 1 ])],Rr :[Rr_data [0 ]],Rs :[Rs_data [0 ]],Ra :[Ra_data [0 ]],X :[X_data [0 ]]})[0 ];
176
+ estimated_data [i ,:,0 ]= estimated_data [i - 1 ][:,0 ];
177
+ estimated_data [i ,:,3 :5 ]= np .transpose (velocities );
178
+ estimated_data [i ,:,1 :3 ]= estimated_data [i - 1 ,:,1 :3 ]+ estimated_data [i ,:,3 :5 ]* 0.001 ;
179
+ xy_estimated = estimated_data [:,:,1 :3 ];
180
+ print ("Video Recording" );
181
+ make_video (xy_origin ,"true.mp4" );
182
+ make_video (xy_estimated ,"modeling.mp4" );
183
+
163
184
164
185
def main (_ ):
165
186
FLAGS .log_dir += str (int (time .time ()));
0 commit comments