@@ -127,43 +127,47 @@ def configure_stabilization_augs(img, init_image_pil, params, loss_augs):
127127
128128
129129def configure_optical_flows (img , params , loss_augs ):
130+ logger .debug (params .device )
131+ _device = params .device
132+ optical_flows = []
130133 if params .animation_mode == "Video Source" :
131134 if params .flow_stabilization_weight == "" :
132135 params .flow_stabilization_weight = "0"
133- # if flow stabilization weight is 0, shouldn't this next block just get skipped?
136+ # TODO: if flow stabilization weight is 0, shouldn't this next block just get skipped?
134137
135138 for i in range (params .flow_long_term_samples + 1 ):
136- name = f"optical flow stabilization (frame { - 2 ** i } )"
137- weight = params .flow_stabilization_weight
138- comp = torch .zeros (1 , 1 , 1 , 1 ) # ,device=device)
139139 optical_flow = OpticalFlowLoss (
140- comp = comp ,
141- weight = weight ,
142- name = f"{ name } (direct)" ,
140+ comp = torch . zeros ( 1 , 1 , 1 , 1 , device = _device ), # ,device=DEVICE)
141+ weight = params . flow_stabilization_weight ,
142+ name = f"optical flow stabilization (frame { - 2 ** i } ) (direct)" ,
143143 image_shape = img .image_shape ,
144+ device = _device ,
144145 ) # , device=device)
145146 optical_flow .set_enabled (False )
146- loss_augs .append (optical_flow )
147+ optical_flows .append (optical_flow )
147148
148149 elif params .animation_mode == "3D" and params .flow_stabilization_weight not in [
149150 "0" ,
150151 "" ,
151152 ]:
152- optical_flows = [
153- TargetFlowLoss .TargetImage (
154- f"optical flow stabilization:{ params .flow_stabilization_weight } " ,
155- img .image_shape ,
156- device = "cuda" ,
157- )
158- ]
159- for optical_flow in optical_flows :
160- optical_flow .set_enabled (False )
161- loss_augs .extend (optical_flows )
162- else :
163- optical_flows = []
153+ optical_flow = TargetFlowLoss (
154+ comp = torch .zeros (1 , 1 , 1 , 1 , device = _device ),
155+ weight = params .flow_stabilization_weight ,
156+ name = "optical flow stabilization (direct)" ,
157+ image_shape = img .image_shape ,
158+ device = _device ,
159+ )
160+ optical_flow .set_enabled (False )
161+ optical_flows .append (optical_flow )
162+
163+ loss_augs .extend (optical_flows )
164+
165+ # this shouldn't be in this function based on the name.
164166 # other loss augs
165167 if params .smoothing_weight != 0 :
166- loss_augs .append (TVLoss (weight = params .smoothing_weight ))
168+ loss_augs .append (
169+ TVLoss (weight = params .smoothing_weight )
170+ ) # , device=params.device))
167171
168172 return img , loss_augs , optical_flows
169173
0 commit comments