forked from nebuly-ai/optimate
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request nebuly-ai#184 from nebuly-ai/chat_llama
Add ChatLLaMA Implementation
- Loading branch information
Showing
14 changed files
with
2,291 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
Empty file.
62 changes: 62 additions & 0 deletions
62
apps/accelerate/chatllama/chatllama/langchain_modules/prompt_templates.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
REWARD_TEMPLATE = dict( | ||
template=( | ||
"Lets pretend that you are a lawyer and you have to" | ||
"evalaute the following completion task from a given" | ||
"assigment with a score between 0 and 5 where 0 represents" | ||
"a bad assignment completion and 5 a perfect completion.\n" | ||
"You MUST evaluate: text quality, content quality and" | ||
"coherence.\n" | ||
"You MUST return only the number that represents your" | ||
"judgment.\n" | ||
"The assignement is:\n{user_input}\n" | ||
"The completion is:\n{completion}\n" | ||
), | ||
input_variables=["user_input", "completion"], | ||
) | ||
|
||
|
||
AI_CHATBOT_TEMPLATE = dict( | ||
template=( | ||
"Assistant is a large language model trained by Meta and Nebuly.ai\n" | ||
"Assistant is designed to be able to assist with a wide range of " | ||
"tasks, from answering simple questions to providing in-depth " | ||
"explanations and discussions on a wide range of topics. As a " | ||
"language model, Assistant is able to generate human-like text " | ||
"based on the input it receives, allowing it to engage in " | ||
"natural-sounding conversations and provide responses that are " | ||
"coherent and relevant to the topic at hand.\n\n" | ||
"Assistant is constantly learning and improving, and its capabilities " | ||
"are constantly evolving. It is able to process and understand large " | ||
"amounts of text, and can use this knowledge to provide accurate and " | ||
"informative responses to a wide range of questions. Additionally, " | ||
"Assistant is able to generate its own text based on the input it " | ||
"receives, allowing it to engage in discussions and provide " | ||
"explanations and descriptions on a wide range of topics.\n\n" | ||
"Overall, Assistant is a powerful tool that can help with a wide " | ||
"range of tasks and provide valuable insights and information on a " | ||
"wide range of topics. Whether you need help with a specific " | ||
"question or just want to have a conversation about a particular " | ||
"topic, Assistant is here to assist.\n\n{history}\n\n" | ||
"Human: {human_input}\n" | ||
"Assistant:" | ||
), | ||
input_variables=["history", "human_input"], | ||
) | ||
|
||
|
||
PERSON_CHATBOT_TEMPLATE = dict( | ||
template=( | ||
"You are a human chatting with a chatbot. The chatbot is a large " | ||
"language model trained by Meta and Nebuly-ai\n" | ||
"The chatbot is designed to be able to assist you with a wide range " | ||
"of tasks, from answering simple questions to providing in-depth " | ||
"explanations and discussions on a wide range of topics. You are a " | ||
"human and you are testing the chatbot. Ask the chatbot questions and" | ||
"see how it responds. You can also ask the chatbot to tell you a " | ||
"story." | ||
"\n\n{history}\n\n" | ||
"Chatbot: {chatbot_input}\n" | ||
"Human:" | ||
), | ||
input_variables=["history", "chatbot_input"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import json | ||
import os | ||
from pathlib import Path | ||
from typing import Tuple, List, Union | ||
|
||
import torch.distributed | ||
import torch.nn as nn | ||
from fairscale.nn.model_parallel.initialize import initialize_model_parallel | ||
from fairscale.nn.model_parallel.layers import ( | ||
ParallelEmbedding, | ||
ColumnParallelLinear, | ||
) | ||
from llama import ModelArgs, Tokenizer | ||
from llama.generation import sample_top_p | ||
from llama.model import TransformerBlock, RMSNorm, precompute_freqs_cis | ||
|
||
|
||
class HFLikeTokenizer: | ||
def __init__(self, tokenizer: Tokenizer): | ||
self.tokenizer = tokenizer | ||
|
||
def __call__(self, texts: Union[List[str], str], *args, **kwargs): | ||
if isinstance(texts, str): | ||
text = self.tokenizer.encode(texts, bos=True, eos=True) | ||
return torch.tensor(text).cuda().long() | ||
else: | ||
texts = [ | ||
self.tokenizer.encode(text, bos=True, eos=True) | ||
for text in texts | ||
] | ||
max_len = max(len(text) for text in texts) | ||
tokens = ( | ||
torch.full((len(texts), max_len), self.tokenizer.pad_id) | ||
.cuda() | ||
.long() | ||
) | ||
for i, text in enumerate(texts): | ||
tokens[i, : len(text)] = torch.tensor(text).cuda().long() | ||
return tokens | ||
|
||
def decode(self, tokens): | ||
return self.tokenizer.decode(tokens) | ||
|
||
|
||
class Transformer(nn.Module): | ||
def __init__(self, params: ModelArgs): | ||
super().__init__() | ||
self.params = params | ||
self.vocab_size = params.vocab_size | ||
self.n_layers = params.n_layers | ||
|
||
self.tok_embeddings = ParallelEmbedding( | ||
params.vocab_size, params.dim, init_method=lambda x: x | ||
) | ||
|
||
self.layers = torch.nn.ModuleList() | ||
for layer_id in range(params.n_layers): | ||
self.layers.append(TransformerBlock(layer_id, params)) | ||
|
||
self.norm = RMSNorm(params.dim, eps=params.norm_eps) | ||
self.output = ColumnParallelLinear( | ||
params.dim, params.vocab_size, bias=False, init_method=lambda x: x | ||
) | ||
|
||
self.freqs_cis = precompute_freqs_cis( | ||
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 | ||
) | ||
|
||
def forward(self, tokens: torch.Tensor, attention_mask: torch.Tensor): | ||
start_pos = int(torch.argmax(attention_mask.detach(), dim=-1).item()) | ||
logits = self._forward(tokens, start_pos) | ||
return logits | ||
|
||
def _forward(self, tokens: torch.Tensor, start_pos: int): | ||
_bsz, seqlen = tokens.shape | ||
h = self.tok_embeddings(tokens) | ||
self.freqs_cis = self.freqs_cis.to(h.device) | ||
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] # noqa E203 | ||
|
||
mask = None | ||
if seqlen > 1: | ||
mask = torch.full( | ||
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device | ||
) | ||
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) | ||
|
||
for layer in self.layers: | ||
h = layer(h, start_pos, freqs_cis, mask) | ||
h = self.norm(h) | ||
output = self.output(h[:, -1, :]) # only compute last logits | ||
return output.float() | ||
|
||
@torch.no_grad() | ||
def generate( | ||
self, | ||
inputs: torch.Tensor, | ||
attention_mask: torch.Tensor, | ||
max_length: int, | ||
temperature: float, | ||
top_p: float = 1.0, | ||
): | ||
prompt_size = inputs.shape[1] | ||
total_len = min(self.params.max_seq_len, max_length + prompt_size) | ||
start_pos = prompt_size # We assume left padding | ||
prev_pos = 0 | ||
generated_tokens = [] | ||
for cur_pos in range(start_pos, total_len): | ||
logits = self._forward(inputs[:, prev_pos:cur_pos], prev_pos) | ||
if temperature > 0: | ||
probs = torch.softmax(logits / temperature, dim=-1) | ||
next_token = sample_top_p(probs, top_p) | ||
else: | ||
next_token = torch.argmax(logits, dim=-1) | ||
next_token = next_token.reshape(-1) | ||
generated_tokens.append(next_token) | ||
prev_pos = cur_pos | ||
return torch.stack(generated_tokens, dim=1) | ||
|
||
|
||
def setup_model_parallel() -> Tuple[int, int]: | ||
local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||
world_size = int(os.environ.get("WORLD_SIZE", -1)) | ||
|
||
torch.distributed.init_process_group("nccl") | ||
initialize_model_parallel(world_size) | ||
torch.cuda.set_device(local_rank) | ||
|
||
# seed must be the same in all processes | ||
torch.manual_seed(1) | ||
return local_rank, world_size | ||
|
||
|
||
def load_checkpoints( | ||
ckpt_dir: str, local_rank: int, world_size: int | ||
) -> Tuple[dict, dict]: | ||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) | ||
assert world_size == len(checkpoints), ( | ||
f"Loading a checkpoint for MP={len(checkpoints)} but world " | ||
f"size is {world_size}" | ||
) | ||
ckpt_path = checkpoints[local_rank] | ||
print("Loading") | ||
checkpoint = torch.load(ckpt_path, map_location="cpu") | ||
with open(Path(ckpt_dir) / "params.json", "r") as f: | ||
params = json.loads(f.read()) | ||
return checkpoint, params | ||
|
||
|
||
def load_model( | ||
ckpt_dir: str, | ||
tokenizer_path: str, | ||
local_rank: int, | ||
world_size: int, | ||
max_batch_size: int = 32, | ||
) -> Tuple[Transformer, HFLikeTokenizer]: | ||
checkpoint, params = load_checkpoints(ckpt_dir, local_rank, world_size) | ||
model_args: ModelArgs = ModelArgs( | ||
max_seq_len=1024, max_batch_size=max_batch_size, **params | ||
) | ||
tokenizer = Tokenizer(model_path=tokenizer_path) | ||
model_args.vocab_size = tokenizer.n_words | ||
torch.set_default_tensor_type(torch.cuda.HalfTensor) | ||
model = Transformer(model_args) | ||
torch.set_default_tensor_type(torch.FloatTensor) | ||
model.load_state_dict(checkpoint, strict=False) | ||
tokenizer = HFLikeTokenizer(tokenizer) | ||
return model, tokenizer | ||
|
||
|
||
def generate( | ||
model: Transformer, | ||
tokenizer: Tokenizer, | ||
prompts: List[str], | ||
max_gen_len: int, | ||
temperature: float = 0.8, | ||
top_p: float = 0.95, | ||
) -> List[str]: | ||
bsz = len(prompts) | ||
params = model.params | ||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) | ||
|
||
prompt_tokens = [tokenizer.encode(x, bos=True, eos=False) for x in prompts] | ||
|
||
min_prompt_size = min([len(t) for t in prompt_tokens]) | ||
max_prompt_size = max([len(t) for t in prompt_tokens]) | ||
|
||
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) | ||
|
||
tokens = torch.full((bsz, total_len), tokenizer.pad_id).cuda().long() | ||
for k, t in enumerate(prompt_tokens): | ||
tokens[k, : len(t)] = torch.tensor(t).long() | ||
input_text_mask = tokens != tokenizer.pad_id | ||
start_pos = min_prompt_size | ||
prev_pos = 0 | ||
for cur_pos in range(start_pos, total_len): | ||
logits = model._forward(tokens[:, prev_pos:cur_pos], prev_pos) | ||
if temperature > 0: | ||
probs = torch.softmax(logits / temperature, dim=-1) | ||
next_token = sample_top_p(probs, top_p) | ||
else: | ||
next_token = torch.argmax(logits, dim=-1) | ||
next_token = next_token.reshape(-1) | ||
# only replace token if prompt has already been generated | ||
next_token = torch.where( | ||
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token | ||
) | ||
tokens[:, cur_pos] = next_token | ||
prev_pos = cur_pos | ||
|
||
decoded = [] | ||
for i, t in enumerate(tokens.tolist()): | ||
# cut to max gen len | ||
t = t[: len(prompt_tokens[i]) + max_gen_len] | ||
# cut to eos tok if any | ||
try: | ||
t = t[: t.index(tokenizer.eos_id)] | ||
except ValueError: | ||
pass | ||
decoded.append(tokenizer.decode(t)) | ||
return decoded |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""RLHF implementation inspired to Lucidrains' implementation.""" |
Oops, something went wrong.