Skip to content

Commit

Permalink
Update propnet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ZongjingLi committed Nov 14, 2023
1 parent 620a834 commit fdff6b0
Showing 1 changed file with 38 additions and 2 deletions.
40 changes: 38 additions & 2 deletions autolearner/model/physics/propnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,32 @@ class PropModule(nn.Module):
def __init__(self, input_dim, output_dim, batch = True, residual = False):
super().__init__()

hidden_dim = 132

self.batch = batch
self.state_dim

self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

self.particle_encoder = ParticleEncoder(
input_dim, hidden_dim, output_dim
)

self.relation_encoder = RelationEncoder(
input_dim, hidden_dim, output_dim
)

# Propagator Modules in Action
self.particle_propagator = Propagator

def forward(self,state):
def forward(self,state, Rs, Rr, Ra, itrs = 3):
"""
Args:
state: input states of the input
Rs: the relation feature sender matrix [B,N,N,1]
Rr: the relation feature reciever matrix [B,N,N,1]
Ra: the relation attribute matrix [B,N,N,D]
"""
B, N, Dx = state.shape
# calculate the particle effect
particle_effect = torch.autograd.Variable(
Expand All @@ -121,7 +141,23 @@ def forward(self,state):

# calculate reciever_states and sender_states
if self.batch:
Rrp = 0
state_r = torch.bmm(Rr, state)
state_s = torch.bmm(Rs, state)
# particle encode
particle_encode = self.particle_encoder(state)

# calculate the relation encode
relation_encode = self.relation_encoder(torch.cat([
state_r, state_s, Ra
], dim = 2))

for i in range(itrs):
if self.batch:
effect_r = torch.bmm(Rr, particle_effect)
effect_s = torch.bmm(Rs, particle_effect)
# calculate the relation effect
#relation_effect = self.relation_propagator()

return state

class PropNet(nn.Module):
Expand Down

0 comments on commit fdff6b0

Please sign in to comment.