Skip to content

Commit

Permalink
hotfix llama2 typo
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Jul 19, 2023
1 parent 6e1c8f1 commit 08686ff
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions example/LLM/llama2/llama/generation.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import json
import os
import sys
import json
import time
from typing import List, Tuple, Literal, Optional, TypedDict
from pathlib import Path
from typing import List, Literal, Optional, Tuple, TypedDict

import torch
import torch.nn.functional as F
from llama.model import ModelArgs, Transformer
from llama.tokenizer import Tokenizer
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)

from llama.model import ModelArgs, Transformer
from llama.tokenizer import Tokenizer

Role = Literal["system", "user", "assistant"]


Expand Down Expand Up @@ -167,6 +166,7 @@ def generate(
# cut to max gen len
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
probs = None
if logprobs:
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
# cut to eos tok if any
Expand Down

0 comments on commit 08686ff

Please sign in to comment.