diff --git a/requirements.txt b/requirements.txt index 321b95d..c358382 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,9 @@ taming-transformers==0.0.1 -more_itertools==8.10.0 -transformers==4.10.2 -youtokentome==1.0.6 -einops==0.3.2 +more_itertools~=8.10.0 +transformers~=4.10.2 +youtokentome~=1.0.6 +omegaconf>=2.0.0 +einops~=0.3.2 torch torchvision matplotlib \ No newline at end of file diff --git a/rudalle/__init__.py b/rudalle/__init__.py index b5698c4..4daea0d 100644 --- a/rudalle/__init__.py +++ b/rudalle/__init__.py @@ -3,7 +3,8 @@ from .dalle import get_rudalle_model from .tokenizer import get_tokenizer from .realesrgan import get_realesrgan -from . import vae, dalle, tokenizer, realesrgan, pipelines +from .ruclip import get_ruclip +from . import vae, dalle, tokenizer, realesrgan, pipelines, ruclip __all__ = [ @@ -11,9 +12,13 @@ 'get_rudalle_model', 'get_tokenizer', 'get_realesrgan', + 'get_ruclip', 'vae', 'dalle', + 'ruclip', 'tokenizer', 'realesrgan', 'pipelines', ] + +__version__ = '0.0.1-rc1' diff --git a/rudalle/pipelines.py b/rudalle/pipelines.py index 6ca1592..a28d169 100644 --- a/rudalle/pipelines.py +++ b/rudalle/pipelines.py @@ -58,6 +58,22 @@ def super_resolution(pil_images, realesrgan): return result +def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu', count=4): + with torch.no_grad(): + inputs = ruclip_processor(text=text, images=pil_images) + for key in inputs.keys(): + inputs[key] = inputs[key].to(device) + outputs = ruclip(**inputs) + sims = outputs.logits_per_image.view(-1).softmax(dim=0) + items = [] + for index, sim in enumerate(sims.cpu().numpy()): + items.append({'img_index': index, 'cosine': sim}) + items = sorted(items, key=lambda x: x['cosine'], reverse=True)[:count] + top_pil_images = [pil_images[x['img_index']] for x in items] + top_scores = [x['cosine'] for x in items] + return top_pil_images, top_scores + + def show(pil_images, nrow=4): imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow) if not isinstance(imgs, list): diff --git a/rudalle/ruclip/__init__.py b/rudalle/ruclip/__init__.py new file mode 100644 index 0000000..b7b538c --- /dev/null +++ b/rudalle/ruclip/__init__.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +import os + +from transformers import CLIPModel +from huggingface_hub import hf_hub_url, cached_download + +from .processor import RuCLIPProcessor + +MODELS = { + 'ruclip-vit-base-patch32-v5': dict( + repo_id='sberbank-ai/ru-clip', + filenames=[ + 'bpe.model', 'config.json', 'pytorch_model.bin' + ] + ), +} + + +def get_ruclip(name, cache_dir='/tmp/rudalle'): + assert name in MODELS + config = MODELS[name] + repo_id = config['repo_id'] + cache_dir = os.path.join(cache_dir, name) + for filename in config['filenames']: + config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{name}/{filename}') + cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) + ruclip = CLIPModel.from_pretrained(cache_dir) + ruclip_processor = RuCLIPProcessor.from_pretrained(cache_dir) + print('ruclip --> ready') + return ruclip, ruclip_processor diff --git a/rudalle/ruclip/processor.py b/rudalle/ruclip/processor.py new file mode 100644 index 0000000..ef2bdc7 --- /dev/null +++ b/rudalle/ruclip/processor.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +import os +import json +import torch +import youtokentome as yttm +import torchvision.transforms as T +from torch.nn.utils.rnn import pad_sequence + + +class RuCLIPProcessor: + eos_id = 3 + bos_id = 2 + unk_id = 1 + pad_id = 0 + + def __init__(self, tokenizer_path, image_size=224, text_seq_length=76, mean=None, std=None): + + self.tokenizer = yttm.BPE(tokenizer_path) + self.mean = mean or [0.485, 0.456, 0.406] + self.std = std or [0.229, 0.224, 0.225] + self.image_transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)), + T.ToTensor(), + T.Normalize(mean=self.mean, std=self.std) + ]) + self.text_seq_length = text_seq_length + self.image_size = image_size + + def encode_text(self, text): + text = text.lower() + tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0] + tokens = [self.bos_id] + tokens + [self.eos_id] + tokens = tokens[:self.text_seq_length] + mask = [1] * len(tokens) + return torch.tensor(tokens).long(), torch.tensor(mask).long() + + def decode_text(self, encoded): + return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ + self.eos_id, self.bos_id, self.unk_id, self.pad_id + ])[0] + + def __call__(self, text=None, images=None, **kwargs): + inputs = {} + if text is not None: + input_ids, masks = [], [] + texts = [text] if isinstance(text, str) else text + for text in texts: + tokens, mask = self.encode_text(text) + input_ids.append(tokens) + masks.append(mask) + inputs['input_ids'] = pad_sequence(input_ids, batch_first=True) + inputs['attention_mask'] = pad_sequence(masks, batch_first=True) + if images is not None: + pixel_values = [] + for i, image in enumerate(images): + pixel_values.append(self.image_transform(image)) + inputs['pixel_values'] = pad_sequence(pixel_values, batch_first=True) + return inputs + + @classmethod + def from_pretrained(cls, folder): + tokenizer_path = os.path.join(folder, 'bpe.model') + config = json.load(open(os.path.join(folder, 'config.json'))) + image_size = config['vision_config']['image_size'] + text_seq_length = config['text_config']['max_position_embeddings'] - 1 + mean, std = config.get('mean'), config.get('std') + return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..6fdb599 --- /dev/null +++ b/setup.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +import os +import re +from setuptools import setup + + +def read(filename): + with open(os.path.join(os.path.dirname(__file__), filename)) as f: + file_content = f.read() + return file_content + + +def get_requirements(): + requirements = [] + for requirement in read('requirements.txt').splitlines(): + if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+'): + parsed_requires = re.findall(r'#egg=([\w\d\.]+)-([\d\.]+)$', requirement) + if parsed_requires: + package, version = parsed_requires[0] + requirements.append(f'{package}=={version}') + else: + print('WARNING! For correct matching dependency links need to specify package name and version' + 'such as #egg=-') + else: + requirements.append(requirement) + return requirements + + +def get_links(): + return [ + requirement for requirement in read('requirements.txt').splitlines() + if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+') + ] + + +def get_version(): + """ Get version from the package without actually importing it. """ + init = read('rudalle/__init__.py') + for line in init.split('\n'): + if line.startswith('__version__'): + return eval(line.split('=')[1]) + + +setup( + name='rudalle', + version=get_version(), + author='SberAI, SberDevices', + author_email='', + description='', + packages=['rudalle', 'rudalle/dalle', 'rudalle/realesrgan', 'rudalle/ruclip', 'rudalle/vae'], + package_data={'rudalle/vae': ['*.yml']}, + install_requires=get_requirements(), + dependency_links=get_links(), + long_description=read('README.md'), + long_description_content_type='text/markdown', +)