Skip to content

Commit

Permalink
Merge pull request ai-forever#3 from sberbank-ai/feature/ruclip
Browse files Browse the repository at this point in the history
ruclip
  • Loading branch information
shonenkov authored Nov 2, 2021
2 parents 6348c77 + 69a8ee4 commit 3987b78
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 5 deletions.
9 changes: 5 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
taming-transformers==0.0.1
more_itertools==8.10.0
transformers==4.10.2
youtokentome==1.0.6
einops==0.3.2
more_itertools~=8.10.0
transformers~=4.10.2
youtokentome~=1.0.6
omegaconf>=2.0.0
einops~=0.3.2
torch
torchvision
matplotlib
7 changes: 6 additions & 1 deletion rudalle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
from .dalle import get_rudalle_model
from .tokenizer import get_tokenizer
from .realesrgan import get_realesrgan
from . import vae, dalle, tokenizer, realesrgan, pipelines
from .ruclip import get_ruclip
from . import vae, dalle, tokenizer, realesrgan, pipelines, ruclip


__all__ = [
'get_vae',
'get_rudalle_model',
'get_tokenizer',
'get_realesrgan',
'get_ruclip',
'vae',
'dalle',
'ruclip',
'tokenizer',
'realesrgan',
'pipelines',
]

__version__ = '0.0.1-rc1'
16 changes: 16 additions & 0 deletions rudalle/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ def super_resolution(pil_images, realesrgan):
return result


def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu', count=4):
with torch.no_grad():
inputs = ruclip_processor(text=text, images=pil_images)
for key in inputs.keys():
inputs[key] = inputs[key].to(device)
outputs = ruclip(**inputs)
sims = outputs.logits_per_image.view(-1).softmax(dim=0)
items = []
for index, sim in enumerate(sims.cpu().numpy()):
items.append({'img_index': index, 'cosine': sim})
items = sorted(items, key=lambda x: x['cosine'], reverse=True)[:count]
top_pil_images = [pil_images[x['img_index']] for x in items]
top_scores = [x['cosine'] for x in items]
return top_pil_images, top_scores


def show(pil_images, nrow=4):
imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
if not isinstance(imgs, list):
Expand Down
30 changes: 30 additions & 0 deletions rudalle/ruclip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
import os

from transformers import CLIPModel
from huggingface_hub import hf_hub_url, cached_download

from .processor import RuCLIPProcessor

MODELS = {
'ruclip-vit-base-patch32-v5': dict(
repo_id='sberbank-ai/ru-clip',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
}


def get_ruclip(name, cache_dir='/tmp/rudalle'):
assert name in MODELS
config = MODELS[name]
repo_id = config['repo_id']
cache_dir = os.path.join(cache_dir, name)
for filename in config['filenames']:
config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{name}/{filename}')
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename)
ruclip = CLIPModel.from_pretrained(cache_dir)
ruclip_processor = RuCLIPProcessor.from_pretrained(cache_dir)
print('ruclip --> ready')
return ruclip, ruclip_processor
68 changes: 68 additions & 0 deletions rudalle/ruclip/processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
import os
import json
import torch
import youtokentome as yttm
import torchvision.transforms as T
from torch.nn.utils.rnn import pad_sequence


class RuCLIPProcessor:
eos_id = 3
bos_id = 2
unk_id = 1
pad_id = 0

def __init__(self, tokenizer_path, image_size=224, text_seq_length=76, mean=None, std=None):

self.tokenizer = yttm.BPE(tokenizer_path)
self.mean = mean or [0.485, 0.456, 0.406]
self.std = std or [0.229, 0.224, 0.225]
self.image_transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)),
T.ToTensor(),
T.Normalize(mean=self.mean, std=self.std)
])
self.text_seq_length = text_seq_length
self.image_size = image_size

def encode_text(self, text):
text = text.lower()
tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0]
tokens = [self.bos_id] + tokens + [self.eos_id]
tokens = tokens[:self.text_seq_length]
mask = [1] * len(tokens)
return torch.tensor(tokens).long(), torch.tensor(mask).long()

def decode_text(self, encoded):
return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[
self.eos_id, self.bos_id, self.unk_id, self.pad_id
])[0]

def __call__(self, text=None, images=None, **kwargs):
inputs = {}
if text is not None:
input_ids, masks = [], []
texts = [text] if isinstance(text, str) else text
for text in texts:
tokens, mask = self.encode_text(text)
input_ids.append(tokens)
masks.append(mask)
inputs['input_ids'] = pad_sequence(input_ids, batch_first=True)
inputs['attention_mask'] = pad_sequence(masks, batch_first=True)
if images is not None:
pixel_values = []
for i, image in enumerate(images):
pixel_values.append(self.image_transform(image))
inputs['pixel_values'] = pad_sequence(pixel_values, batch_first=True)
return inputs

@classmethod
def from_pretrained(cls, folder):
tokenizer_path = os.path.join(folder, 'bpe.model')
config = json.load(open(os.path.join(folder, 'config.json')))
image_size = config['vision_config']['image_size']
text_seq_length = config['text_config']['max_position_embeddings'] - 1
mean, std = config.get('mean'), config.get('std')
return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std)
56 changes: 56 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
import os
import re
from setuptools import setup


def read(filename):
with open(os.path.join(os.path.dirname(__file__), filename)) as f:
file_content = f.read()
return file_content


def get_requirements():
requirements = []
for requirement in read('requirements.txt').splitlines():
if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+'):
parsed_requires = re.findall(r'#egg=([\w\d\.]+)-([\d\.]+)$', requirement)
if parsed_requires:
package, version = parsed_requires[0]
requirements.append(f'{package}=={version}')
else:
print('WARNING! For correct matching dependency links need to specify package name and version'
'such as <dependency url>#egg=<package_name>-<version>')
else:
requirements.append(requirement)
return requirements


def get_links():
return [
requirement for requirement in read('requirements.txt').splitlines()
if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+')
]


def get_version():
""" Get version from the package without actually importing it. """
init = read('rudalle/__init__.py')
for line in init.split('\n'):
if line.startswith('__version__'):
return eval(line.split('=')[1])


setup(
name='rudalle',
version=get_version(),
author='SberAI, SberDevices',
author_email='',
description='',
packages=['rudalle', 'rudalle/dalle', 'rudalle/realesrgan', 'rudalle/ruclip', 'rudalle/vae'],
package_data={'rudalle/vae': ['*.yml']},
install_requires=get_requirements(),
dependency_links=get_links(),
long_description=read('README.md'),
long_description_content_type='text/markdown',
)

0 comments on commit 3987b78

Please sign in to comment.