Skip to content

Commit

Permalink
load whole dataset in train loop
Browse files Browse the repository at this point in the history
Summary: Loads the whole dataset and moves it to the device and sends it to for sampling to enable full dataset heterogeneous raysampling.

Reviewed By: bottler

Differential Revision: D39263009

fbshipit-source-id: c527537dfc5f50116849656c9e171e868f6845b1
  • Loading branch information
Darijan Gudelj authored and facebook-github-bot committed Oct 3, 2022
1 parent c311a4c commit 37bd280
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
1 change: 1 addition & 0 deletions projects/implicitron_trainer/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def run(self) -> None:
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
# pyre-ignore[6]
train_dataset=datasets.train,
model=model,
optimizer=optimizer,
Expand Down
4 changes: 3 additions & 1 deletion projects/implicitron_trainer/impl/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

from .utils import seed_all_random_engines

Expand All @@ -44,6 +44,7 @@ def run(
train_loader: DataLoader,
val_loader: Optional[DataLoader],
test_loader: Optional[DataLoader],
train_dataset: Dataset,
model: ImplicitronModelBase,
optimizer: torch.optim.Optimizer,
scheduler: Any,
Expand Down Expand Up @@ -116,6 +117,7 @@ def run(
train_loader: DataLoader,
val_loader: Optional[DataLoader],
test_loader: Optional[DataLoader],
train_dataset: Dataset,
model: ImplicitronModelBase,
optimizer: torch.optim.Optimizer,
scheduler: Any,
Expand Down
3 changes: 2 additions & 1 deletion pytorch3d/implicitron/models/generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ def safe_slice_targets(
)

# (1) Sample rendering rays with the ray sampler.
ray_bundle: ImplicitronRayBundle = self.raysampler( # pyre-fixme[29]
# pyre-ignore[29]
ray_bundle: ImplicitronRayBundle = self.raysampler(
target_cameras,
evaluation_mode,
mask=mask_crop[:n_targets]
Expand Down

0 comments on commit 37bd280

Please sign in to comment.