Skip to content

Commit a08a739

Browse files
committed
add clip model for encoder
1 parent 1d34f35 commit a08a739

File tree

9 files changed

+163
-6
lines changed

9 files changed

+163
-6
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,7 @@ dmypy.json
139139

140140
*.ini
141141

142-
**/multicache_serving.py
142+
**/multicache_serving.py
143+
**/modelcache_serving.py
144+
145+
**/model/

model/clip_zh/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Alipay.com Inc.
4+
Copyright (c) 2004-2023 All Rights Reserved.
5+
------------------------------------------------------
6+
File Name : __init__.py.py
7+
Author : fuhui.phe
8+
Create Time : 2024/5/7 14:05
9+
Description : description what the main function of this file
10+
Change Activity:
11+
version0 : 2024/5/7 14:05 by fuhui.phe init
12+
"""

modelcache/adapter/adapter_query.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,20 @@ def adapt_query(cache_data_convert, *args, **kwargs):
3030
report_func=chat_cache.report.embedding,
3131
)(pre_embedding_data)
3232

33+
# print('embedding_data: {}'.format(embedding_data))
34+
3335
if cache_enable:
3436
cache_data_list = time_cal(
3537
chat_cache.data_manager.search,
36-
func_name="milvus_search",
38+
func_name="vector_search",
3739
report_func=chat_cache.report.search,
3840
)(
3941
embedding_data,
4042
extra_param=context.get("search_func", None),
4143
top_k=kwargs.pop("top_k", -1),
4244
model=model
4345
)
46+
print('cache_data_list: {}'.format(cache_data_list))
4447
cache_answers = []
4548
cache_questions = []
4649
cache_ids = []
@@ -78,8 +81,8 @@ def adapt_query(cache_data_convert, *args, **kwargs):
7881
return
7982

8083
for cache_data in cache_data_list:
84+
print('cache_data: {}'.format(cache_data))
8185
primary_id = cache_data[1]
82-
start_time = time.time()
8386
ret = chat_cache.data_manager.get_scalar_data(
8487
cache_data, extra_param=context.get("get_scalar_data", None)
8588
)

modelcache/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from modelcache.processor.post import first
55
from modelcache.similarity_evaluation import ExactMatchEvaluation
66
from modelcache.similarity_evaluation import SimilarityEvaluation
7-
from modelcache.embedding.string import to_embeddings as string_embedding
7+
from modelcache.embedding.string_text import to_embeddings as string_embedding
88
from modelcache.report import Report
99
from modelcache.config import Config
1010
from modelcache.utils.cache_func import cache_all

modelcache/embedding/clip.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# -*- coding: utf-8 -*-
2+
import os
3+
import torch
4+
from modelcache.embedding.base import BaseEmbedding
5+
from modelscope.utils.constant import Tasks
6+
from modelscope.pipelines import pipeline
7+
from modelscope.preprocessors.image import load_image
8+
9+
10+
# def mean_pooling(model_output, attention_mask):
11+
# token_embeddings = model_output[0] # First element of model_output contains all token embeddings
12+
# input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
13+
# return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
14+
15+
16+
class ClipAudio(BaseEmbedding):
17+
def __init__(self, model: str = "sentence-transformers/all-MiniLM-L6-v2"):
18+
# current_dir = os.path.dirname(os.path.abspath(__file__))
19+
# parent_dir = os.path.dirname(current_dir)
20+
# model_dir = os.path.dirname(parent_dir)
21+
# model = os.path.join(model_dir, 'model/text2vec-base-chinese/')
22+
23+
self.clip_pipeline = pipeline(task=Tasks.multi_modal_embedding,
24+
model='damo/multi-modal_clip-vit-base-patch16_zh', model_revision='v1.0.1')
25+
26+
self.__dimension = 1024
27+
28+
def to_embeddings(self, data_dict, **_):
29+
text_list = data_dict['text']
30+
image_data = data_dict['image']
31+
32+
img_data = None
33+
txt_data = None
34+
35+
if image_data:
36+
input_img = load_image(image_data)
37+
# 2D Tensor, [图片数, 特征维度]
38+
img_embedding = self.clip_pipeline.forward({'img': input_img})['img_embedding'].tolist()[0] if input_img else []
39+
print('img_embedding: {}'.format(img_embedding))
40+
else:
41+
raise ValueError('image_data is None, please check!')
42+
43+
if text_list and len(text_list) > 0:
44+
# 2D Tensor, [文本数, 特征维度]
45+
text_embedding = self.clip_pipeline.forward({'text': text_list})['text_embedding'].tolist()[0] if text_list else []
46+
print('text_embedding: {}'.format(text_embedding))
47+
else:
48+
raise ValueError('text_list is None, please check!')
49+
50+
return {'image_embedding': img_embedding, 'text_embeddings': text_embedding}
51+
52+
# return {'image_embedding': img_feats, 'text_embeddings': txt_feats}
53+
# input_texts = ["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]
54+
# input_img = load_image(
55+
# 'https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg')
56+
57+
# img_embedding = self.clip_pipeline.forward({'img': input_img})['img_embedding'] # 2D Tensor, [图片数, 特征维度]
58+
# print('img_embedding: {}'.format(img_embedding))
59+
# text_embedding = self.clip_pipeline.forward({'text': input_texts})['text_embedding'] # 2D Tensor, [文本数, 特征维度]
60+
61+
62+
# return embedding_array
63+
64+
def post_proc(self, token_embeddings, inputs):
65+
attention_mask = inputs["attention_mask"]
66+
input_mask_expanded = (
67+
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
68+
)
69+
sentence_embs = torch.sum(
70+
token_embeddings * input_mask_expanded, 1
71+
) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
72+
return sentence_embs
73+
74+
@property
75+
def dimension(self):
76+
"""Embedding dimension.
77+
78+
:return: embedding dimension
79+
"""
80+
return self.__dimension
81+
82+
83+
# if __name__ == '__main__':
84+
# clip_vec = ClipAudio()
85+
# text_list = ['hello', '你好']
86+
# text = ['###'.join(text_list)]
87+
# image = 'https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg'
88+
# data_dict = {'text': text, 'image': image}
89+
# resp = clip_vec.to_embeddings(data_dict)
90+
# print('resp: {}'.format(resp))

modelcache/embedding/clip_demo.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Alipay.com Inc.
4+
Copyright (c) 2004-2023 All Rights Reserved.
5+
------------------------------------------------------
6+
File Name : clip_demo.py
7+
Author : fuhui.phe
8+
Create Time : 2024/5/7 11:58
9+
Description : description what the main function of this file
10+
Change Activity:
11+
version0 : 2024/5/7 11:58 by fuhui.phe init
12+
"""
13+
import torch
14+
from modelscope.utils.constant import Tasks
15+
from modelscope.pipelines import pipeline
16+
from modelscope.preprocessors.image import load_image
17+
18+
19+
pipeline = pipeline(task=Tasks.multi_modal_embedding,
20+
model='damo/multi-modal_clip-vit-base-patch16_zh', model_revision='v1.0.1')
21+
22+
# pipeline = pipeline(task=Tasks.multi_modal_embedding,
23+
# model='/Users/penghongen/PycharmProjects/CodeFuse-ModelCache/model/clip_zh', model_revision='v1.0.1')
24+
25+
# pipeline = pipeline(task=Tasks.multi_modal_embedding, model='/Users/penghongen/PycharmProjects/CodeFuse-ModelCache/model/clip_zh')
26+
27+
28+
input_img = load_image('https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg') # 支持皮卡丘示例图片路径/本地图片 返回PIL.Image
29+
30+
31+
input_texts = ["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]
32+
33+
# 支持一张图片(PIL.Image)或多张图片(List[PIL.Image])输入,输出归一化特征向量
34+
img_embedding = pipeline.forward({'img': input_img})['img_embedding'] # 2D Tensor, [图片数, 特征维度]
35+
print('img_embedding: {}'.format(img_embedding))
36+
37+
# 支持一条文本(str)或多条文本(List[str])输入,输出归一化特征向量
38+
text_embedding = pipeline.forward({'text': input_texts})['text_embedding'] # 2D Tensor, [文本数, 特征维度]
39+
40+
# 计算图文相似度
41+
with torch.no_grad():
42+
# 计算内积得到logit,考虑模型temperature
43+
logits_per_image = (img_embedding / pipeline.model.temperature) @ text_embedding.t()
44+
# 根据logit计算概率分布
45+
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
46+
47+
print("图文匹配概率:", probs)
48+
49+
File renamed without changes.
File renamed without changes.

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ pymilvus==2.3.1
88
PyMySQL==1.1.0
99
Requests==2.31.0
1010
torch==2.1.0
11-
transformers==4.34.1
11+
transformers==4.38.2
1212
faiss-cpu==1.7.4
1313
redis==5.0.1
14-
14+
modelscope==1.14.0

0 commit comments

Comments
 (0)