Skip to content

Commit

Permalink
Release v1.1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
ZexinHe committed Mar 13, 2024
1 parent 168d6c7 commit b362d58
Show file tree
Hide file tree
Showing 14 changed files with 1,200 additions and 3 deletions.
23 changes: 22 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

## News

- [2024.03.13] Update [training code](openlrm/runners/train) and release [OpenLRM v1.1.1](https://github.com/3DTopia/OpenLRM/releases/tag/v1.1.1).
- [2024.03.08] We have released the core [blender script](scripts/data/objaverse/blender_script.py) used to render Objaverse images.
- [2024.03.05] The [Huggingface demo](https://huggingface.co/spaces/zxhezexin/OpenLRM) now uses `openlrm-mix-base-1.1` model by default. Please refer to the [model card](model_card.md) for details on the updated model architecture and training settings.
- [2024.03.04] Version update v1.1. Release model weights trained on both Objaverse and MVImgNet. Codebase is majorly refactored for better usability and extensibility. Please refer to [v1.1.0](https://github.com/3DTopia/OpenLRM/releases/tag/v1.1.0) for details.
Expand Down Expand Up @@ -89,7 +90,27 @@ Model cards with additional details can be found in [model_card.md](model_card.m
- You should be able to see `UserWarning: xFormers is available` if `xFormers` is actually working.

## Training
To be released soon.

### Configuration
- We provide a sample accelerate config file under `configs/accelerate-train.yaml`, which defaults to use 8 GPUs with `bf16` mixed precision.
- You may modify the configuration file to fit your own environment.

### Data Preparation
- We provide the core [Blender script](scripts/data/objaverse/blender_script.py) used to render Objaverse images.
- Please refer to [Objaverse Rendering](https://github.com/allenai/objaverse-rendering) for other scripts including distributed rendering.

### Run Training
- A sample training config file is provided under `configs/train-sample.yaml`.
- Please replace data related paths in the config file with your own paths and customize the training settings.
- An example training usage is as follows:

```
# Example usage
ACC_CONFIG="./configs/accelerate-train.yaml"
TRAIN_CONFIG="./configs/train-sample.yaml"
accelerate launch --config_file $ACC_CONFIG -m openlrm.launch train.lrm --config $TRAIN_CONFIG
```

## Acknowledgement

Expand Down
16 changes: 16 additions & 0 deletions configs/accelerate-train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
91 changes: 91 additions & 0 deletions configs/train-sample.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@

experiment:
type: lrm
seed: 42
parent: lrm-objaverse
child: small-dummyrun

model:
camera_embed_dim: 1024
rendering_samples_per_ray: 96
transformer_dim: 512
transformer_layers: 12
transformer_heads: 8
triplane_low_res: 32
triplane_high_res: 64
triplane_dim: 32
encoder_type: dinov2
encoder_model_name: dinov2_vits14_reg
encoder_feat_dim: 384
encoder_freeze: false

dataset:
subsets:
- name: objaverse
root_dirs:
- <REPLACE_WITH_RENDERING_ROOT>
meta_path:
train: <TRAIN_UIDS_IN_JSON>
val: <VAL_UIDS_IN_JSON>
sample_rate: 1.0
sample_side_views: 3
source_image_res: 224
render_image:
low: 64
high: 192
region: 64
normalize_camera: true
normed_dist_to_center: auto
num_train_workers: 4
num_val_workers: 2
pin_mem: true

train:
mixed_precision: bf16 # REPLACE THIS BASED ON GPU TYPE
find_unused_parameters: false
loss:
pixel_weight: 1.0
perceptual_weight: 1.0
tv_weight: 5e-4
optim:
lr: 4e-4
weight_decay: 0.05
beta1: 0.9
beta2: 0.95
clip_grad_norm: 1.0
scheduler:
type: cosine
warmup_real_iters: 3000
batch_size: 16 # REPLACE THIS (PER GPU)
accum_steps: 1 # REPLACE THIS
epochs: 60 # REPLACE THIS
debug_global_steps: null

val:
batch_size: 4
global_step_period: 1000
debug_batches: null

saver:
auto_resume: true
load_model: null
checkpoint_root: ./exps/checkpoints
checkpoint_global_steps: 1000
checkpoint_keep_level: 5

logger:
stream_level: WARNING
log_level: INFO
log_root: ./exps/logs
tracker_root: ./exps/trackers
enable_profiler: false
trackers:
- tensorboard
image_monitor:
train_global_steps: 100
samples_per_log: 4

compile:
suppress_errors: true
print_specializations: true
disable: true
2 changes: 1 addition & 1 deletion openlrm/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.


# from .mixer import MixerDataset
from .mixer import MixerDataset
26 changes: 26 additions & 0 deletions openlrm/datasets/cam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,32 @@ def decompose_extrinsic_RT(E: torch.Tensor):
return E[:, :3, :]


def camera_normalization_objaverse(normed_dist_to_center, poses: torch.Tensor, ret_transform: bool = False):
assert normed_dist_to_center is not None
pivotal_pose = compose_extrinsic_RT(poses[:1])
dist_to_center = pivotal_pose[:, :3, 3].norm(dim=-1, keepdim=True).item() \
if normed_dist_to_center == 'auto' else normed_dist_to_center

# compute camera norm (new version)
canonical_camera_extrinsics = torch.tensor([[
[1, 0, 0, 0],
[0, 0, -1, -dist_to_center],
[0, 1, 0, 0],
[0, 0, 0, 1],
]], dtype=torch.float32)
pivotal_pose_inv = torch.inverse(pivotal_pose)
camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv)

# normalize all views
poses = compose_extrinsic_RT(poses)
poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses)
poses = decompose_extrinsic_RT(poses)

if ret_transform:
return poses, camera_norm_matrix.squeeze(dim=0)
return poses


def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
"""
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
Expand Down
69 changes: 69 additions & 0 deletions openlrm/datasets/mixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2023-2024, Zexin He
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from functools import partial
import torch

__all__ = ['MixerDataset']


class MixerDataset(torch.utils.data.Dataset):

def __init__(self,
split: str,
subsets: list[dict],
**dataset_kwargs,
):
self.subsets = [
self._dataset_fn(subset, split)(**dataset_kwargs)
for subset in subsets
]
self.virtual_lens = [
math.ceil(subset_config['sample_rate'] * len(subset_obj))
for subset_config, subset_obj in zip(subsets, self.subsets)
]

@staticmethod
def _dataset_fn(subset_config: dict, split: str):
name = subset_config['name']

dataset_cls = None
if name == "objaverse":
from .objaverse import ObjaverseDataset
dataset_cls = ObjaverseDataset
# elif name == 'mvimgnet':
# from .mvimgnet import MVImgNetDataset
# dataset_cls = MVImgNetDataset
else:
raise NotImplementedError(f"Dataset {name} not implemented")

return partial(
dataset_cls,
root_dirs=subset_config['root_dirs'],
meta_path=subset_config['meta_path'][split],
)

def __len__(self):
return sum(self.virtual_lens)

def __getitem__(self, idx):
subset_idx = 0
virtual_idx = idx
while virtual_idx >= self.virtual_lens[subset_idx]:
virtual_idx -= self.virtual_lens[subset_idx]
subset_idx += 1
real_idx = virtual_idx % len(self.subsets[subset_idx])
return self.subsets[subset_idx][real_idx]
125 changes: 125 additions & 0 deletions openlrm/datasets/objaverse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2023-2024, Zexin He
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
from typing import Union
import random
import numpy as np
import torch
from megfile import smart_path_join, smart_open

from .base import BaseDataset
from .cam_utils import build_camera_standard, build_camera_principle, camera_normalization_objaverse
from ..utils.proxy import no_proxy

__all__ = ['ObjaverseDataset']


class ObjaverseDataset(BaseDataset):

def __init__(self, root_dirs: list[str], meta_path: str,
sample_side_views: int,
render_image_res_low: int, render_image_res_high: int, render_region_size: int,
source_image_res: int, normalize_camera: bool,
normed_dist_to_center: Union[float, str] = None, num_all_views: int = 32):
super().__init__(root_dirs, meta_path)
self.sample_side_views = sample_side_views
self.render_image_res_low = render_image_res_low
self.render_image_res_high = render_image_res_high
self.render_region_size = render_region_size
self.source_image_res = source_image_res
self.normalize_camera = normalize_camera
self.normed_dist_to_center = normed_dist_to_center
self.num_all_views = num_all_views

@staticmethod
def _load_pose(file_path):
pose = np.load(smart_open(file_path, 'rb'))
pose = torch.from_numpy(pose).float()
return pose

@no_proxy
def inner_get_item(self, idx):
"""
Loaded contents:
rgbs: [M, 3, H, W]
poses: [M, 3, 4], [R|t]
intrinsics: [3, 2], [[fx, fy], [cx, cy], [weight, height]]
"""
uid = self.uids[idx]
root_dir = self._locate_datadir(self.root_dirs, uid, locator="intrinsics.npy")

pose_dir = os.path.join(root_dir, uid, 'pose')
rgba_dir = os.path.join(root_dir, uid, 'rgba')
intrinsics_path = os.path.join(root_dir, uid, 'intrinsics.npy')

# load intrinsics
intrinsics = np.load(smart_open(intrinsics_path, 'rb'))
intrinsics = torch.from_numpy(intrinsics).float()

# sample views (incl. source view and side views)
sample_views = np.random.choice(range(self.num_all_views), self.sample_side_views + 1, replace=False)
poses, rgbs, bg_colors = [], [], []
source_image = None
for view in sample_views:
pose_path = smart_path_join(pose_dir, f'{view:03d}.npy')
rgba_path = smart_path_join(rgba_dir, f'{view:03d}.png')
pose = self._load_pose(pose_path)
bg_color = random.choice([0.0, 0.5, 1.0])
rgb = self._load_rgba_image(rgba_path, bg_color=bg_color)
poses.append(pose)
rgbs.append(rgb)
bg_colors.append(bg_color)
if source_image is None:
source_image = self._load_rgba_image(rgba_path, bg_color=1.0)
assert source_image is not None, "Really bad luck!"
poses = torch.stack(poses, dim=0)
rgbs = torch.cat(rgbs, dim=0)

if self.normalize_camera:
poses = camera_normalization_objaverse(self.normed_dist_to_center, poses)

# build source and target camera features
source_camera = build_camera_principle(poses[:1], intrinsics.unsqueeze(0)).squeeze(0)
render_camera = build_camera_standard(poses, intrinsics.repeat(poses.shape[0], 1, 1))

# adjust source image resolution
source_image = torch.nn.functional.interpolate(
source_image, size=(self.source_image_res, self.source_image_res), mode='bicubic', align_corners=True).squeeze(0)
source_image = torch.clamp(source_image, 0, 1)

# adjust render image resolution and sample intended rendering region
render_image_res = np.random.randint(self.render_image_res_low, self.render_image_res_high + 1)
render_image = torch.nn.functional.interpolate(
rgbs, size=(render_image_res, render_image_res), mode='bicubic', align_corners=True)
render_image = torch.clamp(render_image, 0, 1)
anchors = torch.randint(
0, render_image_res - self.render_region_size + 1, size=(self.sample_side_views + 1, 2))
crop_indices = torch.arange(0, self.render_region_size, device=render_image.device)
index_i = (anchors[:, 0].unsqueeze(1) + crop_indices).view(-1, self.render_region_size, 1)
index_j = (anchors[:, 1].unsqueeze(1) + crop_indices).view(-1, 1, self.render_region_size)
batch_indices = torch.arange(self.sample_side_views + 1, device=render_image.device).view(-1, 1, 1)
cropped_render_image = render_image[batch_indices, :, index_i, index_j].permute(0, 3, 1, 2)

return {
'uid': uid,
'source_camera': source_camera,
'render_camera': render_camera,
'source_image': source_image,
'render_image': cropped_render_image,
'render_anchors': anchors,
'render_full_resolutions': torch.tensor([[render_image_res]], dtype=torch.float32).repeat(self.sample_side_views + 1, 1),
'render_bg_colors': torch.tensor(bg_colors, dtype=torch.float32).unsqueeze(-1),
}
2 changes: 1 addition & 1 deletion openlrm/runners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@

REGISTRY_RUNNERS = Registry()

# from .train import *
from .train import *
from .infer import *
Loading

0 comments on commit b362d58

Please sign in to comment.