Skip to content

Commit

Permalink
add different sampling strategies
Browse files Browse the repository at this point in the history
  • Loading branch information
salykova committed Jun 18, 2021
1 parent 7ba04d0 commit 5d36c05
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 25 deletions.
7 changes: 4 additions & 3 deletions configs/lego.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ half_res = True

obs_img_num = 15
dil_iter = 1
kernel_size = 3
kernel_size = 5
start_pose_num = 27
batch_size = 2050
lrate = 0.01
batch_size = 2048
lrate = 0.01
sampling_strategy = interest_regions
66 changes: 47 additions & 19 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def run():
kernel_size = args.kernel_size
lrate = args.lrate
dataset_type = args.dataset_type
sampling_strategy = args.sampling_strategy

# Load and pre-process the observed image
# obs_img - rgb image with elements in range 0...255
Expand All @@ -50,10 +51,28 @@ def run():
show_img("Observed image", obs_img)

# Find points of interest
POI = find_POI(obs_img, DEBUG) # xy pixel coordinates of interest points (N x 2)
POI = find_POI(obs_img, DEBUG) # xy pixel coordinates of points of interest (N x 2)

# Filter out the points that do not fit sample region
POI = [point for point in POI if ((((point[0]-kernel_size) >= 0) and ((point[1]-kernel_size) >= 0 )) and (((point[0]+kernel_size) <= (H-1)) and ((point[1]+kernel_size) <= (W-1))))]
POI = np.array(POI).astype(int)
POI_filtered = [point for point in POI if ((((point[0]-kernel_size) >= 0) and ((point[1]-kernel_size) >= 0 )) and (((point[0]+kernel_size) <= (H-1)) and ((point[1]+kernel_size) <= (W-1))))]
POI_filtered = np.array(POI_filtered).astype(int)

coords = np.asarray(np.stack(np.meshgrid(np.linspace(0, W - 1, W), np.linspace(0, H - 1, H)), -1),
dtype=np.int)

# create sampling masks
masks = np.zeros((len(POI_filtered), kernel_size, kernel_size, 2), dtype=np.int)
step = int(kernel_size / 2)
region_size = kernel_size ** 2
for i in range(len(masks)):
masks[i] = coords[POI_filtered[i][0] - step: POI_filtered[i][0] + step + 1,
POI_filtered[i][1] - step: POI_filtered[i][1] + step + 1]
masks = masks.reshape((masks.shape[0]*region_size), 2)

# not_POI contains all points except of POI
coords = coords.reshape(H * W, 2)
not_POI = set(tuple(point) for point in coords) - set(tuple(point) for point in POI)
not_POI = np.array([list(point) for point in not_POI]).astype(int)

# Dilate the observed image
I = args.dil_iter
Expand All @@ -62,9 +81,7 @@ def run():
show_img("Dilated image", dil_obs_img)

dil_obs_img = (np.array(dil_obs_img) / 255.).astype(np.float32)

region_size = kernel_size ** 2
num_regions = int(batch_size / region_size)
obs_img = (np.array(obs_img) / 255.).astype(np.float32)

# Load NeRF Model
render_kwargs = load_nerf(args, device)
Expand All @@ -81,22 +98,33 @@ def run():
testsavedir = os.path.join(output_dir, model_name)
os.makedirs(testsavedir, exist_ok=True)


for k in range(300):

batch = np.zeros((num_regions, kernel_size, kernel_size, 2), dtype=np.int)
rand_inds = np.random.choice(POI.shape[0], size=[num_regions], replace=False) # (N_rand,)
rand_coords = POI[rand_inds]
step = int(kernel_size / 2)
coords = np.asarray(np.stack(np.meshgrid(np.linspace(0, W - 1, W), np.linspace(0, H - 1, H)), -1),
dtype=np.int) # (H, W, 2)
for i in range(len(rand_coords)):
batch[i] = coords[rand_coords[i][0] - step: rand_coords[i][0] + step + 1,
rand_coords[i][1] - step: rand_coords[i][1] + step + 1]
batch = batch.reshape((batch.shape[0] * region_size, 2)) # (num_regions * region_size, 2)
target_s = dil_obs_img[batch[:, 0], batch[:, 1]]
target_s = torch.Tensor(target_s).to(device)
if sampling_strategy == 'random':
rand_inds = np.random.choice(coords.shape[0], size=batch_size, replace=False)
batch = coords[rand_inds]
target_s = obs_img[batch[:, 0], batch[:, 1]]

elif sampling_strategy == 'interest_points':
if POI.shape[0] >= batch_size:
rand_inds = np.random.choice(POI.shape[0], size=batch_size, replace=False)
batch = POI[rand_inds]
else:
batch = np.zeros((batch_size, 2), dtype=np.int)
batch[:POI.shape[0]] = POI
rand_inds = np.random.choice(not_POI.shape[0], size=batch_size-POI.shape[0], replace=False)
batch[POI.shape[0]:] = not_POI[rand_inds]
target_s = obs_img[batch[:, 0], batch[:, 1]]

elif sampling_strategy == 'interest_regions':
rand_inds = np.random.choice(masks.shape[0], size=batch_size, replace=False)
batch = masks[rand_inds]
target_s = dil_obs_img[batch[:, 0], batch[:, 1]]
else:
print('Unknown sampling strategy')
return

target_s = torch.Tensor(target_s).to(device)
pose = cam_transf(start_pose)

rays_o, rays_d = get_rays(H, W, focal, pose) # (H, W, 3), (H, W, 3)
Expand Down
10 changes: 7 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ def config_parser():
parser.add_argument("--dataset_type", type=str, default='llff',
help='options: llff / blender / deepvoxels')

# blender flags
# blender options
parser.add_argument("--white_bkgd", action='store_true',
help='set to render synthetic data on a white bkgd (always use for dvoxels)')
parser.add_argument("--half_res", action='store_true',
help='load blender synthetic data at 400x400 instead of 800x800')
parser.add_argument("--lindisp", action='store_true',
help='sampling linearly in disparity rather than depth')

# llff flags
# llff options
parser.add_argument("--llffhold", type=int, default=8,
help='will take every 1/N images as LLFF test set, paper uses 8')
parser.add_argument("--factor", type=int, default=8,
Expand All @@ -93,7 +93,8 @@ def config_parser():
help='Number of sampled rays per gradient step')
parser.add_argument("--lrate", type=float, default=0.01,
help='Initial learning rate')

parser.add_argument("--sampling_strategy", type=str, default='random',
help='options: random / interest_point / interest_region')
"""
# llff flags
parser.add_argument("--factor", type=int, default=8,
Expand Down Expand Up @@ -162,6 +163,9 @@ def find_POI(img_rgb, DEBUG=False): # img - RGB image in range 0...255
show_img("Detected points", img)
xy = [keypoint.pt for keypoint in keypoints]
xy = np.array(xy).astype(int)
# Remove duplicate points
xy_set = set(tuple(point) for point in xy)
xy = np.array([list(point) for point in xy_set]).astype(int)
return xy # pixel coordinates

# Misc
Expand Down

0 comments on commit 5d36c05

Please sign in to comment.