Skip to content

Commit 78cf53b

Browse files
author
i338425
committed
In Progress
0 parents  commit 78cf53b

File tree

3 files changed

+206
-0
lines changed

3 files changed

+206
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
###Tensorflow Implementation of Interaction Networks
2+
----

interaction_network.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import argparse
6+
import sys
7+
8+
import tensorflow as tf
9+
10+
import numpy as np
11+
import time
12+
13+
FLAGS = None
14+
15+
def m(O,Rr,Rs,Ra):
16+
return tf.concat([tf.matmul(O,Rr),tf.matmul(O,Rs),Ra],0);
17+
18+
19+
def phi_R(B):
20+
h_size=150;
21+
B_trans=tf.transpose(B);
22+
w1 = tf.Variable(tf.truncated_normal([(2*FLAGS.Ds+FLAGS.Dr), h_size], stddev=0.1), name="r_w1", dtype=tf.float32);
23+
b1 = tf.Variable(tf.zeros([h_size]), name="r_b1", dtype=tf.float32);
24+
h1 = tf.nn.relu(tf.matmul(B_trans, w1) + b1);
25+
w2 = tf.Variable(tf.truncated_normal([h_size, h_size], stddev=0.1), name="r_w2", dtype=tf.float32);
26+
b2 = tf.Variable(tf.zeros([h_size]), name="r_b2", dtype=tf.float32);
27+
h2 = tf.nn.relu(tf.matmul(h1, w2) + b2);
28+
w3 = tf.Variable(tf.truncated_normal([h_size, h_size], stddev=0.1), name="r_w3", dtype=tf.float32);
29+
b3 = tf.Variable(tf.zeros([h_size]), name="r_b3", dtype=tf.float32);
30+
h3 = tf.nn.relu(tf.matmul(h2, w3) + b3);
31+
w4 = tf.Variable(tf.truncated_normal([h_size, h_size], stddev=0.1), name="r_w4", dtype=tf.float32);
32+
b4 = tf.Variable(tf.zeros([h_size]), name="r_b4", dtype=tf.float32);
33+
h4 = tf.nn.relu(tf.matmul(h3, w4) + b4);
34+
w5 = tf.Variable(tf.truncated_normal([h_size, FLAGS.De], stddev=0.1), name="r_w5", dtype=tf.float32);
35+
b5 = tf.Variable(tf.zeros([FLAGS.De]), name="r_b5", dtype=tf.float32);
36+
h4 = tf.matmul(h4, w5) + b5;
37+
h4_trans=tf.transpose(h4);
38+
return(h4_trans);
39+
40+
def a(O,Rr,X,E):
41+
E_bar=tf.matmul(E,tf.transpose(Rr));
42+
return (tf.concat([O,X,E_bar],0));
43+
44+
def phi_O(C):
45+
h_size=100;
46+
C_trans=tf.transpose(C);
47+
w1 = tf.Variable(tf.truncated_normal([(FLAGS.Ds+FLAGS.Dx+FLAGS.De), h_size], stddev=0.1), name="o_w1", dtype=tf.float32);
48+
b1 = tf.Variable(tf.zeros([h_size]), name="o_b1", dtype=tf.float32);
49+
h1 = tf.nn.relu(tf.matmul(C_trans, w1) + b1);
50+
w2 = tf.Variable(tf.truncated_normal([h_size, FLAGS.Dp], stddev=0.1), name="o_w1", dtype=tf.float32);
51+
b2 = tf.Variable(tf.zeros([FLAGS.Dp]), name="o_b1", dtype=tf.float32);
52+
h2 = tf.matmul(h1, w2) + b2;
53+
h2_trans=tf.transpose(h2);
54+
return(h2_trans);
55+
56+
def phi_A(P):
57+
h_size=25;
58+
p_bar=tf.reduce_sum(P,1);
59+
w1 = tf.Variable(tf.truncated_normal([FLAGS.Dp, h_size], stddev=0.1), name="a_w1", dtype=tf.float32);
60+
b1 = tf.Variable(tf.zeros([h_size]), name="a_b1", dtype=tf.float32);
61+
h1 = tf.nn.relu(tf.matmul([p_bar], w1) + b1);
62+
w2 = tf.Variable(tf.truncated_normal([h_size, FLAGS.Da], stddev=0.1), name="a_w2", dtype=tf.float32);
63+
b2 = tf.Variable(tf.zeros([FLAGS.Da]), name="a_b2", dtype=tf.float32);
64+
h2 = tf.matmul(h1, w2) + b2;
65+
return(h1);
66+
67+
def train():
68+
"""
69+
# Object Matrix
70+
O=np.zeros((FLAGS.Ds,FLAGS.No),dtype=float);
71+
# Relation Matrics R=<Rr,Rs,Ra>
72+
R=np.zeros(3,dtype=object);
73+
R[0]=np.zeros((FLAGS.No,FLAGS.Nr),dtype=float);
74+
R[1]=np.zeros((FLAGS.No,FLAGS.Nr),dtype=float);
75+
R[2]=np.zeros((FLAGS.Dr,FLAGS.Nr),dtype=float);
76+
# External Effects
77+
X=np.zeros((FLAGS.Dx,FLAGS.No),dtype=float);
78+
79+
# marshalling function, m(G)=B, G=<O,R>
80+
B=m(O,R);
81+
"""
82+
83+
# Object Matrix
84+
O = tf.placeholder(tf.float32, [FLAGS.Ds,FLAGS.No], name="O");
85+
# Relation Matrics R=<Rr,Rs,Ra>
86+
Rr = tf.placeholder(tf.float32, [FLAGS.No,FLAGS.Nr], name="Rr");
87+
Rs = tf.placeholder(tf.float32, [FLAGS.No,FLAGS.Nr], name="Rs");
88+
Ra = tf.placeholder(tf.float32, [FLAGS.Dr,FLAGS.Nr], name="Ra");
89+
# External Effects
90+
X = tf.placeholder(tf.float32, [FLAGS.Dx,FLAGS.No], name="X");
91+
92+
# marshalling function, m(G)=B, G=<O,R>
93+
B=m(O,Rr,Rs,Ra);
94+
95+
# relational modeling phi_R(B)=E
96+
E=phi_R(B);
97+
98+
# aggregator
99+
C=a(O,Rr,X,E);
100+
101+
# object modeling phi_O(C)=P
102+
P=phi_O(C);
103+
104+
# abstract modeling phi_A(P)=q
105+
q=phi_A(P);
106+
print(q);exit(1);
107+
108+
def main(_):
109+
FLAGS.log_dir+=str(int(time.time()));
110+
if tf.gfile.Exists(FLAGS.log_dir):
111+
tf.gfile.DeleteRecursively(FLAGS.log_dir)
112+
tf.gfile.MakeDirs(FLAGS.log_dir)
113+
train()
114+
115+
116+
if __name__ == '__main__':
117+
parser = argparse.ArgumentParser()
118+
parser.add_argument('--log_dir', type=str, default='/tmp/interaction-network/',
119+
help='Summaries log directry')
120+
parser.add_argument('--Ds', type=int, default=5,
121+
help='The Number of State')
122+
parser.add_argument('--No', type=int, default=5,
123+
help='The Number of Objects')
124+
parser.add_argument('--Nr', type=int, default=5,
125+
help='The Number of Relations')
126+
parser.add_argument('--Dr', type=int, default=5,
127+
help='The Relationship Dimension')
128+
parser.add_argument('--Dx', type=int, default=3,
129+
help='The External Effect Dimension')
130+
parser.add_argument('--De', type=int, default=50,
131+
help='The Effect Dimension')
132+
parser.add_argument('--Dp', type=int, default=2,
133+
help='The Object Modeling Output Dimension')
134+
parser.add_argument('--Da', type=int, default=1,
135+
help='The Abstract Modeling Output Dimension')
136+
FLAGS, unparsed = parser.parse_known_args()
137+
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

physics_engine.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import argparse
6+
import sys
7+
8+
import numpy as np
9+
import time
10+
from math import sin, cos, radians, pi
11+
12+
# 1000 one-millisecond time steps
13+
total_state=1000;
14+
# 5 features on the state [mass,x,y,x_vel,y_vel]
15+
fea_num=5;
16+
# G
17+
#G = 6.67428e-11;
18+
G=10;
19+
# time step
20+
diff_t=0.000001;
21+
22+
def init(total_state,n_body,fea_num,orbit):
23+
data=np.zeros((total_state,n_body,fea_num),dtype=float);
24+
if(orbit):
25+
print("Not yet");exit(1);
26+
else:
27+
for i in range(n_body):
28+
data[0][i][0]=np.random.rand()*8.98+0.02;
29+
distance=np.random.rand()*90.0+10.0;
30+
theta=np.random.rand()*360;
31+
theta_rad = pi/2 - radians(theta);
32+
data[0][i][1]=distance*cos(theta_rad);
33+
data[0][i][2]=distance*sin(theta_rad);
34+
data[0][i][3]=np.random.rand()*6.0-3.0;
35+
data[0][i][4]=np.random.rand()*6.0-3.0;
36+
return data;
37+
38+
def get_f(reciever,sender):
39+
diff=reciever[1:3]-sender[1:3];
40+
return G*reciever[0]*sender[0]*diff/(np.linalg.norm(diff)**3);
41+
42+
def calc(cur_state,n_body):
43+
next_state=np.zeros((n_body,fea_num),dtype=float);
44+
f_mat=np.zeros((n_body,n_body,2),dtype=float);
45+
f_sum=np.zeros((n_body,2),dtype=float);
46+
acc=np.zeros((n_body,2),dtype=float);
47+
for i in range(n_body):
48+
for j in range(i+1,n_body):
49+
if(j!=i):
50+
f_mat[i,j]+=get_f(cur_state[i][:3],cur_state[j][:3]);
51+
f_mat[j,i]-=f_mat[i,j];
52+
f_sum[i]=np.sum(f_mat[i,:]);
53+
acc[i]=f_sum[i]/cur_state[i][0];
54+
next_state[i][0]=cur_state[i][0];
55+
next_state[i][3:5]=cur_state[i][3:5]+acc[i]*diff_t;
56+
next_state[i][1:3]=cur_state[i][1:3]+next_state[i][3:5]*diff_t;
57+
return next_state;
58+
59+
def gen(n_body,orbit):
60+
# initialization on just first state
61+
data=init(total_state,n_body,fea_num,orbit);
62+
for i in range(1,total_state):
63+
data[i]=calc(data[i-1],n_body);
64+
return data;
65+
66+
if __name__=='__main__':
67+
gen(3,False);

0 commit comments

Comments
 (0)