Skip to content

Commit

Permalink
Anisotropic kernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
RS-Coop committed Dec 14, 2023
1 parent c204ad2 commit 8865734
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions torch_quadconv/quadconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self,*,

#decay parameter
if decay_param == None:
self.decay_param = (self.in_points/16)**2
self.decay_param = (self.in_points/16)**-2
else:
self.decay_param = decay_param

Expand Down Expand Up @@ -150,7 +150,10 @@ def _create_mlp(self, mlp_channels):
z: evaluation locations, [out_points, in_points, spatial_dim]
'''
def _bump_arg(self, z):
return torch.linalg.vector_norm(z, dim=(2), keepdims = True)**4
# return torch.linalg.vector_norm(z, dim=(2), keepdims = True)
# a, b = 2, 0.5
a, b = 1, 1
return torch.sqrt(z[:,:,0]**2/a**2 + z[:,:,1]**2/b**2)

'''
Compute indices associated with non-zero filters.
Expand All @@ -173,8 +176,9 @@ def _compute_eval_indices(self, mesh):

bump_arg = self._bump_arg(locs)

tf_vec = (bump_arg <= 1/self.decay_param).squeeze()
tf_vec = (bump_arg <= self.decay_param).squeeze()
idx = torch.nonzero(tf_vec, as_tuple=False)
print(idx.dtype)

if self.cache:
self.eval_indices = nn.Parameter(idx, requires_grad=False)
Expand All @@ -183,7 +187,7 @@ def _compute_eval_indices(self, mesh):
if self.verbose:
print(f"QuadConv eval_indices: {idx.numel()}")

hist = torch.histc(idx[:,0], bins=self.out_points, min=0, max=self.out_points-1)
hist = torch.histc(idx[:,1].to(torch.float32), bins=self.out_points, min=0, max=self.in_points-1)

print(f"Max support points: {torch.max(hist)}")
print(f"Min support points: {torch.min(hist)}")
Expand Down

0 comments on commit 8865734

Please sign in to comment.