Skip to content

Commit

Permalink
Make import of external SWIN and EmoNet packages optional
Browse files Browse the repository at this point in the history
  • Loading branch information
radekd91 committed Jan 17, 2022
1 parent d678184 commit b695dcb
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 28 deletions.
10 changes: 8 additions & 2 deletions gdl/layers/losses/EmoNetLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@
from gdl.layers.losses.EmonetLoader import get_emonet
from pathlib import Path
import torch.nn.functional as F
from gdl.models.EmoNetModule import EmoNetModule
from gdl.models.EmoSwinModule import EmoSwinModule
try:
from gdl.models.EmoNetModule import EmoNetModule
except ImportError as e:
print(f"Could not import EmoNetModule. EmoNet models will not be available.")
try:
from gdl.models.EmoSwinModule import EmoSwinModule
except ImportError as e:
print(f"Could not import EmoSwinModule. SWIN models will not be available")
from gdl.models.EmoCnnModule import EmoCnnModule
from gdl.models.IO import get_checkpoint_with_kwargs
from gdl.utils.other import class_from_str
Expand Down
5 changes: 4 additions & 1 deletion gdl/models/DecaEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import torch.nn.functional as F
import gdl.models.ResNet as resnet

from .Swin import create_swin_backbone, swin_cfg_from_name
try:
from .Swin import create_swin_backbone, swin_cfg_from_name
except ImportError as e:
print("SWIN not found, will not be able to use SWIN models")

class BaseEncoder(nn.Module):
def __init__(self, outsize, last_op=None):
Expand Down
38 changes: 19 additions & 19 deletions gdl_apps/EMOCA/demos/test_emoca_on_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,29 +59,29 @@ def main():
## 3) Get the data loadeer with the detected faces
dl = dm.test_dataloader()

# ## 4) Run the model on the data
# for j, batch in enumerate (auto.tqdm( dl)):
## 4) Run the model on the data
for j, batch in enumerate (auto.tqdm( dl)):

# current_bs = batch["image"].shape[0]
# img = batch
# vals, visdict = test(emoca, img)
# for i in range(current_bs):
# # name = f"{(j*batch_size + i):05d}"
# name = batch["image_name"][i]
current_bs = batch["image"].shape[0]
img = batch
vals, visdict = test(emoca, img)
for i in range(current_bs):
# name = f"{(j*batch_size + i):05d}"
name = batch["image_name"][i]

# sample_output_folder = Path(outfolder) /name
# sample_output_folder.mkdir(parents=True, exist_ok=True)
sample_output_folder = Path(outfolder) /name
sample_output_folder.mkdir(parents=True, exist_ok=True)

# if args.save_mesh:
# save_obj(emoca, str(sample_output_folder / "mesh_coarse.obj"), vals, i)
# if args.save_images:
# save_images(outfolder, name, visdict, i)
# if args.save_codes:
# save_codes(Path(outfolder), name, vals, i)
if args.save_mesh:
save_obj(emoca, str(sample_output_folder / "mesh_coarse.obj"), vals, i)
if args.save_images:
save_images(outfolder, name, visdict, i)
if args.save_codes:
save_codes(Path(outfolder), name, vals, i)

# ## 5) Create the reconstruction video (reconstructions overlayed on the original video)
# dm.create_reconstruction_video(0, rec_method=model_name, image_type=image_type, overwrite=True)
# print("Done")
## 5) Create the reconstruction video (reconstructions overlayed on the original video)
dm.create_reconstruction_video(0, rec_method=model_name, image_type=image_type, overwrite=True)
print("Done")


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

from gdl.models.DECA import DECA, ExpDECA, DecaModule
from gdl.models.IO import locate_checkpoint
from gdl.models.EmoNetModule import EmoNetModule
try:
from gdl.models.EmoNetModule import EmoNetModule
except ImportError as e:
print("Skipping EmoNetModule because EmoNet it is not installed")
from gdl.utils.other import class_from_str
import datetime
from pytorch_lightning import Trainer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from gdl.models.IO import locate_checkpoint, get_checkpoint_with_kwargs

from gdl.models.EmoDECA import EmoDECA
from gdl.models.EmoNetModule import EmoNetModule
try:
from gdl.models.EmoNetModule import EmoNetModule
except ImportError as e:
print("Skipping EmoNetModule because EmoNet it is not installed")
from gdl.utils.other import class_from_str
import datetime
from pytorch_lightning import Trainer
Expand Down
10 changes: 8 additions & 2 deletions gdl_apps/EmotionRecognition/training/train_emodeca.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@
from gdl.models.IO import locate_checkpoint, get_checkpoint_with_kwargs

from gdl.models.EmoDECA import EmoDECA
from gdl.models.EmoSwinModule import EmoSwinModule
try:
from gdl.models.EmoSwinModule import EmoSwinModule
except ImportError as e:
print(f"Could not import EmoSwinModule. SWIN models will not be available")
from gdl.models.EmoCnnModule import EmoCnnModule
from gdl.models.EmoNetModule import EmoNetModule
try:
from gdl.models.EmoNetModule import EmoNetModule
except ImportError as e:
print(f"Could not import EmoNet. EmoNet models will not be available")
from gdl.models.EmoMLP import EmoMLP

from gdl.utils.other import class_from_str
Expand Down
10 changes: 8 additions & 2 deletions gdl_apps/EmotionRecognition/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from gdl.models.EmoDECA import EmoDECA
from gdl.models.EmoCnnModule import EmoCnnModule
from gdl.models.EmoSwinModule import EmoSwinModule
from gdl.models.EmoNetModule import EmoNetModule
try:
from gdl.models.EmoSwinModule import EmoSwinModule
except ImportError as e:
print(f"Could not import EmoSwinModule. SWIN models will not be available")
try:
from gdl.models.EmoNetModule import EmoNetModule
except ImportError as e:
print(f"Could not import EmoNetModule. EmoNet models will not be available")
from gdl.models.IO import locate_checkpoint
from gdl.utils.other import class_from_str
from pathlib import Path
Expand Down

0 comments on commit b695dcb

Please sign in to comment.