-
Notifications
You must be signed in to change notification settings - Fork 101
/
loss.py
47 lines (37 loc) · 1.99 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Author: Yash Bhalgat
from math import exp, log, floor
import torch
import torch.nn.functional as F
import pdb
from utils import hash
def total_variation_loss(embeddings, min_resolution, max_resolution, level, log2_hashmap_size, n_levels=16):
# Get resolution
b = exp((log(max_resolution)-log(min_resolution))/(n_levels-1))
resolution = torch.tensor(floor(min_resolution * b**level))
# Cube size to apply TV loss
min_cube_size = min_resolution - 1
max_cube_size = 50 # can be tuned
if min_cube_size > max_cube_size:
print("ALERT! min cuboid size greater than max!")
pdb.set_trace()
cube_size = torch.floor(torch.clip(resolution/10.0, min_cube_size, max_cube_size)).int()
# Sample cuboid
min_vertex = torch.randint(0, resolution-cube_size, (3,))
idx = min_vertex + torch.stack([torch.arange(cube_size+1) for _ in range(3)], dim=-1)
cube_indices = torch.stack(torch.meshgrid(idx[:,0], idx[:,1], idx[:,2]), dim=-1)
hashed_indices = hash(cube_indices, log2_hashmap_size)
cube_embeddings = embeddings(hashed_indices)
#hashed_idx_offset_x = hash(idx+torch.tensor([1,0,0]), log2_hashmap_size)
#hashed_idx_offset_y = hash(idx+torch.tensor([0,1,0]), log2_hashmap_size)
#hashed_idx_offset_z = hash(idx+torch.tensor([0,0,1]), log2_hashmap_size)
# Compute loss
#tv_x = torch.pow(embeddings(hashed_idx)-embeddings(hashed_idx_offset_x), 2).sum()
#tv_y = torch.pow(embeddings(hashed_idx)-embeddings(hashed_idx_offset_y), 2).sum()
#tv_z = torch.pow(embeddings(hashed_idx)-embeddings(hashed_idx_offset_z), 2).sum()
tv_x = torch.pow(cube_embeddings[1:,:,:,:]-cube_embeddings[:-1,:,:,:], 2).sum()
tv_y = torch.pow(cube_embeddings[:,1:,:,:]-cube_embeddings[:,:-1,:,:], 2).sum()
tv_z = torch.pow(cube_embeddings[:,:,1:,:]-cube_embeddings[:,:,:-1,:], 2).sum()
return (tv_x + tv_y + tv_z)/cube_size
def sigma_sparsity_loss(sigmas):
# Using Cauchy Sparsity loss on sigma values
return torch.log(1.0 + 2*sigmas**2).sum(dim=-1)