Skip to content

working on nerfstudio 1.1.4 #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ __pycache__/
*.so
/segmentation/*
/registration/*
/datasets/*

debug/nerf_output.txt

Expand All @@ -32,6 +33,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
.vscode/

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ Download pretrained weights
cd .. # Download into grounded_sam
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
wget https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth
```

Install SAM-HQ
Expand Down
40 changes: 5 additions & 35 deletions fruit_nerf/data/fruit_datamanager.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,21 @@
"""
Fruit tamanager.
FruitDataManager implementation.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import (
Dict,
Literal,
Optional,
Tuple,
Type,
Union,
)
from typing import Dict, Literal, Optional, Tuple, Type, Union

import torch
from torch.nn import Parameter

from typing_extensions import TypeVar

from nerfstudio.cameras.rays import RayBundle

from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.pixel_samplers import (
PixelSampler,
)

from nerfstudio.data.pixel_samplers import PixelSampler
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig

from fruit_nerf.components.ray_generators import OrthographicRayGenerator
from fruit_nerf.data.fruit_dataset import FruitDataset

Expand Down Expand Up @@ -80,21 +69,14 @@ def sample_surface_points(aabb, n, device, noise=False):
Returns:
torch tensor: Tensor of shape (num_points, 3) containing the sampled 3D coordinates.
"""
# select three corners (must be adjacent!)
corner_1 = aabb[0] # x
corner_2 = aabb[1] # y
corner_3 = aabb[2] # z

# Check if elements are to far away (check if adjacent)
# assert torch.abs(torch.sum(corner_1 - corner_2)) == 2.0
# assert torch.abs(torch.sum(corner_1 - corner_3)) == 2.0

dx_y_z = torch.abs(torch.max(aabb, axis=0).values - torch.min(aabb, axis=0).values)

# Part where the coordinate does not change
constant_axis_part_pos = int(torch.argmax(torch.logical_and((corner_1 == corner_2), (corner_2 == corner_3)).to(int)))

# Generate meshgrid along XY plane
start_x_pos = torch.argmax(torch.abs(corner_1 - corner_2))
x = torch.linspace(corner_1[start_x_pos], corner_2[start_x_pos],
int(dx_y_z[0] / dx_y_z[constant_axis_part_pos] * n), dtype=torch.float32, device=device)
Expand All @@ -104,13 +86,11 @@ def sample_surface_points(aabb, n, device, noise=False):

xx, yy = torch.meshgrid(x, y)

# Flatten the meshgrid and set Z coordinate to the minimum Z value of the AABB
surface_points = torch.column_stack(
(xx.flatten(),
yy.flatten(),
torch.full_like(xx.flatten(), corner_3[constant_axis_part_pos])))

# Convert to torch tensor
surface_points_tensor = surface_points.clone()

corner_4 = aabb[-1]
Expand All @@ -122,17 +102,7 @@ def sample_surface_points(aabb, n, device, noise=False):


class FruitDataManager(VanillaDataManager):
"""Basic stored data manager implementation.

This is pretty much a port over from our old dataloading utilities, and is a little jank
under the hood. We may clean this up a little bit under the hood with more standard dataloading
components that can be strung together, but it can be just used as a black box for now since
only the constructor is likely to change in the future, or maybe passing in step number to the
next_train and next_eval functions.

Args:
config: the DataManagerConfig used to instantiate class
"""
"""FruitDataManager implementation."""

config: FruitDataManagerConfig
train_dataset: TDataset
Expand Down
6 changes: 3 additions & 3 deletions fruit_nerf/fruit_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from nerfstudio.field_components.mlp import MLP
from nerfstudio.field_components.spatial_distortions import SpatialDistortion
from nerfstudio.fields.base_field import Field, shift_directions_for_tcnn
from nerfstudio.fields.base_field import Field, get_normalized_directions

from fruit_nerf.components.field_heads import SemanticFieldHead

Expand Down Expand Up @@ -206,7 +206,7 @@ def get_inference_outputs(
outputs[FieldHeadNames.SEMANTICS] = self.field_head_semantics(x)

if render_rgb:
directions = shift_directions_for_tcnn(ray_samples.frustums.directions)
directions = get_normalized_directions(ray_samples.frustums.directions)
directions_flat = directions.view(-1, 3)
d = self.direction_encoding(directions_flat)
outputs_shape = ray_samples.frustums.directions.shape[:-1]
Expand Down Expand Up @@ -241,7 +241,7 @@ def get_outputs(
if ray_samples.camera_indices is None:
raise AttributeError("Camera indices are not provided.")
camera_indices = ray_samples.camera_indices.squeeze()
directions = shift_directions_for_tcnn(ray_samples.frustums.directions)
directions = get_normalized_directions(ray_samples.frustums.directions)
directions_flat = directions.view(-1, 3)
d = self.direction_encoding(directions_flat)

Expand Down
Loading