diff --git a/legacy/examples/few_shot/p-tuning/data.py b/legacy/examples/few_shot/p-tuning/data.py index 6f96ac02cdc8..15dd653a29a9 100644 --- a/legacy/examples/few_shot/p-tuning/data.py +++ b/legacy/examples/few_shot/p-tuning/data.py @@ -139,7 +139,8 @@ def convert_ids_to_words(example, token_ids): the length of which should coincide with that of `mask` in prompt. """ if "label_ids" in example: - labels = paddle.index_select(token_ids, paddle.to_tensor(example.pop("label_ids")), axis=0).squeeze(0) + label_ids_tensor = paddle.to_tensor([example.pop("label_ids")], dtype='int64') + labels = paddle.index_select(token_ids, label_ids_tensor, axis=0).squeeze(0) example["labels"] = labels return example