|
293 | 293 | " padding_mask = tf.cast(\n", |
294 | 294 | " mask[:, tf.newaxis, :], dtype=\"int32\")\n", |
295 | 295 | " padding_mask = tf.minimum(padding_mask, causal_mask)\n", |
| 296 | + " else:\n", |
| 297 | + " padding_mask = mask\n", |
296 | 298 | " attention_output_1 = self.attention_1(\n", |
297 | 299 | " query=inputs,\n", |
298 | 300 | " value=inputs,\n", |
|
391 | 393 | " self.model_input_length = model_input_length\n", |
392 | 394 | " self.temperatures = temperatures\n", |
393 | 395 | " self.print_freq = print_freq\n", |
| 396 | + " vectorized_prompt = text_vectorization([prompt])[0].numpy()\n", |
| 397 | + " self.prompt_length = np.nonzero(vectorized_prompt == 0)[0][0]\n", |
394 | 398 | "\n", |
395 | 399 | " def on_epoch_end(self, epoch, logs=None):\n", |
396 | 400 | " if (epoch + 1) % self.print_freq != 0:\n", |
|
401 | 405 | " for i in range(self.generate_length):\n", |
402 | 406 | " tokenized_sentence = text_vectorization([sentence])\n", |
403 | 407 | " predictions = self.model(tokenized_sentence)\n", |
404 | | - " next_token = sample_next(predictions[0, i, :])\n", |
| 408 | + " next_token = sample_next(\n", |
| 409 | + " predictions[0, self.prompt_length - 1 + i, :]\n", |
| 410 | + " )\n", |
405 | 411 | " sampled_token = tokens_index[next_token]\n", |
406 | 412 | " sentence += \" \" + sampled_token\n", |
407 | 413 | " print(sentence)\n", |
|
0 commit comments