Skip to content

Commit 96eaae2

Browse files
committed
refactor seems sound but emitting a cuda error now
1 parent 85f9cf1 commit 96eaae2

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed

src/pytti/LossAug/LossOrchestratorClass.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -127,43 +127,47 @@ def configure_stabilization_augs(img, init_image_pil, params, loss_augs):
127127

128128

129129
def configure_optical_flows(img, params, loss_augs):
130+
logger.debug(params.device)
131+
_device = params.device
132+
optical_flows = []
130133
if params.animation_mode == "Video Source":
131134
if params.flow_stabilization_weight == "":
132135
params.flow_stabilization_weight = "0"
133-
# if flow stabilization weight is 0, shouldn't this next block just get skipped?
136+
# TODO: if flow stabilization weight is 0, shouldn't this next block just get skipped?
134137

135138
for i in range(params.flow_long_term_samples + 1):
136-
name = f"optical flow stabilization (frame {-2**i})"
137-
weight = params.flow_stabilization_weight
138-
comp = torch.zeros(1, 1, 1, 1) # ,device=device)
139139
optical_flow = OpticalFlowLoss(
140-
comp=comp,
141-
weight=weight,
142-
name=f"{name} (direct)",
140+
comp=torch.zeros(1, 1, 1, 1, device=_device), # ,device=DEVICE)
141+
weight=params.flow_stabilization_weight,
142+
name=f"optical flow stabilization (frame {-2**i}) (direct)",
143143
image_shape=img.image_shape,
144+
device=_device,
144145
) # , device=device)
145146
optical_flow.set_enabled(False)
146-
loss_augs.append(optical_flow)
147+
optical_flows.append(optical_flow)
147148

148149
elif params.animation_mode == "3D" and params.flow_stabilization_weight not in [
149150
"0",
150151
"",
151152
]:
152-
optical_flows = [
153-
TargetFlowLoss.TargetImage(
154-
f"optical flow stabilization:{params.flow_stabilization_weight}",
155-
img.image_shape,
156-
device="cuda",
157-
)
158-
]
159-
for optical_flow in optical_flows:
160-
optical_flow.set_enabled(False)
161-
loss_augs.extend(optical_flows)
162-
else:
163-
optical_flows = []
153+
optical_flow = TargetFlowLoss(
154+
comp=torch.zeros(1, 1, 1, 1, device=_device),
155+
weight=params.flow_stabilization_weight,
156+
name="optical flow stabilization (direct)",
157+
image_shape=img.image_shape,
158+
device=_device,
159+
)
160+
optical_flow.set_enabled(False)
161+
optical_flows.append(optical_flow)
162+
163+
loss_augs.extend(optical_flows)
164+
165+
# this shouldn't be in this function based on the name.
164166
# other loss augs
165167
if params.smoothing_weight != 0:
166-
loss_augs.append(TVLoss(weight=params.smoothing_weight))
168+
loss_augs.append(
169+
TVLoss(weight=params.smoothing_weight)
170+
) # , device=params.device))
167171

168172
return img, loss_augs, optical_flows
169173

tests/test_loss_refactoring.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,16 @@ def test_video_optical_flow():
8282
device: '{TEST_DEVICE}'
8383
"""
8484
run_cfg(cfg_str)
85+
86+
87+
def test_3D_optical_flow():
88+
cfg_str = f"""# @package _global_
89+
scenes: a photograph of an apple
90+
animation_mode: 3D
91+
video_path: {video_fpath}
92+
flow_stabilization_weight: 1
93+
steps_per_frame: 10
94+
steps_per_scene: 150
95+
device: '{TEST_DEVICE}'
96+
"""
97+
run_cfg(cfg_str)

0 commit comments

Comments
 (0)