Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError assert(len(tmp_ids) == 1) when using RoBERTa #5

Open
steveazzolin opened this issue Jun 9, 2023 · 0 comments
Open

Comments

@steveazzolin
Copy link

Hi,

I was trying to run scripts/run_optiprompt.pt with the model roberta-base when I encountered the following problem:

Traceback (most recent call last):
File "/nfs/data_chaos/sazzolin/OptiPrompt/code/run_optiprompt.py", line 176, in
best_result, result_rel = evaluate(model, valid_samples_batches, valid_sentences_batches, filter_indices, index_list)
File "/nfs/data_chaos/sazzolin/OptiPrompt/code/utils.py", line 110, in evaluate
log_probs, cor_b, tot_b, pred_b, topk_preds, loss, common_vocab_loss = model.run_batch(sentences_b, samples_b, training=False, filter_indices=filter_indices, index_list=index_list, vocab_to_common_vocab=vocab_to_common_vocab)
File "/nfs/data_chaos/sazzolin/OptiPrompt/code/models.py", line 345, in run_batch
tokens_tensor, segments_tensor, attention_mask_tensor, masked_indices_list, tokenized_text_list, mlm_labels_tensor, mlm_label_ids = self._get_input_tensors_batch_train(sentences_list, samples_list)
File "/nfs/data_chaos/sazzolin/OptiPrompt/code/models.py", line 151, in _get_input_tensors_batch_train
tokens_tensor, segments_tensor, masked_indices, tokenized_text, mlm_labels_tensor, mlm_label_id = self.__get_input_tensors(sentences, mlm_label=samples['obj_label'])
File "/nfs/data_chaos/sazzolin/OptiPrompt/code/models.py", line 298, in __get_input_tensors
assert(len(tmp_ids) == 1)
AssertionError

Nonetheless, the code works just fine when using bert-base-cased.

This is the complete python call:

python code/run_optiprompt.py \
            --relation_profile relation_metainfo/LAMA_relations.jsonl \
            --relation ${REL} \
            --common_vocab_filename common_vocabs/common_vocab_cased.txt \
            --model_name roberta_base \
            --do_train \
            --train_data data/autoprompt_data/${REL}/train.jsonl \
            --dev_data data/autoprompt_data/${REL}/dev.jsonl \
            --do_eval \
            --test_data data/LAMA-TREx/${REL}.jsonl \
            --output_dir ${DIR} \
            --random_init none \
            --seed ${SEED} \
            --output_predictions 
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant