Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Nerfacto + Aria #2950

Merged
merged 2 commits into from
Feb 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 42 additions & 14 deletions nerfstudio/data/pixel_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,22 +178,50 @@ def sample_method_fisheye(
if isinstance(mask, torch.Tensor) and not self.config.ignore_mask:
indices = self.sample_method(batch_size, num_images, image_height, image_width, mask=mask, device=device)
else:
rand_samples = torch.rand((batch_size, 3), device=device)
# convert random samples tto radius and theta
radii = self.config.fisheye_crop_radius * torch.sqrt(rand_samples[:, 1])
theta = 2.0 * torch.pi * rand_samples[:, 2]

# convert radius and theta to x and y between -radii and radii
x = radii * torch.cos(theta)
y = radii * torch.sin(theta)
# Rejection sampling.
valid: Optional[torch.Tensor] = None
indices = None
while True:
samples_needed = batch_size if valid is None else int(batch_size - torch.sum(valid).item())

# Check if done!
if samples_needed == 0:
break

rand_samples = torch.rand((samples_needed, 2), device=device)
# Convert random samples to radius and theta.
radii = self.config.fisheye_crop_radius * torch.sqrt(rand_samples[:, 0])
theta = 2.0 * torch.pi * rand_samples[:, 1]

# Convert radius and theta to x and y.
x = (radii * torch.cos(theta) + image_width // 2).long()
y = (radii * torch.sin(theta) + image_height // 2).long()
sampled_indices = torch.stack(
[torch.randint(0, num_images, size=(samples_needed,), device=device), y, x], dim=-1
)

# Multiply by the batch size and height/width to get pixel indices.
indices = torch.floor(
torch.stack([rand_samples[:, 0], y, x], dim=1)
* torch.tensor([num_images, image_height // 2, image_width // 2], device=device)
+ torch.tensor([0, image_height // 2, image_width // 2], device=device)
).long()
# Update indices.
if valid is None:
indices = sampled_indices
valid = (
(sampled_indices[:, 1] >= 0)
& (sampled_indices[:, 1] < image_height)
& (sampled_indices[:, 2] >= 0)
& (sampled_indices[:, 2] < image_width)
)
else:
assert indices is not None
not_valid = ~valid
indices[not_valid, :] = sampled_indices
valid[not_valid] = (
(sampled_indices[:, 1] >= 0)
& (sampled_indices[:, 1] < image_height)
& (sampled_indices[:, 2] >= 0)
& (sampled_indices[:, 2] < image_width)
)
assert indices is not None

assert indices.shape == (batch_size, 3)
return indices

def collate_image_dataset_batch(self, batch: Dict, num_rays_per_batch: int, keep_full_image: bool = False):
Expand Down
Loading