Skip to content

Commit

Permalink
refactor dynamicrafter (#2358)
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova authored Sep 4, 2024
1 parent a85afef commit 072c726
Showing 1 changed file with 69 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "a30812de-c46e-44a3-8194-b7f6f0fd4707",
"metadata": {},
Expand Down Expand Up @@ -63,6 +64,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f9dc9580-da81-47dd-b5d3-3cafa8f5a4b5",
"metadata": {},
Expand Down Expand Up @@ -112,6 +114,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "a3c0c659-aad3-4962-8db7-7b123379f01a",
"metadata": {},
Expand Down Expand Up @@ -205,6 +208,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "be9643c8-a70c-4dba-8259-d4467ae82949",
"metadata": {},
Expand Down Expand Up @@ -245,6 +249,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3c63518d-957d-4358-8711-cf6fb935d8be",
"metadata": {},
Expand All @@ -259,6 +264,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "d9824415cd5b0ffd",
"metadata": {},
Expand All @@ -276,8 +282,9 @@
"source": [
"from dynamicrafter.lvdm.modules.encoders.condition import FrozenOpenCLIPEmbedder\n",
"\n",
"MODEL_DIR = Path(\"models\")\n",
"\n",
"COND_STAGE_MODEL_OV_PATH = Path(\"models/cond_stage_model.xml\")\n",
"COND_STAGE_MODEL_OV_PATH = MODEL_DIR / \"cond_stage_model.xml\"\n",
"\n",
"\n",
"class FrozenOpenCLIPEmbedderWrapper(FrozenOpenCLIPEmbedder):\n",
Expand All @@ -300,6 +307,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "63b361937c948711",
"metadata": {},
Expand All @@ -316,7 +324,7 @@
"metadata": {},
"outputs": [],
"source": [
"EMBEDDER_OV_PATH = Path(\"models/embedder_ir.xml\")\n",
"EMBEDDER_OV_PATH = MODEL_DIR / \"embedder_ir.xml\"\n",
"\n",
"\n",
"dummy_input = torch.rand([1, 3, 767, 767], dtype=torch.float32)\n",
Expand All @@ -331,6 +339,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "eef65d17fec62fa",
"metadata": {},
Expand All @@ -346,7 +355,7 @@
"metadata": {},
"outputs": [],
"source": [
"ENCODER_FIRST_STAGE_OV_PATH = Path(\"models/encoder_first_stage_ir.xml\")\n",
"ENCODER_FIRST_STAGE_OV_PATH = MODEL_DIR / \"encoder_first_stage_ir.xml\"\n",
"\n",
"\n",
"dummy_input = torch.rand([1, 3, 256, 256], dtype=torch.float32)\n",
Expand All @@ -363,6 +372,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "7ec5ee02317d8e77",
"metadata": {},
Expand All @@ -378,7 +388,7 @@
"metadata": {},
"outputs": [],
"source": [
"MODEL_OV_PATH = Path(\"models/model_ir.xml\")\n",
"MODEL_OV_PATH = MODEL_DIR / \"model_ir.xml\"\n",
"\n",
"\n",
"class ModelWrapper(torch.nn.Module):\n",
Expand Down Expand Up @@ -410,6 +420,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "8d5430af-12b4-4a15-bb7c-c9300f824431",
"metadata": {},
Expand Down Expand Up @@ -446,7 +457,7 @@
"metadata": {},
"outputs": [],
"source": [
"DECODER_FIRST_STAGE_OV_PATH = Path(\"models/decoder_first_stage_ir.xml\")\n",
"DECODER_FIRST_STAGE_OV_PATH = MODEL_DIR / \"decoder_first_stage_ir.xml\"\n",
"\n",
"\n",
"dummy_input = torch.rand([16, 4, 32, 32], dtype=torch.float32)\n",
Expand All @@ -463,6 +474,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "51ff6eb8-dd58-4820-aae3-85c0b4e487a8",
"metadata": {},
Expand Down Expand Up @@ -511,14 +523,21 @@
"metadata": {},
"outputs": [],
"source": [
"compiled_cond_stage_model = core.compile_model(core.read_model(COND_STAGE_MODEL_OV_PATH), device.value)\n",
"compiled_encode_first_stage = core.compile_model(core.read_model(ENCODER_FIRST_STAGE_OV_PATH), device.value)\n",
"compiled_embedder = core.compile_model(core.read_model(EMBEDDER_OV_PATH), device.value)\n",
"compiled_model = core.compile_model(core.read_model(MODEL_OV_PATH), device.value)\n",
"compiled_decoder_first_stage = core.compile_model(core.read_model(DECODER_FIRST_STAGE_OV_PATH), device.value)"
"cond_stage_model = core.read_model(COND_STAGE_MODEL_OV_PATH)\n",
"encoder_first_stage = core.read_model(ENCODER_FIRST_STAGE_OV_PATH)\n",
"embedder = core.read_model(EMBEDDER_OV_PATH)\n",
"model_ov = core.read_model(MODEL_OV_PATH)\n",
"decoder_first_stage = core.read_model(DECODER_FIRST_STAGE_OV_PATH)\n",
"\n",
"compiled_cond_stage_model = core.compile_model(cond_stage_model, device.value)\n",
"compiled_encode_first_stage = core.compile_model(encoder_first_stage, device.value)\n",
"compiled_embedder = core.compile_model(embedder, device.value)\n",
"compiled_model = core.compile_model(model_ov, device.value)\n",
"compiled_decoder_first_stage = core.compile_model(decoder_first_stage, device.value)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "11f2c95b-e872-458b-a6f8-448f8124ffe6",
"metadata": {},
Expand All @@ -536,12 +555,12 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Any\n",
"import open_clip\n",
"\n",
"\n",
"class CondStageModelWrapper(torch.nn.Module):\n",
"class CondStageModelWrapper:\n",
" def __init__(self, cond_stage_model):\n",
" super().__init__()\n",
" self.cond_stage_model = cond_stage_model\n",
"\n",
" def encode(self, tokens):\n",
Expand All @@ -552,31 +571,34 @@
" return torch.from_numpy(outs)\n",
"\n",
"\n",
"class EncoderFirstStageModelWrapper(torch.nn.Module):\n",
"class EncoderFirstStageModelWrapper:\n",
" def __init__(self, encode_first_stage):\n",
" super().__init__()\n",
" self.encode_first_stage = encode_first_stage\n",
"\n",
" def forward(self, x):\n",
" outs = self.encode_first_stage(x)[0]\n",
"\n",
" return torch.from_numpy(outs)\n",
"\n",
" def __call__(self, *args: Any, **kwargs: Any) -> Any:\n",
" return self.forward(*args, **kwargs)\n",
"\n",
"class EmbedderWrapper(torch.nn.Module):\n",
"\n",
"class EmbedderWrapper:\n",
" def __init__(self, embedder):\n",
" super().__init__()\n",
" self.embedder = embedder\n",
"\n",
" def forward(self, x):\n",
" outs = self.embedder(x)[0]\n",
"\n",
" return torch.from_numpy(outs)\n",
"\n",
" def __call__(self, *args: Any, **kwargs: Any) -> Any:\n",
" return self.forward(*args, **kwargs)\n",
"\n",
"class CModelWrapper(torch.nn.Module):\n",
"\n",
"class CModelWrapper:\n",
" def __init__(self, diffusion_model, out_channels):\n",
" super().__init__()\n",
" self.diffusion_model = diffusion_model\n",
" self.out_channels = out_channels\n",
"\n",
Expand All @@ -591,20 +613,26 @@
"\n",
" return torch.from_numpy(outs)\n",
"\n",
" def __call__(self, *args: Any, **kwargs: Any) -> Any:\n",
" return self.forward(*args, **kwargs)\n",
"\n",
"\n",
"class DecoderFirstStageModelWrapper(torch.nn.Module):\n",
"class DecoderFirstStageModelWrapper:\n",
" def __init__(self, decoder_first_stage):\n",
" super().__init__()\n",
" self.decoder_first_stage = decoder_first_stage\n",
"\n",
" def forward(self, x):\n",
" x.float()\n",
" outs = self.decoder_first_stage(x)[0]\n",
"\n",
" return torch.from_numpy(outs)"
" return torch.from_numpy(outs)\n",
"\n",
" def __call__(self, *args: Any, **kwargs: Any) -> Any:\n",
" return self.forward(*args, **kwargs)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "1178a847-eb14-419b-815e-c47628aa6868",
"metadata": {},
Expand All @@ -627,6 +655,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "953f6d14",
"metadata": {},
Expand Down Expand Up @@ -902,6 +931,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "0e8548d7",
"metadata": {},
Expand Down Expand Up @@ -949,6 +979,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "03c3dfb5",
"metadata": {},
Expand All @@ -972,11 +1003,17 @@
"open(\"skip_kernel_extension.py\", \"w\").write(r.text)\n",
"\n",
"int8_model = None\n",
"MODEL_INT8_OV_PATH = MODEL_DIR / \"model_ir_int8.xml\"\n",
"COND_STAGE_MODEL_INT8_OV_PATH = MODEL_DIR / \"cond_stage_model_int8.xml\"\n",
"DECODER_FIRST_STAGE_INT8_OV_PATH = MODEL_DIR / \"decoder_first_stage_ir_int8.xml\"\n",
"ENCODER_FIRST_STAGE_INT8_OV_PATH = MODEL_DIR / \"encoder_first_stage_ir_int8.xml\"\n",
"EMBEDDER_INT8_OV_PATH = MODEL_DIR / \"embedder_ir_int8.xml\"\n",
"\n",
"%load_ext skip_kernel_extension"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "2c5ff698",
"metadata": {},
Expand Down Expand Up @@ -1011,6 +1048,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "fb54119d",
"metadata": {},
Expand Down Expand Up @@ -1107,13 +1145,14 @@
"source": [
"%%skip not $to_quantize.value\n",
"\n",
"MODEL_INT8_OV_PATH = Path(\"models/model_ir_int8.xml\")\n",
"\n",
"if not MODEL_INT8_OV_PATH.exists():\n",
" subset_size = 300\n",
" calibration_data = collect_calibration_data(model, subset_size=subset_size)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "12c23abd",
"metadata": {},
Expand Down Expand Up @@ -1270,8 +1309,7 @@
"\n",
"\n",
"if MODEL_INT8_OV_PATH.exists():\n",
" print(\"Loading quantized model\")\n",
" quantized_model = core.read_model(MODEL_INT8_OV_PATH)\n",
" print(\"Model already quantized\")\n",
"else:\n",
" ov_model_ir = core.read_model(MODEL_OV_PATH)\n",
" quantized_model = nncf.quantize(\n",
Expand All @@ -1289,6 +1327,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "011835c3",
"metadata": {},
Expand Down Expand Up @@ -1505,11 +1544,6 @@
"source": [
"%%skip not $to_quantize.value\n",
"\n",
"COND_STAGE_MODEL_INT8_OV_PATH = Path(\"models/cond_stage_model_int8.xml\")\n",
"DECODER_FIRST_STAGE_INT8_OV_PATH = Path(\"models/decoder_first_stage_ir_int8.xml\")\n",
"ENCODER_FIRST_STAGE_INT8_OV_PATH = Path(\"models/encoder_first_stage_ir_int8.xml\")\n",
"EMBEDDER_INT8_OV_PATH = Path(\"models/embedder_ir_int8.xml\")\n",
"\n",
"def compress_model_weights(fp_model_path, int8_model_path):\n",
" if not int8_model_path.exists():\n",
" model = core.read_model(fp_model_path)\n",
Expand All @@ -1524,6 +1558,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6c8d7527",
"metadata": {},
Expand Down Expand Up @@ -1646,6 +1681,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "81410f5a",
"metadata": {},
Expand Down Expand Up @@ -1685,6 +1721,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "fa624ee9",
"metadata": {},
Expand Down Expand Up @@ -1757,6 +1794,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "4417db2b-2f65-407c-a384-ec466f18bca0",
"metadata": {},
Expand All @@ -1774,6 +1812,8 @@
"metadata": {},
"outputs": [],
"source": [
"from ipywidgets import widgets\n",
"\n",
"quantized_models_present = int8_model is not None\n",
"\n",
"use_quantized_models = widgets.Checkbox(\n",
Expand Down

0 comments on commit 072c726

Please sign in to comment.