Skip to content

Commit

Permalink
Corrects oversights in device handling for GPU support
Browse files Browse the repository at this point in the history
  • Loading branch information
cgpotts committed May 23, 2022
1 parent d3d9665 commit d570932
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions hw_openqa.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@
" \n",
" \"\"\"\n",
" prompt_ids = eleuther_tokenizer(\n",
" prompts, return_tensors=\"pt\", padding=True).input_ids\n",
" prompts, return_tensors=\"pt\", padding=True).input_ids.to(device)\n",
" \n",
" with torch.inference_mode():\n",
" # Automatic mixed precision if possible.\n",
Expand Down Expand Up @@ -645,7 +645,7 @@
" for prompt, gen_id, gen_text, gen_prob in iterator: \n",
" gen_tokens = eleuther_tokenizer.convert_ids_to_tokens(gen_id)\n",
" generated_text = gen_text[len(prompt): ]\n",
" gen_prob = [float(x) for x in gen_prob.numpy()] # float for JSON storage\n",
" gen_prob = [float(x) for x in gen_prob.cpu().numpy()] # float for JSON storage\n",
" ans_indices = _find_generated_answer(gen_tokens, newline=\"Ċ\")\n",
" answer_tokens = [gen_tokens[i] for i in ans_indices]\n",
" answer_probs = [gen_prob[i] for i in ans_indices]\n",
Expand Down

0 comments on commit d570932

Please sign in to comment.