Skip to content

Commit

Permalink
image-to-3d
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Apr 17, 2023
1 parent ec08c0d commit d24abd1
Show file tree
Hide file tree
Showing 32 changed files with 153 additions and 143 deletions.
Binary file added data/cactus.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/cactus_depth.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/cactus_rgba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/cake.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/cake_depth.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/cake_rgba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/catstatue.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/catstatue_depth.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/catstatue_rgba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/firekeeper.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/firekeeper_depth.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/firekeeper_rgba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/fox.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/fox_depth.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/fox_rgba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/hamburger.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/hamburger_depth.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/hamburger_rgba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 1 addition & 5 deletions guidance/sd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_text_embeds(self, prompt):
return embeddings


def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, grad_clip=None, grad_scale=1):
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1):

if as_latent:
latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
Expand Down Expand Up @@ -161,11 +161,7 @@ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=Fa

# w(t), sigma_t^2
w = (1 - self.alphas[t])
# w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
grad = grad_scale * w * (noise_pred - noise)

if grad_clip is not None:
grad = grad.clamp(-grad_clip, grad_clip)
grad = torch.nan_to_num(grad)

# since we omitted an item in grad, we need to use the custom function to specify the gradient
Expand Down
7 changes: 2 additions & 5 deletions guidance/zero123_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_img_embeds(self, x):
v = self.model.encode_first_stage(x).mode()
return c, v

def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scale=3, as_latent=False, grad_clip=None):
def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scale=3, as_latent=False, grad_scale=1):
# pred_rgb: tensor [1, 3, H, W] in [-1, 1]

if as_latent:
Expand Down Expand Up @@ -134,10 +134,7 @@ def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scal
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

w = (1 - self.alphas[t])
grad = w * (noise_pred - noise)

if grad_clip is not None:
grad = grad.clamp(-grad_clip, grad_clip)
grad = grad_scale * w * (noise_pred - noise)
grad = torch.nan_to_num(grad)

# import kiui
Expand Down
73 changes: 54 additions & 19 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
parser.add_argument('--warmup_iters', type=int, default=2000, help="training iters that only use albedo shading")
parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
parser.add_argument('--uniform_sphere_rate', type=float, default=0, help="likelihood of sampling camera location uniformly on the sphere surface area")
parser.add_argument('--grad_clip', type=float, default=1, help="clip grad of all grad to this limit, negative value disables it")
parser.add_argument('--grad_clip', type=float, default=-1, help="clip grad of all grad to this limit, negative value disables it")
parser.add_argument('--grad_clip_rgb', type=float, default=-1, help="clip grad of rgb space grad to this limit, negative value disables it")
# model options
parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)")
Expand All @@ -68,6 +68,7 @@
parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")
parser.add_argument('--known_view_scale', type=float, default=1.5, help="multiply --h/w by this for known view rendering")
parser.add_argument('--known_view_noise_scale', type=float, default=2e-3, help="random camera noise added to rays_o and rays_d")
parser.add_argument('--dmtet_reso_scale', type=float, default=8, help="multiply --h/w by this for dmtet finetuning")

### dataset options
Expand All @@ -86,6 +87,7 @@
parser.add_argument('--default_fovy', type=float, default=60, help="fovy for the default view")

parser.add_argument('--progressive_view', action='store_true', help="progressively expand view sampling range from default to full")
parser.add_argument('--progressive_level', action='store_true', help="progressively increase gridencoder's max_level")

parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region")
parser.add_argument('--angle_front', type=float, default=60, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
Expand All @@ -95,16 +97,17 @@
parser.add_argument('--lambda_entropy', type=float, default=1e-3, help="loss scale for alpha entropy")
parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value")
parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")
parser.add_argument('--lambda_tv', type=float, default=1e-7, help="loss scale for total variation")
parser.add_argument('--lambda_tv', type=float, default=0, help="loss scale for total variation")
parser.add_argument('--lambda_wd', type=float, default=0, help="loss scale")
parser.add_argument('--lambda_normal_smooth', type=float, default=0, help="loss scale for 2D normal image smoothness")

parser.add_argument('--lambda_normal', type=float, default=0.5, help="loss scale for mesh normal smoothness")
parser.add_argument('--lambda_lap', type=float, default=0.5, help="loss scale for mesh laplacian")

parser.add_argument('--lambda_guidance', type=float, default=1, help="loss scale for SDS")
parser.add_argument('--lambda_rgb', type=float, default=10, help="loss scale for RGB")
parser.add_argument('--lambda_mask', type=float, default=10, help="loss scale for mask (alpha)")
parser.add_argument('--lambda_depth', type=float, default=0.1, help="loss scale for relative depth")
parser.add_argument('--lambda_mask', type=float, default=1, help="loss scale for mask (alpha)")
parser.add_argument('--lambda_depth', type=float, default=1, help="loss scale for relative depth")

### GUI options
parser.add_argument('--gui', action='store_true', help="start a GUI")
Expand All @@ -126,34 +129,59 @@
opt.fp16 = True
opt.backbone = 'vanilla'

# image-conditioned generation
# parameters for image-conditioned generation
if opt.image is not None:
opt.text = None # disable text-condition
opt.guidance = 'zero123' # use zero123 guidance model
opt.guidance_scale = 3
opt.warmup_iters = 0
# opt.t_range = [0.02, 0.98] # need to do more exps...

opt.uniform_sphere_rate = 0 # do not do this as it disturbs the progressive view expansion
opt.jitter_pose = False # must not jitter view target
opt.fovy_range = [60, 60] # fix fov as zero123 doesn't support changing fov
opt.progressive_view = True
if opt.text is None:
# use zero123 guidance model when only providing image
opt.guidance = 'zero123'
opt.fovy_range = [opt.default_fovy, opt.default_fovy] # fix fov as zero123 doesn't support changing fov

# very important to keep the image's content
opt.guidance_scale = 3
opt.lambda_guidance = 0.01
opt.grad_clip = 1

else:
# use stable-diffusion when providing both text and image
opt.guidance = 'stable-diffusion'

opt.guidance_scale = 100
opt.lambda_guidance = 0.1

opt.lambda_normal_smooth = 50
# enforce surface smoothness in nerf stage
opt.lambda_normal_smooth = 1
opt.lambda_orient = 100

# latent warmup is not needed, we hardcode a 100-iter rgbd loss only warmup.
opt.warmup_iters = 0

# make shape init more stable
opt.progressive_view = True
opt.progressive_level = True

# default parameters for finetuning
if opt.dmtet:
opt.h = int(opt.h * opt.dmtet_reso_scale)
opt.w = int(opt.w * opt.dmtet_reso_scale)

opt.t_range = [0.02, 0.50] # ref: magic3D

# assume finetuning
opt.warmup_iters = 0
opt.t_range = [0.02, 0.50]
opt.progressive_view = False
opt.progressive_level = False

if opt.image is None:
opt.fovy_range = [30, 70] # smaller fovy (zoom in) for better details
if opt.guidance != 'zero123':
# smaller fovy (zoom in) for better details
opt.fovy_range = [opt.fovy_range[0] - 10, opt.fovy_range[1] - 10]

# record full range for progressive view expansion
if opt.progressive_view:
# disable as they disturb progressive view
opt.jitter_pose = False
opt.uniform_sphere_rate = 0
# back up full range
opt.full_radius_range = opt.radius_range
opt.full_theta_range = opt.theta_range
opt.full_phi_range = opt.phi_range
Expand Down Expand Up @@ -240,7 +268,7 @@
else:
raise NotImplementedError(f'--guidance {opt.guidance} is not implemented.')

trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval, scheduler_update_every_step=True)
trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval, scheduler_update_every_step=True)

trainer.default_view_data = train_loader._data.get_default_view_data()

Expand All @@ -254,3 +282,10 @@

max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
trainer.train(train_loader, valid_loader, max_epoch)

# also test at the end
test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader()
trainer.test(test_loader)

if opt.save_mesh:
trainer.save_mesh()
9 changes: 0 additions & 9 deletions nerf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,6 @@ def __init__(self,
else:
self.bg_net = None

def density_blob(self, x):
# x: [B, N, 3]

d = (x ** 2).sum(-1)
# g = self.opt.blob_density * torch.exp(- d / (self.opt.blob_radius ** 2))
g = self.opt.blob_density * (1 - torch.sqrt(d) / self.opt.blob_radius)

return g

def common_forward(self, x):
# x: [N, 3], in [-bound, bound]

Expand Down
11 changes: 0 additions & 11 deletions nerf/network_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,6 @@ def __init__(self,
else:
self.bg_net = None

# add a density blob to the scene center
@torch.no_grad()
def density_blob(self, x):
# x: [B, N, 3]

d = (x ** 2).sum(-1)
# g = self.opt.blob_density * torch.exp(- d / (self.opt.blob_radius ** 2))
g = self.opt.blob_density * (1 - torch.sqrt(d) / self.opt.blob_radius)

return g

def common_forward(self, x):

# sigma
Expand Down
10 changes: 0 additions & 10 deletions nerf/network_grid_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,6 @@ def __init__(self,
else:
self.bg_net = None

# add a density blob to the scene center
def density_blob(self, x):
# x: [B, N, 3]

d = (x ** 2).sum(-1)
# g = self.opt.blob_density * torch.exp(- d / (self.opt.blob_radius ** 2))
g = self.opt.blob_density * (1 - torch.sqrt(d) / self.opt.blob_radius)

return g

def common_forward(self, x):

# sigma
Expand Down
15 changes: 12 additions & 3 deletions nerf/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,16 +335,25 @@ def __init__(self, opt):
self.mean_density = 0
self.iter_density = 0

@torch.no_grad()
def density_blob(self, x):
# x: [B, N, 3]

d = (x ** 2).sum(-1)

if self.opt.density_activation == 'exp':
g = self.opt.blob_density * torch.exp(- d / (2 * self.opt.blob_radius ** 2))
else:
g = self.opt.blob_density * (1 - torch.sqrt(d) / self.opt.blob_radius)

return g

def forward(self, x, d):
raise NotImplementedError()

def density(self, x):
raise NotImplementedError()

def color(self, x, d, mask=None, **kwargs):
raise NotImplementedError()

def reset_extra_state(self):
if not (self.cuda_ray or self.taichi_ray):
return
Expand Down
Loading

0 comments on commit d24abd1

Please sign in to comment.