From 072c726f2e0eec5fbe43f864a818753fe1b800ce Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Wed, 4 Sep 2024 08:53:49 +0400 Subject: [PATCH] refactor dynamicrafter (#2358) --- .../dynamicrafter-animating-images.ipynb | 98 +++++++++++++------ 1 file changed, 69 insertions(+), 29 deletions(-) diff --git a/notebooks/dynamicrafter-animating-images/dynamicrafter-animating-images.ipynb b/notebooks/dynamicrafter-animating-images/dynamicrafter-animating-images.ipynb index 22cfdc5abb7..3af46e4b15e 100644 --- a/notebooks/dynamicrafter-animating-images/dynamicrafter-animating-images.ipynb +++ b/notebooks/dynamicrafter-animating-images/dynamicrafter-animating-images.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "id": "a30812de-c46e-44a3-8194-b7f6f0fd4707", "metadata": {}, @@ -63,6 +64,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "f9dc9580-da81-47dd-b5d3-3cafa8f5a4b5", "metadata": {}, @@ -112,6 +114,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "a3c0c659-aad3-4962-8db7-7b123379f01a", "metadata": {}, @@ -205,6 +208,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "be9643c8-a70c-4dba-8259-d4467ae82949", "metadata": {}, @@ -245,6 +249,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "3c63518d-957d-4358-8711-cf6fb935d8be", "metadata": {}, @@ -259,6 +264,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "d9824415cd5b0ffd", "metadata": {}, @@ -276,8 +282,9 @@ "source": [ "from dynamicrafter.lvdm.modules.encoders.condition import FrozenOpenCLIPEmbedder\n", "\n", + "MODEL_DIR = Path(\"models\")\n", "\n", - "COND_STAGE_MODEL_OV_PATH = Path(\"models/cond_stage_model.xml\")\n", + "COND_STAGE_MODEL_OV_PATH = MODEL_DIR / \"cond_stage_model.xml\"\n", "\n", "\n", "class FrozenOpenCLIPEmbedderWrapper(FrozenOpenCLIPEmbedder):\n", @@ -300,6 +307,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "63b361937c948711", "metadata": {}, @@ -316,7 +324,7 @@ "metadata": {}, "outputs": [], "source": [ - "EMBEDDER_OV_PATH = Path(\"models/embedder_ir.xml\")\n", + "EMBEDDER_OV_PATH = MODEL_DIR / \"embedder_ir.xml\"\n", "\n", "\n", "dummy_input = torch.rand([1, 3, 767, 767], dtype=torch.float32)\n", @@ -331,6 +339,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "eef65d17fec62fa", "metadata": {}, @@ -346,7 +355,7 @@ "metadata": {}, "outputs": [], "source": [ - "ENCODER_FIRST_STAGE_OV_PATH = Path(\"models/encoder_first_stage_ir.xml\")\n", + "ENCODER_FIRST_STAGE_OV_PATH = MODEL_DIR / \"encoder_first_stage_ir.xml\"\n", "\n", "\n", "dummy_input = torch.rand([1, 3, 256, 256], dtype=torch.float32)\n", @@ -363,6 +372,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "7ec5ee02317d8e77", "metadata": {}, @@ -378,7 +388,7 @@ "metadata": {}, "outputs": [], "source": [ - "MODEL_OV_PATH = Path(\"models/model_ir.xml\")\n", + "MODEL_OV_PATH = MODEL_DIR / \"model_ir.xml\"\n", "\n", "\n", "class ModelWrapper(torch.nn.Module):\n", @@ -410,6 +420,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "8d5430af-12b4-4a15-bb7c-c9300f824431", "metadata": {}, @@ -446,7 +457,7 @@ "metadata": {}, "outputs": [], "source": [ - "DECODER_FIRST_STAGE_OV_PATH = Path(\"models/decoder_first_stage_ir.xml\")\n", + "DECODER_FIRST_STAGE_OV_PATH = MODEL_DIR / \"decoder_first_stage_ir.xml\"\n", "\n", "\n", "dummy_input = torch.rand([16, 4, 32, 32], dtype=torch.float32)\n", @@ -463,6 +474,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "51ff6eb8-dd58-4820-aae3-85c0b4e487a8", "metadata": {}, @@ -511,14 +523,21 @@ "metadata": {}, "outputs": [], "source": [ - "compiled_cond_stage_model = core.compile_model(core.read_model(COND_STAGE_MODEL_OV_PATH), device.value)\n", - "compiled_encode_first_stage = core.compile_model(core.read_model(ENCODER_FIRST_STAGE_OV_PATH), device.value)\n", - "compiled_embedder = core.compile_model(core.read_model(EMBEDDER_OV_PATH), device.value)\n", - "compiled_model = core.compile_model(core.read_model(MODEL_OV_PATH), device.value)\n", - "compiled_decoder_first_stage = core.compile_model(core.read_model(DECODER_FIRST_STAGE_OV_PATH), device.value)" + "cond_stage_model = core.read_model(COND_STAGE_MODEL_OV_PATH)\n", + "encoder_first_stage = core.read_model(ENCODER_FIRST_STAGE_OV_PATH)\n", + "embedder = core.read_model(EMBEDDER_OV_PATH)\n", + "model_ov = core.read_model(MODEL_OV_PATH)\n", + "decoder_first_stage = core.read_model(DECODER_FIRST_STAGE_OV_PATH)\n", + "\n", + "compiled_cond_stage_model = core.compile_model(cond_stage_model, device.value)\n", + "compiled_encode_first_stage = core.compile_model(encoder_first_stage, device.value)\n", + "compiled_embedder = core.compile_model(embedder, device.value)\n", + "compiled_model = core.compile_model(model_ov, device.value)\n", + "compiled_decoder_first_stage = core.compile_model(decoder_first_stage, device.value)" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "11f2c95b-e872-458b-a6f8-448f8124ffe6", "metadata": {}, @@ -536,12 +555,12 @@ "metadata": {}, "outputs": [], "source": [ + "from typing import Any\n", "import open_clip\n", "\n", "\n", - "class CondStageModelWrapper(torch.nn.Module):\n", + "class CondStageModelWrapper:\n", " def __init__(self, cond_stage_model):\n", - " super().__init__()\n", " self.cond_stage_model = cond_stage_model\n", "\n", " def encode(self, tokens):\n", @@ -552,9 +571,8 @@ " return torch.from_numpy(outs)\n", "\n", "\n", - "class EncoderFirstStageModelWrapper(torch.nn.Module):\n", + "class EncoderFirstStageModelWrapper:\n", " def __init__(self, encode_first_stage):\n", - " super().__init__()\n", " self.encode_first_stage = encode_first_stage\n", "\n", " def forward(self, x):\n", @@ -562,10 +580,12 @@ "\n", " return torch.from_numpy(outs)\n", "\n", + " def __call__(self, *args: Any, **kwargs: Any) -> Any:\n", + " return self.forward(*args, **kwargs)\n", "\n", - "class EmbedderWrapper(torch.nn.Module):\n", + "\n", + "class EmbedderWrapper:\n", " def __init__(self, embedder):\n", - " super().__init__()\n", " self.embedder = embedder\n", "\n", " def forward(self, x):\n", @@ -573,10 +593,12 @@ "\n", " return torch.from_numpy(outs)\n", "\n", + " def __call__(self, *args: Any, **kwargs: Any) -> Any:\n", + " return self.forward(*args, **kwargs)\n", "\n", - "class CModelWrapper(torch.nn.Module):\n", + "\n", + "class CModelWrapper:\n", " def __init__(self, diffusion_model, out_channels):\n", - " super().__init__()\n", " self.diffusion_model = diffusion_model\n", " self.out_channels = out_channels\n", "\n", @@ -591,20 +613,26 @@ "\n", " return torch.from_numpy(outs)\n", "\n", + " def __call__(self, *args: Any, **kwargs: Any) -> Any:\n", + " return self.forward(*args, **kwargs)\n", + "\n", "\n", - "class DecoderFirstStageModelWrapper(torch.nn.Module):\n", + "class DecoderFirstStageModelWrapper:\n", " def __init__(self, decoder_first_stage):\n", - " super().__init__()\n", " self.decoder_first_stage = decoder_first_stage\n", "\n", " def forward(self, x):\n", " x.float()\n", " outs = self.decoder_first_stage(x)[0]\n", "\n", - " return torch.from_numpy(outs)" + " return torch.from_numpy(outs)\n", + "\n", + " def __call__(self, *args: Any, **kwargs: Any) -> Any:\n", + " return self.forward(*args, **kwargs)" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "1178a847-eb14-419b-815e-c47628aa6868", "metadata": {}, @@ -627,6 +655,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "953f6d14", "metadata": {}, @@ -902,6 +931,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "0e8548d7", "metadata": {}, @@ -949,6 +979,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "03c3dfb5", "metadata": {}, @@ -972,11 +1003,17 @@ "open(\"skip_kernel_extension.py\", \"w\").write(r.text)\n", "\n", "int8_model = None\n", + "MODEL_INT8_OV_PATH = MODEL_DIR / \"model_ir_int8.xml\"\n", + "COND_STAGE_MODEL_INT8_OV_PATH = MODEL_DIR / \"cond_stage_model_int8.xml\"\n", + "DECODER_FIRST_STAGE_INT8_OV_PATH = MODEL_DIR / \"decoder_first_stage_ir_int8.xml\"\n", + "ENCODER_FIRST_STAGE_INT8_OV_PATH = MODEL_DIR / \"encoder_first_stage_ir_int8.xml\"\n", + "EMBEDDER_INT8_OV_PATH = MODEL_DIR / \"embedder_ir_int8.xml\"\n", "\n", "%load_ext skip_kernel_extension" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "2c5ff698", "metadata": {}, @@ -1011,6 +1048,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "fb54119d", "metadata": {}, @@ -1107,13 +1145,14 @@ "source": [ "%%skip not $to_quantize.value\n", "\n", - "MODEL_INT8_OV_PATH = Path(\"models/model_ir_int8.xml\")\n", + "\n", "if not MODEL_INT8_OV_PATH.exists():\n", " subset_size = 300\n", " calibration_data = collect_calibration_data(model, subset_size=subset_size)" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "12c23abd", "metadata": {}, @@ -1270,8 +1309,7 @@ "\n", "\n", "if MODEL_INT8_OV_PATH.exists():\n", - " print(\"Loading quantized model\")\n", - " quantized_model = core.read_model(MODEL_INT8_OV_PATH)\n", + " print(\"Model already quantized\")\n", "else:\n", " ov_model_ir = core.read_model(MODEL_OV_PATH)\n", " quantized_model = nncf.quantize(\n", @@ -1289,6 +1327,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "011835c3", "metadata": {}, @@ -1505,11 +1544,6 @@ "source": [ "%%skip not $to_quantize.value\n", "\n", - "COND_STAGE_MODEL_INT8_OV_PATH = Path(\"models/cond_stage_model_int8.xml\")\n", - "DECODER_FIRST_STAGE_INT8_OV_PATH = Path(\"models/decoder_first_stage_ir_int8.xml\")\n", - "ENCODER_FIRST_STAGE_INT8_OV_PATH = Path(\"models/encoder_first_stage_ir_int8.xml\")\n", - "EMBEDDER_INT8_OV_PATH = Path(\"models/embedder_ir_int8.xml\")\n", - "\n", "def compress_model_weights(fp_model_path, int8_model_path):\n", " if not int8_model_path.exists():\n", " model = core.read_model(fp_model_path)\n", @@ -1524,6 +1558,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "6c8d7527", "metadata": {}, @@ -1646,6 +1681,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "81410f5a", "metadata": {}, @@ -1685,6 +1721,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "fa624ee9", "metadata": {}, @@ -1757,6 +1794,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "4417db2b-2f65-407c-a384-ec466f18bca0", "metadata": {}, @@ -1774,6 +1812,8 @@ "metadata": {}, "outputs": [], "source": [ + "from ipywidgets import widgets\n", + "\n", "quantized_models_present = int8_model is not None\n", "\n", "use_quantized_models = widgets.Checkbox(\n",