Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenAI embedding to text benchmark script #367

Merged
merged 4 commits into from
Feb 28, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add OpenAI embedding to text benchmark script
  • Loading branch information
zechengz committed Feb 28, 2024
commit 820eee465a3b1ad246755c1f3d956c48c512233b
31 changes: 30 additions & 1 deletion benchmark/data_frame_text_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,13 @@
"google/electra-large-discriminator",
"sentence-transformers/all-distilroberta-v1",
"sentence-transformers/average_word_embeddings_glove.6B.300d",
"sentence-transformers/all-roberta-large-v1",
"text-embedding-3-large",
],
)
parser.add_argument("--finetune", action="store_true")
parser.add_argument('--result_path', type=str, default='')
parser.add_argument("--api_key", type=str, default=None)
args = parser.parse_args()

model_out_channels = {
Expand Down Expand Up @@ -188,6 +191,26 @@ def tokenize(self, sentences: list[str]) -> TextTokenizationOutputs:
return_tensors="pt")


class OpenAIEmbedding:
def __init__(self, model: str, api_key: str):
# Please run `pip install openai` to install the package
from openai import OpenAI

self.client = OpenAI(api_key=api_key)
self.model = model

def __call__(self, sentences: list[str]) -> Tensor:
from openai import Embedding

items: list[Embedding] = self.client.embeddings.create(
input=sentences, model=self.model).data
assert len(items) == len(sentences)
embeddings = [
torch.FloatTensor(item.embedding).view(1, -1) for item in items
]
return torch.cat(embeddings, dim=0)


def mean_pooling(last_hidden_state: Tensor, attention_mask: Tensor) -> Tensor:
input_mask_expanded = (attention_mask.unsqueeze(-1).expand(
last_hidden_state.size()).float())
Expand Down Expand Up @@ -360,7 +383,13 @@ def main_torch(
path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data")

if not args.finetune:
text_encoder = TextToEmbedding(model=args.text_model, device=device)
if args.text_model == "text-embedding-3-large":
assert isinstance(args.api_key, str)
text_encoder = OpenAIEmbedding(model=args.text_model,
api_key=args.api_key)
else:
text_encoder = TextToEmbedding(model=args.text_model,
device=device)
text_stype = torch_frame.text_embedded
kwargs = {
"text_stype":
Expand Down
Loading