Skip to content

Commit

Permalink
add image_prompts (ai-forever#2)
Browse files Browse the repository at this point in the history
* add image_prompts
* add image prompts jupyter
  • Loading branch information
oriBetelgeuse authored Nov 2, 2021
1 parent cb886c5 commit 8f34865
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 14 deletions.
253 changes: 253 additions & 0 deletions jupyters/ruDALLE-image-prompts-A100.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions rudalle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .tokenizer import get_tokenizer
from .realesrgan import get_realesrgan
from .ruclip import get_ruclip
from . import vae, dalle, tokenizer, realesrgan, pipelines, ruclip
from . import vae, dalle, tokenizer, realesrgan, pipelines, ruclip, image_prompts


__all__ = [
Expand All @@ -19,6 +19,7 @@
'tokenizer',
'realesrgan',
'pipelines',
'image_prompts',
]

__version__ = '0.0.1-rc1'
__version__ = '0.0.1-rc2'
53 changes: 53 additions & 0 deletions rudalle/image_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
import torch
import numpy as np


class ImagePrompts:

def __init__(self, pil_image, borders, vae, device='cpu', crop_first=False):
"""
Args:
pil_image (PIL.Image): image in PIL format
borders (dict[str] | int): borders that we croped from pil_image
example: {'up': 4, 'right': 0, 'left': 0, 'down': 0} (1 int eq 8 pixels)
vae (VQGanGumbelVAE): VQGAN model for image encoding
device (str): cpu or cuda
crop_first (bool): if True, croped image before VQGAN encoding
"""
self.device = device
img = self._preprocess_img(pil_image)
self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first)

def _preprocess_img(self, pil_img):
img = torch.tensor(np.array(pil_img.convert('RGB')).transpose(2, 0, 1)) / 255.
img = img.unsqueeze(0).to(self.device, dtype=torch.float32)
img = (2 * img) - 1
return img

@staticmethod
def _get_image_prompts(img, borders, vae, crop_first):
if crop_first:
assert borders['right'] + borders['left'] + borders['down'] == 0
up_border = borders['up'] * 8
_, _, [_, _, vqg_img] = vae.model.encode(img[:, :, :up_border, :])
else:
_, _, [_, _, vqg_img] = vae.model.encode(img)

bs, vqg_img_w, vqg_img_h = vqg_img.shape
mask = torch.zeros(vqg_img_w, vqg_img_h)
if borders['up'] != 0:
mask[:borders['up'], :] = 1.
if borders['down'] != 0:
mask[-borders['down']:, :] = 1.
if borders['right'] != 0:
mask[:, :borders['right']] = 1.
if borders['left'] != 0:
mask[:, -borders['left']:] = 1.
mask = mask.reshape(-1).bool()

image_prompts = vqg_img.reshape((bs, -1))
image_prompts_idx = np.arange(vqg_img_w * vqg_img_h)
image_prompts_idx = set(image_prompts_idx[mask])

return image_prompts_idx, image_prompts
34 changes: 22 additions & 12 deletions rudalle/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from . import utils


def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, temperature=1.0, bs=8, seed=None,
use_cache=True):
def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image_prompts=None, temperature=1.0, bs=8,
seed=None, use_cache=True):
# TODO docstring
if seed is not None:
utils.seed_everything(seed)
Expand All @@ -32,16 +32,26 @@ def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, tempe
out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(device)
has_cache = False
sample_scores = []
for _ in tqdm(range(out.shape[1], total_seq_length)):
logits, has_cache = dalle(out, attention_mask,
has_cache=has_cache, use_cache=use_cache, return_loss=False)
logits = logits[:, -1, vocab_size:]
logits /= temperature
filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
sample = torch.multinomial(probs, 1)
sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)])
out = torch.cat((out, sample), dim=-1)
if image_prompts is not None:
prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts
prompts = prompts.repeat(images_num, 1)
if use_cache:
use_cache = False
print('Warning: use_cache changed to False')
for idx in tqdm(range(out.shape[1], total_seq_length)):
idx -= text_seq_length
if image_prompts is not None and idx in prompts_idx:
out = torch.cat((out, prompts[:, idx].unsqueeze(1)), dim=-1)
else:
logits, has_cache = dalle(out, attention_mask,
has_cache=has_cache, use_cache=use_cache, return_loss=False)
logits = logits[:, -1, vocab_size:]
logits /= temperature
filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
sample = torch.multinomial(probs, 1)
sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)])
out = torch.cat((out, sample), dim=-1)
codebooks = out[:, -image_seq_length:]
images = vae.decode(codebooks)
pil_images += utils.torch_tensors_to_pil_list(images)
Expand Down
22 changes: 22 additions & 0 deletions tests/test_image_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
import pytest

from rudalle.image_prompts import ImagePrompts


@pytest.mark.parametrize('borders, crop_first', [
({'up': 4, 'right': 0, 'left': 0, 'down': 0}, False),
({'up': 4, 'right': 0, 'left': 0, 'down': 0}, True),
({'up': 4, 'right': 3, 'left': 3, 'down': 3}, False)
])
def test_image_prompts(sample_image, vae, borders, crop_first):
img = sample_image.copy()
img = img.resize((256, 256))
image_prompt = ImagePrompts(img, borders, vae, crop_first=crop_first)
if crop_first:
assert image_prompt.image_prompts.shape[1] == borders['up'] * 32
assert len(image_prompt.image_prompts_idx) == borders['up'] * 32
else:
assert image_prompt.image_prompts.shape[1] == 32 * 32
assert len(image_prompt.image_prompts_idx) == (borders['up'] + borders['down']) * 32 \
+ (borders['left'] + borders['right']) * (32 - borders['up'] - borders['down'])

0 comments on commit 8f34865

Please sign in to comment.