forked from ai-forever/ru-dalle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request ai-forever#3 from sberbank-ai/feature/ruclip
ruclip
- Loading branch information
Showing
6 changed files
with
181 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
taming-transformers==0.0.1 | ||
more_itertools==8.10.0 | ||
transformers==4.10.2 | ||
youtokentome==1.0.6 | ||
einops==0.3.2 | ||
more_itertools~=8.10.0 | ||
transformers~=4.10.2 | ||
youtokentome~=1.0.6 | ||
omegaconf>=2.0.0 | ||
einops~=0.3.2 | ||
torch | ||
torchvision | ||
matplotlib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# -*- coding: utf-8 -*- | ||
import os | ||
|
||
from transformers import CLIPModel | ||
from huggingface_hub import hf_hub_url, cached_download | ||
|
||
from .processor import RuCLIPProcessor | ||
|
||
MODELS = { | ||
'ruclip-vit-base-patch32-v5': dict( | ||
repo_id='sberbank-ai/ru-clip', | ||
filenames=[ | ||
'bpe.model', 'config.json', 'pytorch_model.bin' | ||
] | ||
), | ||
} | ||
|
||
|
||
def get_ruclip(name, cache_dir='/tmp/rudalle'): | ||
assert name in MODELS | ||
config = MODELS[name] | ||
repo_id = config['repo_id'] | ||
cache_dir = os.path.join(cache_dir, name) | ||
for filename in config['filenames']: | ||
config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{name}/{filename}') | ||
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) | ||
ruclip = CLIPModel.from_pretrained(cache_dir) | ||
ruclip_processor = RuCLIPProcessor.from_pretrained(cache_dir) | ||
print('ruclip --> ready') | ||
return ruclip, ruclip_processor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# -*- coding: utf-8 -*- | ||
import os | ||
import json | ||
import torch | ||
import youtokentome as yttm | ||
import torchvision.transforms as T | ||
from torch.nn.utils.rnn import pad_sequence | ||
|
||
|
||
class RuCLIPProcessor: | ||
eos_id = 3 | ||
bos_id = 2 | ||
unk_id = 1 | ||
pad_id = 0 | ||
|
||
def __init__(self, tokenizer_path, image_size=224, text_seq_length=76, mean=None, std=None): | ||
|
||
self.tokenizer = yttm.BPE(tokenizer_path) | ||
self.mean = mean or [0.485, 0.456, 0.406] | ||
self.std = std or [0.229, 0.224, 0.225] | ||
self.image_transform = T.Compose([ | ||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | ||
T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)), | ||
T.ToTensor(), | ||
T.Normalize(mean=self.mean, std=self.std) | ||
]) | ||
self.text_seq_length = text_seq_length | ||
self.image_size = image_size | ||
|
||
def encode_text(self, text): | ||
text = text.lower() | ||
tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0] | ||
tokens = [self.bos_id] + tokens + [self.eos_id] | ||
tokens = tokens[:self.text_seq_length] | ||
mask = [1] * len(tokens) | ||
return torch.tensor(tokens).long(), torch.tensor(mask).long() | ||
|
||
def decode_text(self, encoded): | ||
return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ | ||
self.eos_id, self.bos_id, self.unk_id, self.pad_id | ||
])[0] | ||
|
||
def __call__(self, text=None, images=None, **kwargs): | ||
inputs = {} | ||
if text is not None: | ||
input_ids, masks = [], [] | ||
texts = [text] if isinstance(text, str) else text | ||
for text in texts: | ||
tokens, mask = self.encode_text(text) | ||
input_ids.append(tokens) | ||
masks.append(mask) | ||
inputs['input_ids'] = pad_sequence(input_ids, batch_first=True) | ||
inputs['attention_mask'] = pad_sequence(masks, batch_first=True) | ||
if images is not None: | ||
pixel_values = [] | ||
for i, image in enumerate(images): | ||
pixel_values.append(self.image_transform(image)) | ||
inputs['pixel_values'] = pad_sequence(pixel_values, batch_first=True) | ||
return inputs | ||
|
||
@classmethod | ||
def from_pretrained(cls, folder): | ||
tokenizer_path = os.path.join(folder, 'bpe.model') | ||
config = json.load(open(os.path.join(folder, 'config.json'))) | ||
image_size = config['vision_config']['image_size'] | ||
text_seq_length = config['text_config']['max_position_embeddings'] - 1 | ||
mean, std = config.get('mean'), config.get('std') | ||
return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# -*- coding: utf-8 -*- | ||
import os | ||
import re | ||
from setuptools import setup | ||
|
||
|
||
def read(filename): | ||
with open(os.path.join(os.path.dirname(__file__), filename)) as f: | ||
file_content = f.read() | ||
return file_content | ||
|
||
|
||
def get_requirements(): | ||
requirements = [] | ||
for requirement in read('requirements.txt').splitlines(): | ||
if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+'): | ||
parsed_requires = re.findall(r'#egg=([\w\d\.]+)-([\d\.]+)$', requirement) | ||
if parsed_requires: | ||
package, version = parsed_requires[0] | ||
requirements.append(f'{package}=={version}') | ||
else: | ||
print('WARNING! For correct matching dependency links need to specify package name and version' | ||
'such as <dependency url>#egg=<package_name>-<version>') | ||
else: | ||
requirements.append(requirement) | ||
return requirements | ||
|
||
|
||
def get_links(): | ||
return [ | ||
requirement for requirement in read('requirements.txt').splitlines() | ||
if requirement.startswith('git+') or requirement.startswith('svn+') or requirement.startswith('hg+') | ||
] | ||
|
||
|
||
def get_version(): | ||
""" Get version from the package without actually importing it. """ | ||
init = read('rudalle/__init__.py') | ||
for line in init.split('\n'): | ||
if line.startswith('__version__'): | ||
return eval(line.split('=')[1]) | ||
|
||
|
||
setup( | ||
name='rudalle', | ||
version=get_version(), | ||
author='SberAI, SberDevices', | ||
author_email='', | ||
description='', | ||
packages=['rudalle', 'rudalle/dalle', 'rudalle/realesrgan', 'rudalle/ruclip', 'rudalle/vae'], | ||
package_data={'rudalle/vae': ['*.yml']}, | ||
install_requires=get_requirements(), | ||
dependency_links=get_links(), | ||
long_description=read('README.md'), | ||
long_description_content_type='text/markdown', | ||
) |