Skip to content

Commit 116a3ec

Browse files
author
i338425
committed
In Progress
1 parent 3b27f17 commit 116a3ec

File tree

6 files changed

+50
-15
lines changed

6 files changed

+50
-15
lines changed

interaction_network.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import numpy as np
1111
import time
12-
from physics_engine import gen
12+
from physics_engine import gen, make_video
1313
FLAGS = None
1414

1515
def m(O,Rr,Rs,Ra):
@@ -108,11 +108,12 @@ def train():
108108
tf.global_variables_initializer().run();
109109

110110
# Data Generation
111-
set_num=2000;
111+
#set_num=2000;
112+
set_num=1;
112113
total_data=np.zeros((999*set_num,FLAGS.Ds,FLAGS.No),dtype=object);
113114
total_label=np.zeros((999*set_num,FLAGS.Dp,FLAGS.No),dtype=object);
114115
for i in range(set_num):
115-
raw_data=gen(FLAGS.No,False);
116+
raw_data=gen(FLAGS.No,True);
116117
data=np.zeros((999,FLAGS.Ds,FLAGS.No),dtype=object);
117118
label=np.zeros((999,FLAGS.Dp,FLAGS.No),dtype=object);
118119
for j in range(1000-1):
@@ -121,8 +122,10 @@ def train():
121122
total_label[i*999:(i+1)*999,:]=label;
122123

123124
# Shuffle
124-
tr_data_num=1000000;
125-
val_data_num=200000;
125+
#tr_data_num=1000000;
126+
#val_data_num=200000;
127+
tr_data_num=400;
128+
val_data_num=300;
126129
total_idx=range(len(total_data));np.random.shuffle(total_idx);
127130
mixed_data=total_data[total_idx];
128131
mixed_label=total_label[total_idx];
@@ -149,7 +152,8 @@ def train():
149152
cnt+=1;
150153

151154
# Training
152-
for i in range(2000):
155+
max_epoches=2000;
156+
for i in range(max_epoches):
153157
for j in range(int(len(train_data)/mini_batch_num)):
154158
batch_data=train_data[j*mini_batch_num:(j+1)*mini_batch_num];
155159
batch_label=train_label[j*mini_batch_num:(j+1)*mini_batch_num];
@@ -160,6 +164,23 @@ def train():
160164
batch_label=val_label[j*mini_batch_num:(j+1)*mini_batch_num];
161165
val_loss+=sess.run(loss,feed_dict={O:batch_data,Rr:Rr_data,Rs:Rs_data,Ra:Ra_data,P_label:batch_label,X:X_data});
162166
print("Epoch "+str(i+1)+" Validation MSE: "+str(val_loss/(j+1)));
167+
168+
# Make Video
169+
frame_len=250;
170+
raw_data=gen(FLAGS.No,True);
171+
xy_origin=raw_data[:frame_len,:,1:3];
172+
estimated_data=np.zeros((frame_len,FLAGS.No,FLAGS.Ds),dtype=float);
173+
estimated_data[0]=raw_data[0];
174+
for i in range(1,frame_len):
175+
velocities=sess.run(P,feed_dict={O:[np.transpose(estimated_data[i-1])],Rr:[Rr_data[0]],Rs:[Rs_data[0]],Ra:[Ra_data[0]],X:[X_data[0]]})[0];
176+
estimated_data[i,:,0]=estimated_data[i-1][:,0];
177+
estimated_data[i,:,3:5]=np.transpose(velocities);
178+
estimated_data[i,:,1:3]=estimated_data[i-1,:,1:3]+estimated_data[i,:,3:5]*0.001;
179+
xy_estimated=estimated_data[:,:,1:3];
180+
print("Video Recording");
181+
make_video(xy_origin,"true.mp4");
182+
make_video(xy_estimated,"modeling.mp4");
183+
163184

164185
def main(_):
165186
FLAGS.log_dir+=str(int(time.time()));

modeling.mp4

44.6 KB
Binary file not shown.

physics_engine.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,27 @@
1919
# 5 features on the state [mass,x,y,x_vel,y_vel]
2020
fea_num=5;
2121
# G
22-
G = 6.67428e-11;
22+
#G = 6.67428e-11;
23+
G=10**5;
2324
# time step
2425
diff_t=0.001;
2526

2627
def init(total_state,n_body,fea_num,orbit):
2728
data=np.zeros((total_state,n_body,fea_num),dtype=float);
2829
if(orbit):
29-
print("Not yet");exit(1);
30+
data[0][0][0]=100;
31+
data[0][0][1:5]=0.0;
32+
for i in range(1,n_body):
33+
data[0][i][0]=np.random.rand()*8.98+0.02;
34+
distance=np.random.rand()*90.0+10.0;
35+
theta=np.random.rand()*360;
36+
theta_rad = pi/2 - radians(theta);
37+
data[0][i][1]=distance*cos(theta_rad);
38+
data[0][i][2]=distance*sin(theta_rad);
39+
data[0][i][3]=-1*data[0][i][2]/norm(data[0][i][1:3])*(G*data[0][0][0]/distance**2)*0.05;
40+
data[0][i][4]=data[0][i][1]/norm(data[0][i][1:3])*(G*data[0][0][0]/distance**2)*0.05;
41+
#data[0][i][3]=np.random.rand()*10.0-5.0;
42+
#data[0][i][4]=np.random.rand()*10.0-5.0;
3043
else:
3144
for i in range(n_body):
3245
data[0][i][0]=np.random.rand()*8.98+0.02;
@@ -69,26 +82,27 @@ def gen(n_body,orbit):
6982
data=init(total_state,n_body,fea_num,orbit);
7083
for i in range(1,total_state):
7184
data[i]=calc(data[i-1],n_body);
85+
print(norm(data[i][0,1:3]-data[i][1,1:3]));
7286
return data;
7387

74-
def make_video(xy):
88+
def make_video(xy,filename):
7589
os.system("rm -rf pics/*");
7690
FFMpegWriter = manimation.writers['ffmpeg']
7791
metadata = dict(title='Movie Test', artist='Matplotlib',
7892
comment='Movie support!')
7993
writer = FFMpegWriter(fps=15, metadata=metadata)
8094
fig = plt.figure()
81-
plt.xlim(-1000, 1000)
82-
plt.ylim(-1000, 1000)
95+
plt.xlim(-200, 200)
96+
plt.ylim(-200, 200)
8397
fig_num=len(xy);
8498
color=['ro','bo','go','ko','yo','mo','co'];
85-
with writer.saving(fig, "video.mp4", len(xy)):
99+
with writer.saving(fig, filename, len(xy)):
86100
for i in range(len(xy)):
87101
for j in range(len(xy[0])):
88102
plt.plot(xy[i,j,1],xy[i,j,0],color[j%len(color)]);
89103
writer.grab_frame();
90104

91105
if __name__=='__main__':
92-
data=gen(6,False);
93-
#xy=data[:,:,1:3];
94-
#make_video(xy);
106+
data=gen(6,True);
107+
xy=data[:,:,1:3];
108+
make_video(xy,"test.mp4");

physics_engine.pyc

100755100644
665 Bytes
Binary file not shown.

test.mp4

139 KB
Binary file not shown.

true.mp4

43.9 KB
Binary file not shown.

0 commit comments

Comments
 (0)