44
55from pytti .image_models import PixelImage , RGBImage
66
7- # from pytti.LossAug import build_loss
87from pytti .LossAug import TVLoss , HSVLoss , OpticalFlowLoss , TargetFlowLoss
98from pytti .Perceptor .Prompt import parse_prompt
109from pytti .eval_tools import parse_subprompt
@@ -30,7 +29,6 @@ def build_loss(
3029 pil_target : Image ,
3130 device = None ,
3231):
33- # from pytti.LossAug import LOSS_DICT
3432 if device is None :
3533 device = img .device
3634
@@ -39,9 +37,7 @@ def build_loss(
3937 loss = type (img ).get_preferred_loss ()
4038 else :
4139 loss = LOSS_DICT [weight_name ]
42- # out = Loss.TargetImage(
43- # f"{weight_name} {name}:{weight}", img.image_shape, pil_target
44- # )
40+
4541 if pil_target is not None :
4642 resized = pil_target .resize (img .image_shape , Image .LANCZOS )
4743 comp = loss .make_comp (resized , device = device )
@@ -172,247 +168,10 @@ def configure_optical_flows(img, params, loss_augs):
172168 return img , loss_augs , optical_flows
173169
174170
175- #######################################
176-
177-
178- # class LossBuilder:
179-
180- # LOSS_DICT = {"edge": EdgeLoss, "depth": DepthLoss}
181-
182- # def __init__(self, weight_name, weight, name, img, pil_target):
183- # self.weight_name = weight_name
184- # self.weight = weight
185- # self.name = name
186- # self.img = img
187- # self.pil_target = pil_target
188-
189- # # uh.... should the places this is beind used maybe just use Loss.__init__?
190- # # TO DO: let's make this a class attribute on something
191-
192- # @property
193- # def weight_category(self):
194- # return self.weight_name.split("_")[0]
195-
196- # @property
197- # def loss_factory(self):
198- # weight_name = self.weight_category
199- # if weight_name == "direct":
200- # Loss = type(self.img).get_preferred_loss()
201- # else:
202- # Loss = self.LOSS_DICT[weight_name]
203- # return Loss
204-
205- # def build_loss(self) -> Loss:
206- # """
207- # Given a weight name, weight, name, image, and target image, returns a loss object
208-
209- # :param weight_name: The name of the loss function
210- # :param weight: The weight of the loss
211- # :param name: The name of the loss function
212- # :param img: The image to be optimized
213- # :param pil_target: The target image
214- # :return: The loss function.
215- # """
216- # Loss = self.loss_factory
217- # out = Loss.TargetImage(
218- # f"{self.weight_category} {self.name}:{self.weight}",
219- # self.img.image_shape,
220- # self.pil_target,
221- # )
222- # out.set_enabled(self.pil_target is not None)
223- # return out
224-
225-
226171def _standardize_null (weight ):
227172 weight = str (weight ).strip ()
228173 if weight in ("" , "None" ):
229174 weight = "0"
230175 if float (weight ) == 0 :
231176 weight = ""
232177 return weight
233-
234-
235- # class LossConfigurator:
236- # """
237- # Groups together procedures for initializing losses
238- # """
239-
240- # def __init__(
241- # self,
242- # init_image_pil: Image.Image,
243- # restore: bool,
244- # img: PixelImage,
245- # embedder,
246- # prompts,
247- # # params,
248- # ########
249- # direct_image_prompts,
250- # semantic_stabilization_weight,
251- # init_image,
252- # semantic_init_weight,
253- # animation_mode,
254- # flow_stabilization_weight,
255- # flow_long_term_samples,
256- # smoothing_weight,
257- # ###########
258- # direct_init_weight,
259- # direct_stabilization_weight,
260- # depth_stabilization_weight,
261- # edge_stabilization_weight,
262- # ):
263- # self.init_image_pil = init_image_pil
264- # self.img = img
265- # self.embedder = embedder
266- # self.prompts = prompts
267-
268- # self.init_augs = []
269- # self.loss_augs = []
270- # self.optical_flows = []
271- # self.last_frame_semantic = None
272- # self.semantic_init_prompt = None
273-
274- # # self.params = params
275- # self.restore = restore
276-
277- # ### params
278- # self.direct_image_prompts = direct_image_prompts
279- # self.semantic_stabilization_weight = _standardize_null(
280- # semantic_stabilization_weight
281- # )
282- # self.init_image = init_image
283- # self.semantic_init_weight = _standardize_null(semantic_init_weight)
284- # self.animation_mode = animation_mode
285- # self.flow_stabilization_weight = _standardize_null(flow_stabilization_weight)
286- # self.flow_long_term_samples = flow_long_term_samples
287- # self.smoothing_weight = _standardize_null(smoothing_weight)
288-
289- # ######
290- # self.direct_init_weight = _standardize_null(direct_init_weight)
291- # self.direct_stabilization_weight = _standardize_null(
292- # direct_stabilization_weight
293- # )
294- # self.depth_stabilization_weight = _standardize_null(depth_stabilization_weight)
295- # self.edge_stabilization_weight = _standardize_null(edge_stabilization_weight)
296-
297- # def process_direct_image_prompts(self):
298- # # prompt parsing shouldn't go here.
299- # self.loss_augs.extend(
300- # type(self.img)
301- # .get_preferred_loss()
302- # .TargetImage(p.strip(), self.img.image_shape, is_path=True)
303- # for p in self.direct_image_prompts.split("|")
304- # if p.strip()
305- # )
306-
307- # def process_semantic_stabilization(self):
308- # last_frame_pil = self.init_image_pil
309- # if not last_frame_pil:
310- # last_frame_pil = self.img.decode_image()
311- # self.last_frame_semantic = parse_prompt(
312- # self.embedder,
313- # f"stabilization:{self.semantic_stabilization_weight}",
314- # last_frame_pil,
315- # )
316- # self.last_frame_semantic.set_enabled(self.init_image_pil is not None)
317- # for scene in self.prompts:
318- # scene.append(self.last_frame_semantic)
319-
320- # def configure_losses(self):
321- # if self.init_image_pil is not None:
322- # self.configure_init_image()
323- # self.process_direct_image_prompts()
324- # if self.semantic_stabilization_weight:
325- # self.process_semantic_stabilization()
326- # self.configure_stabilization_augs()
327- # self.configure_optical_flows()
328- # self.configure_aesthetic_losses()
329-
330- # return (
331- # self.loss_augs,
332- # self.init_augs,
333- # self.stabilization_augs,
334- # self.optical_flows,
335- # self.semantic_init_prompt,
336- # self.last_frame_semantic,
337- # self.img,
338- # )
339-
340- # def configure_init_image(self):
341-
342- # if not self.restore:
343- # # move these logging statements into .encode_image()
344- # logger.info("Encoding image...")
345- # self.img.encode_image(self.init_image_pil)
346- # logger.info("Encoded Image:")
347- # # pretty sure this assumes we're in a notebook
348- # display.display(self.img.decode_image())
349-
350- # ## wrap this for the flexibility that the loop is pretending to provide...
351- # # set up init image prompt
352- # if self.direct_init_weight:
353- # init_aug = LossBuilder(
354- # "direct_init_weight",
355- # self.direct_init_weight,
356- # f"init image ({self.init_image})",
357- # self.img,
358- # self.init_image_pil,
359- # ).build_loss()
360- # self.loss_augs.append(init_aug)
361- # self.init_augs.append(init_aug)
362-
363- # ########
364- # if self.semantic_init_weight:
365- # self.semantic_init_prompt = parse_prompt(
366- # self.embedder,
367- # f"init image [{self.init_image}]:{self.semantic_init_weight}",
368- # self.init_image_pil,
369- # )
370- # self.prompts[0].append(self.semantic_init_prompt)
371-
372- # # stabilization
373- # def configure_stabilization_augs(self):
374- # d_augs = {
375- # "direct_stabilization_weight": self.direct_stabilization_weight,
376- # "depth_stabilization_weight": self.depth_stabilization_weight,
377- # "edge_stabilization_weight": self.edge_stabilization_weight,
378- # }
379- # stabilization_augs = [
380- # LossBuilder(
381- # k, v, "stabilization", self.img, self.init_image_pil
382- # ).build_loss()
383- # for k, v in d_augs.items()
384- # if v
385- # ]
386- # self.stabilization_augs = stabilization_augs
387- # self.loss_augs.extend(stabilization_augs)
388-
389- # def configure_optical_flows(self):
390- # optical_flows = None
391-
392- # if self.animation_mode == "Video Source":
393- # if self.flow_stabilization_weight == "":
394- # self.flow_stabilization_weight = "0"
395- # optical_flows = [
396- # OpticalFlowLoss.TargetImage(
397- # f"optical flow stabilization (frame {-2**i}):{self.flow_stabilization_weight}",
398- # self.img.image_shape,
399- # )
400- # for i in range(self.flow_long_term_samples + 1)
401- # ]
402-
403- # elif self.animation_mode == "3D" and self.flow_stabilization_weight:
404- # optical_flows = [
405- # TargetFlowLoss.TargetImage(
406- # f"optical flow stabilization:{self.flow_stabilization_weight}",
407- # self.img.image_shape,
408- # )
409- # ]
410-
411- # if optical_flows is not None:
412- # for optical_flow in optical_flows:
413- # optical_flow.set_enabled(False)
414- # self.loss_augs.extend(optical_flows)
415-
416- # def configure_aesthetic_losses(self):
417- # if self.smoothing_weight != 0:
418- # self.loss_augs.append(TVLoss(weight=self.smoothing_weight))
0 commit comments