Skip to content

Commit 492f713

Browse files
committed
cleaned up a deprecated code
1 parent 503bd72 commit 492f713

File tree

5 files changed

+12
-355
lines changed

5 files changed

+12
-355
lines changed

src/pytti/LossAug/LossOrchestratorClass.py

Lines changed: 1 addition & 242 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from pytti.image_models import PixelImage, RGBImage
66

7-
# from pytti.LossAug import build_loss
87
from pytti.LossAug import TVLoss, HSVLoss, OpticalFlowLoss, TargetFlowLoss
98
from pytti.Perceptor.Prompt import parse_prompt
109
from pytti.eval_tools import parse_subprompt
@@ -30,7 +29,6 @@ def build_loss(
3029
pil_target: Image,
3130
device=None,
3231
):
33-
# from pytti.LossAug import LOSS_DICT
3432
if device is None:
3533
device = img.device
3634

@@ -39,9 +37,7 @@ def build_loss(
3937
loss = type(img).get_preferred_loss()
4038
else:
4139
loss = LOSS_DICT[weight_name]
42-
# out = Loss.TargetImage(
43-
# f"{weight_name} {name}:{weight}", img.image_shape, pil_target
44-
# )
40+
4541
if pil_target is not None:
4642
resized = pil_target.resize(img.image_shape, Image.LANCZOS)
4743
comp = loss.make_comp(resized, device=device)
@@ -172,247 +168,10 @@ def configure_optical_flows(img, params, loss_augs):
172168
return img, loss_augs, optical_flows
173169

174170

175-
#######################################
176-
177-
178-
# class LossBuilder:
179-
180-
# LOSS_DICT = {"edge": EdgeLoss, "depth": DepthLoss}
181-
182-
# def __init__(self, weight_name, weight, name, img, pil_target):
183-
# self.weight_name = weight_name
184-
# self.weight = weight
185-
# self.name = name
186-
# self.img = img
187-
# self.pil_target = pil_target
188-
189-
# # uh.... should the places this is beind used maybe just use Loss.__init__?
190-
# # TO DO: let's make this a class attribute on something
191-
192-
# @property
193-
# def weight_category(self):
194-
# return self.weight_name.split("_")[0]
195-
196-
# @property
197-
# def loss_factory(self):
198-
# weight_name = self.weight_category
199-
# if weight_name == "direct":
200-
# Loss = type(self.img).get_preferred_loss()
201-
# else:
202-
# Loss = self.LOSS_DICT[weight_name]
203-
# return Loss
204-
205-
# def build_loss(self) -> Loss:
206-
# """
207-
# Given a weight name, weight, name, image, and target image, returns a loss object
208-
209-
# :param weight_name: The name of the loss function
210-
# :param weight: The weight of the loss
211-
# :param name: The name of the loss function
212-
# :param img: The image to be optimized
213-
# :param pil_target: The target image
214-
# :return: The loss function.
215-
# """
216-
# Loss = self.loss_factory
217-
# out = Loss.TargetImage(
218-
# f"{self.weight_category} {self.name}:{self.weight}",
219-
# self.img.image_shape,
220-
# self.pil_target,
221-
# )
222-
# out.set_enabled(self.pil_target is not None)
223-
# return out
224-
225-
226171
def _standardize_null(weight):
227172
weight = str(weight).strip()
228173
if weight in ("", "None"):
229174
weight = "0"
230175
if float(weight) == 0:
231176
weight = ""
232177
return weight
233-
234-
235-
# class LossConfigurator:
236-
# """
237-
# Groups together procedures for initializing losses
238-
# """
239-
240-
# def __init__(
241-
# self,
242-
# init_image_pil: Image.Image,
243-
# restore: bool,
244-
# img: PixelImage,
245-
# embedder,
246-
# prompts,
247-
# # params,
248-
# ########
249-
# direct_image_prompts,
250-
# semantic_stabilization_weight,
251-
# init_image,
252-
# semantic_init_weight,
253-
# animation_mode,
254-
# flow_stabilization_weight,
255-
# flow_long_term_samples,
256-
# smoothing_weight,
257-
# ###########
258-
# direct_init_weight,
259-
# direct_stabilization_weight,
260-
# depth_stabilization_weight,
261-
# edge_stabilization_weight,
262-
# ):
263-
# self.init_image_pil = init_image_pil
264-
# self.img = img
265-
# self.embedder = embedder
266-
# self.prompts = prompts
267-
268-
# self.init_augs = []
269-
# self.loss_augs = []
270-
# self.optical_flows = []
271-
# self.last_frame_semantic = None
272-
# self.semantic_init_prompt = None
273-
274-
# # self.params = params
275-
# self.restore = restore
276-
277-
# ### params
278-
# self.direct_image_prompts = direct_image_prompts
279-
# self.semantic_stabilization_weight = _standardize_null(
280-
# semantic_stabilization_weight
281-
# )
282-
# self.init_image = init_image
283-
# self.semantic_init_weight = _standardize_null(semantic_init_weight)
284-
# self.animation_mode = animation_mode
285-
# self.flow_stabilization_weight = _standardize_null(flow_stabilization_weight)
286-
# self.flow_long_term_samples = flow_long_term_samples
287-
# self.smoothing_weight = _standardize_null(smoothing_weight)
288-
289-
# ######
290-
# self.direct_init_weight = _standardize_null(direct_init_weight)
291-
# self.direct_stabilization_weight = _standardize_null(
292-
# direct_stabilization_weight
293-
# )
294-
# self.depth_stabilization_weight = _standardize_null(depth_stabilization_weight)
295-
# self.edge_stabilization_weight = _standardize_null(edge_stabilization_weight)
296-
297-
# def process_direct_image_prompts(self):
298-
# # prompt parsing shouldn't go here.
299-
# self.loss_augs.extend(
300-
# type(self.img)
301-
# .get_preferred_loss()
302-
# .TargetImage(p.strip(), self.img.image_shape, is_path=True)
303-
# for p in self.direct_image_prompts.split("|")
304-
# if p.strip()
305-
# )
306-
307-
# def process_semantic_stabilization(self):
308-
# last_frame_pil = self.init_image_pil
309-
# if not last_frame_pil:
310-
# last_frame_pil = self.img.decode_image()
311-
# self.last_frame_semantic = parse_prompt(
312-
# self.embedder,
313-
# f"stabilization:{self.semantic_stabilization_weight}",
314-
# last_frame_pil,
315-
# )
316-
# self.last_frame_semantic.set_enabled(self.init_image_pil is not None)
317-
# for scene in self.prompts:
318-
# scene.append(self.last_frame_semantic)
319-
320-
# def configure_losses(self):
321-
# if self.init_image_pil is not None:
322-
# self.configure_init_image()
323-
# self.process_direct_image_prompts()
324-
# if self.semantic_stabilization_weight:
325-
# self.process_semantic_stabilization()
326-
# self.configure_stabilization_augs()
327-
# self.configure_optical_flows()
328-
# self.configure_aesthetic_losses()
329-
330-
# return (
331-
# self.loss_augs,
332-
# self.init_augs,
333-
# self.stabilization_augs,
334-
# self.optical_flows,
335-
# self.semantic_init_prompt,
336-
# self.last_frame_semantic,
337-
# self.img,
338-
# )
339-
340-
# def configure_init_image(self):
341-
342-
# if not self.restore:
343-
# # move these logging statements into .encode_image()
344-
# logger.info("Encoding image...")
345-
# self.img.encode_image(self.init_image_pil)
346-
# logger.info("Encoded Image:")
347-
# # pretty sure this assumes we're in a notebook
348-
# display.display(self.img.decode_image())
349-
350-
# ## wrap this for the flexibility that the loop is pretending to provide...
351-
# # set up init image prompt
352-
# if self.direct_init_weight:
353-
# init_aug = LossBuilder(
354-
# "direct_init_weight",
355-
# self.direct_init_weight,
356-
# f"init image ({self.init_image})",
357-
# self.img,
358-
# self.init_image_pil,
359-
# ).build_loss()
360-
# self.loss_augs.append(init_aug)
361-
# self.init_augs.append(init_aug)
362-
363-
# ########
364-
# if self.semantic_init_weight:
365-
# self.semantic_init_prompt = parse_prompt(
366-
# self.embedder,
367-
# f"init image [{self.init_image}]:{self.semantic_init_weight}",
368-
# self.init_image_pil,
369-
# )
370-
# self.prompts[0].append(self.semantic_init_prompt)
371-
372-
# # stabilization
373-
# def configure_stabilization_augs(self):
374-
# d_augs = {
375-
# "direct_stabilization_weight": self.direct_stabilization_weight,
376-
# "depth_stabilization_weight": self.depth_stabilization_weight,
377-
# "edge_stabilization_weight": self.edge_stabilization_weight,
378-
# }
379-
# stabilization_augs = [
380-
# LossBuilder(
381-
# k, v, "stabilization", self.img, self.init_image_pil
382-
# ).build_loss()
383-
# for k, v in d_augs.items()
384-
# if v
385-
# ]
386-
# self.stabilization_augs = stabilization_augs
387-
# self.loss_augs.extend(stabilization_augs)
388-
389-
# def configure_optical_flows(self):
390-
# optical_flows = None
391-
392-
# if self.animation_mode == "Video Source":
393-
# if self.flow_stabilization_weight == "":
394-
# self.flow_stabilization_weight = "0"
395-
# optical_flows = [
396-
# OpticalFlowLoss.TargetImage(
397-
# f"optical flow stabilization (frame {-2**i}):{self.flow_stabilization_weight}",
398-
# self.img.image_shape,
399-
# )
400-
# for i in range(self.flow_long_term_samples + 1)
401-
# ]
402-
403-
# elif self.animation_mode == "3D" and self.flow_stabilization_weight:
404-
# optical_flows = [
405-
# TargetFlowLoss.TargetImage(
406-
# f"optical flow stabilization:{self.flow_stabilization_weight}",
407-
# self.img.image_shape,
408-
# )
409-
# ]
410-
411-
# if optical_flows is not None:
412-
# for optical_flow in optical_flows:
413-
# optical_flow.set_enabled(False)
414-
# self.loss_augs.extend(optical_flows)
415-
416-
# def configure_aesthetic_losses(self):
417-
# if self.smoothing_weight != 0:
418-
# self.loss_augs.append(TVLoss(weight=self.smoothing_weight))

src/pytti/LossAug/MSELossClass.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from torch.nn import functional as F
55
from pytti.LossAug.BaseLossClass import Loss
66

7-
# from pytti.Notebook import Rotoscoper
87
from pytti.rotoscoper import Rotoscoper
98
from pytti import fetch, vram_usage_mode
10-
from pytti.eval_tools import parse, parse_subprompt
9+
from pytti.eval_tools import parse_subprompt
1110
import torch
1211

1312

src/pytti/LossAug/OpticalFlowLossClass.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@
1616
import gma
1717
from gma.core.network import RAFTGMA
1818

19-
# from gma.core.utils import flow_viz
2019
from gma.core.utils.utils import InputPadder
2120

22-
# from pytti import fetch, to_pil, DEVICE, vram_usage_mode
2321
from pytti import fetch, vram_usage_mode
2422
from pytti.LossAug.MSELossClass import MSELoss
2523
from pytti.rotoscoper import Rotoscoper
@@ -100,7 +98,6 @@ def init_GMA(checkpoint_path=None, device=None):
10098
args = parser.parse_args([])
10199

102100
# create new OrderedDict that does not contain `module.` prefix
103-
# state_dict = torch.load(checkpoint_path)
104101
state_dict = torch.load(checkpoint_path, map_location=device)
105102
from collections import OrderedDict
106103

@@ -110,18 +107,9 @@ def init_GMA(checkpoint_path=None, device=None):
110107
k = k[7:] # remove `module.`
111108
new_state_dict[k] = v
112109

113-
# GMA = torch.nn.DataParallel(RAFTGMA(args), device_ids=[device])
114110
GMA = RAFTGMA(args)
115-
# GMA = torch.nn.parallel.DistributedDataParallel(RAFTGMA(args).to(device), device_ids=[device])
116-
# GMA = RAFTGMA(args)
117-
# GMA.load_state_dict(torch.load(checkpoint_path, map_location=device))
118-
# GMA.load_state_dict(torch.load(checkpoint_path))
119111
GMA.load_state_dict(new_state_dict)
120112
logger.debug("gma state_dict loaded")
121-
###########################
122-
# 1. Fix state dict (remove module prefixes)
123-
# 2. load state dict into model without DataParallel
124-
###########################
125113
GMA.to(device) # redundant?
126114
GMA.eval()
127115

@@ -209,7 +197,6 @@ def get_loss(self, input, img, device=None):
209197
if device is None:
210198
device = getattr(self, "device", self.device)
211199
init_GMA(
212-
# "GMA/checkpoints/gma-sintel.pth"
213200
device=device,
214201
) # update this to use model dir from config
215202
image1 = self.last_step
@@ -220,8 +207,6 @@ def get_loss(self, input, img, device=None):
220207
logger.debug(device)
221208
logger.debug((flow.shape, flow.device))
222209
logger.debug((self.comp.shape, self.comp.device))
223-
# logger.debug(GMA.device) # ugh... I bet this is another dataparallel thing.
224-
# logger.debug(GMA.module.device)
225210
flow = flow.to(device, memory_format=torch.channels_last)
226211
return super().get_loss(TF.resize(flow, self.comp.shape[-2:]), img) / self.mag
227212

@@ -232,9 +217,9 @@ class OpticalFlowLoss(MSELoss):
232217
def motion_edge_map(
233218
flow_forward,
234219
flow_backward,
235-
img, # is this even being used anywhere here?
236-
border_mode="smear",
237-
sampling_mode="bilinear",
220+
img, # unused
221+
border_mode="smear", # unused
222+
sampling_mode="bilinear", # unused
238223
device=None,
239224
):
240225
"""
@@ -325,7 +310,6 @@ def get_flow(image1, image2, device=None):
325310
"""
326311
if device is None:
327312
device = "cuda" if torch.cuda.is_available() else "cpu"
328-
# init_GMA("GMA/checkpoints/gma-sintel.pth")
329313
init_GMA(
330314
device=device,
331315
)
@@ -386,11 +370,9 @@ def set_flow(
386370
)
387371
logger.debug(device)
388372
if path is not None:
389-
# img = img.clone()
390373
img = img.clone().to(device)
391374
if not isinstance(device, torch.device):
392375
device = torch.device(device)
393-
# logger.debug(device)
394376
state_dict = torch.load(path, map_location=device)
395377
img.load_state_dict(state_dict)
396378

@@ -412,13 +394,9 @@ def set_flow(
412394
image1.add_(noise)
413395
image2.add_(noise)
414396

415-
# flow_forward = OpticalFlowLoss.get_flow(image1, image2)
416-
# flow_backward = OpticalFlowLoss.get_flow(image2, image1)
417-
# flow_forward = self.get_flow(image1, image2, device=device)
418-
# flow_backward = self.get_flow(image2, image1, device=device)
419397
flow_forward = OpticalFlowLoss.get_flow(image1, image2, device=device)
420398
flow_backward = OpticalFlowLoss.get_flow(image2, image1, device=device)
421-
unwarped_target_direct = img.decode_tensor()
399+
unwarped_target_direct = img.decode_tensor() # unused
422400
flow_target_direct = apply_flow(
423401
img, -flow_backward, border_mode=border_mode, sampling_mode=sampling_mode
424402
)

0 commit comments

Comments
 (0)