Skip to content

Commit

Permalink
Fix windows dtype bug of neural search (PaddlePaddle#2911)
Browse files Browse the repository at this point in the history
  • Loading branch information
w5688414 authored Jul 28, 2022
1 parent 842954c commit 112830b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
5 changes: 3 additions & 2 deletions applications/neural_search/recall/milvus/feature_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ def predict(self, data, tokenizer):
"""

batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"
), # segment
): fn(samples)

all_embeddings = []
Expand Down
6 changes: 4 additions & 2 deletions applications/neural_search/recall/milvus/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ def search_in_milvus(text_embedding):
max_seq_length=max_seq_length)

batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # text_input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # text_segment
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"
), # text_input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"
), # text_segment
): [data for data in fn(samples)]

pretrained_model = AutoModel.from_pretrained("ernie-3.0-medium-zh")
Expand Down

0 comments on commit 112830b

Please sign in to comment.