From def96882f9d009cd704e3db9e6e119e89e00352b Mon Sep 17 00:00:00 2001 From: ZHUO ZHANG Date: Mon, 29 Apr 2024 19:48:36 +0800 Subject: [PATCH] add a demo (scifi-train) --- llama-3/README.md | 2 + scifi-demo/README.md | 13 ++++ scifi-demo/data/scifi.txt | 0 scifi-demo/finetune-scifi.py | 101 +++++++++++++++++++++++++ scifi-demo/inference-scifi.py | 42 +++++++++++ scifi-demo/model.py | 138 ++++++++++++++++++++++++++++++++++ scifi-demo/show-parameters.py | 11 +++ scifi-demo/tools_download.py | 35 +++++++++ scifi-demo/train-scifi.py | 117 ++++++++++++++++++++++++++++ 9 files changed, 459 insertions(+) create mode 100644 scifi-demo/README.md create mode 100644 scifi-demo/data/scifi.txt create mode 100644 scifi-demo/finetune-scifi.py create mode 100644 scifi-demo/inference-scifi.py create mode 100644 scifi-demo/model.py create mode 100644 scifi-demo/show-parameters.py create mode 100644 scifi-demo/tools_download.py create mode 100644 scifi-demo/train-scifi.py diff --git a/llama-3/README.md b/llama-3/README.md index 90bf5e3..187f13f 100644 --- a/llama-3/README.md +++ b/llama-3/README.md @@ -1,5 +1,7 @@ # 中文微调llama-3 +这个目录是用来存放中文微调llama-3的代码的。与其他目录无关。 + 目前最省GPU最快的方法微调llama-3就是通过usloth的方法,这个方法是在llama-3的基础上,他们预先量化到了4bit,减少微调时所需的内存。 这个方法的优点是,不需要重新训练模型,只需要下载预训练模型,然后微调即可。 diff --git a/scifi-demo/README.md b/scifi-demo/README.md new file mode 100644 index 0000000..b578ea2 --- /dev/null +++ b/scifi-demo/README.md @@ -0,0 +1,13 @@ +## About this repo + +这个目录是我在B站、抖音上发布视频里的一个demo。 + +通过我们跟目录下一样的手写的模型来实现一个科幻小说微调原理。教学作用,代码不适用于生产环境。 + +### 准备数据集 + +准备数据集的部分需要预处理,并存放在 `/data` 子文件夹下。 + + - 训练数据集 `tools_download`工具代码中的第18~35行。先把zip包解压,然后合并所有的小说txt文件到单一的txt文件当中。 + + - 微调数据集 `tools_download`工具代码中的第6~14行。从huggingface上下载一个带微调指令的数据集。 diff --git a/scifi-demo/data/scifi.txt b/scifi-demo/data/scifi.txt new file mode 100644 index 0000000..e69de29 diff --git a/scifi-demo/finetune-scifi.py b/scifi-demo/finetune-scifi.py new file mode 100644 index 0000000..c61aacb --- /dev/null +++ b/scifi-demo/finetune-scifi.py @@ -0,0 +1,101 @@ +""" +Fine-tune a model +""" +import os +import sys +import pickle +from contextlib import nullcontext +import torch +import tiktoken +from aim import Run +from model import Model +import json + + +# Hyperparameters +batch_size = 8 # How many batches per training step +context_length = 128 # Length of the token chunk each batch +max_iters = 1000 # Total of training iterations <- Change this to smaller number for testing +learning_rate = 1e-4 # 0.001 +eval_interval = 10 # How often to evaluate +eval_iters = 10 # Number of iterations to average for evaluation +device = 'cuda' if torch.cuda.is_available() else 'cpu' # Use GPU if it's available. +TORCH_SEED = 1337 +torch.manual_seed(TORCH_SEED) + + +# 准备训练数据 +with open('data/scifi-finetune.json', 'r') as file: + alpaca = json.load(file) + text = alpaca[1000:5001] + +# print(text) +# sys.exit(0) + +# Using TikToken (Same as GPT3) to tokenize the source text +encoding = tiktoken.get_encoding("cl100k_base") +tokenized_text = encoding.encode(str(text)) +tokenized_text = torch.tensor(tokenized_text, dtype=torch.long, device=device) # 将77,919个tokens 转换到Pytorch张量中 + +total_tokens = encoding.encode_ordinary(str(text)) +print(f"数据集合计有 {len(total_tokens):,} tokens") + + +# Split train and validation +train_size = int(len(tokenized_text) * 0.9) +train_data = tokenized_text[:train_size] +val_data = tokenized_text[train_size:] + + +# Initialize the model +model = Model() +model.load_state_dict(torch.load('model/model-scifi.pt')) +model.to(device) + +# get batch +def get_batch(split: str): + data = train_data if split == 'train' else val_data + idxs = torch.randint(low=0, high=len(data) - context_length, size=(batch_size,)) + x = torch.stack([data[idx:idx + context_length] for idx in idxs]).to(device) + y = torch.stack([data[idx + 1:idx + context_length + 1] for idx in idxs]).to(device) + return x, y + + +# calculate the loss +@torch.no_grad() +def estimate_loss(): + out = {} + model.eval() + for split in ['train', 'valid']: + losses = torch.zeros(eval_iters) + for k in range(eval_iters): + x_batch, y_batch = get_batch(split) + logits, loss = model(x_batch, y_batch) + losses[k] = loss.item() + out[split] = losses.mean() + model.train() + return out + + +# Create the optimizer +optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) +tracked_losses = list() +for step in range(max_iters): + if step % eval_iters == 0 or step == max_iters - 1: + losses = estimate_loss() + tracked_losses.append(losses) + print('Step:', step, 'Training Loss:', round(losses['train'].item(), 3), 'Validation Loss:', round(losses['valid'].item(), 3)) + + xb, yb = get_batch('train') + logits, loss = model(xb, yb) + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + +# Save the model +torch.save(model.state_dict(), 'model/model-scifi-finetune.pt') + + + + + diff --git a/scifi-demo/inference-scifi.py b/scifi-demo/inference-scifi.py new file mode 100644 index 0000000..2d63e2a --- /dev/null +++ b/scifi-demo/inference-scifi.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +""" +Sample from a trained model +""" +import os +import pickle +from contextlib import nullcontext +import torch +import tiktoken +from model import Model + + +# Hyperparameters +device = 'cuda' if torch.cuda.is_available() else 'cpu' +TORCH_SEED = 1337 +torch.manual_seed(TORCH_SEED) +torch.cuda.manual_seed(TORCH_SEED) + + +encoding = tiktoken.get_encoding("cl100k_base") + + +# Initiate from trained model +model = Model() +model.load_state_dict(torch.load('model/model-scifi.pt')) +model.eval() +model.to(device) + +# start = 'Write a short story about Sam Altman.' +start = 'Sam Altman was born in' +start_ids = encoding.encode(start) +x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) + +# run generation +with torch.no_grad(): + y = model.generate(x, max_new_tokens=500) + print('---------------') + print(encoding.decode(y[0].tolist())) + print('---------------') + + + diff --git a/scifi-demo/model.py b/scifi-demo/model.py new file mode 100644 index 0000000..4e67863 --- /dev/null +++ b/scifi-demo/model.py @@ -0,0 +1,138 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F +import tiktoken + +# Hyperparameters +context_length = 128 # Length of the token chunk each batch +d_model = 512 # The size of our model token embeddings +num_blocks = 12 # Number of transformer blocks +num_heads = 8 # Number of heads in Multi-head attention +dropout = 0.1 # Dropout rate +device = 'cuda' if torch.cuda.is_available() else 'cpu' # Use GPU if it's available. +TORCH_SEED = 1337 +torch.manual_seed(TORCH_SEED) + +# Define feed forward network +class FeedForwardNetwork(nn.Module): + def __init__(self): + super().__init__() + self.ffn = nn.Sequential( + nn.Linear(d_model, d_model * 4), + nn.ReLU(), + nn.Linear(d_model * 4, d_model), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.ffn(x) + + +# Define Scaled Dot Product Attention +class Attention(nn.Module): + def __init__(self): + super().__init__() + self.Wq = nn.Linear(d_model, d_model // num_heads, bias=False) + self.Wk = nn.Linear(d_model, d_model // num_heads, bias=False) + self.Wv = nn.Linear(d_model, d_model // num_heads, bias=False) + self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length))) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + B, T, C = x.shape + q = self.Wq(x) + k = self.Wk(x) + v = self.Wv(x) + + weights = (q @ k.transpose(-2, -1)) / math.sqrt(d_model // num_heads) + weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf')) + weights = F.softmax(weights, dim=-1) + weights = self.dropout(weights) + + output = weights @ v + + return output + + +# Define Multi-head Attention +class MultiHeadAttention(nn.Module): + def __init__(self): + super().__init__() + self.heads = nn.ModuleList([Attention() for _ in range(num_heads)]) + self.projection_layer = nn.Linear(d_model, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + head_outputs = [head(x) for head in self.heads] + head_outputs = torch.cat(head_outputs, dim=-1) + out = self.dropout(self.projection_layer(head_outputs)) + return out + + +# Define Transformer Block +class TransformerBlock(nn.Module): + def __init__(self): + super().__init__() + self.ln1 = nn.LayerNorm(d_model) + self.ln2 = nn.LayerNorm(d_model) + self.mha = MultiHeadAttention() + self.ffn = FeedForwardNetwork() + + def forward(self, x): + x = x + self.mha(self.ln1(x)) + x = x + self.ffn(self.ln2(x)) + return x + + +# Define the model +class Model(nn.Module): + def __init__(self, max_token_value=100256): # if not passed, force to be default tiktoken cl100k vocab size + super().__init__() + self.token_embedding_lookup_table = nn.Embedding(max_token_value, d_model) + self.transformer_blocks = nn.Sequential(*( + [TransformerBlock() for _ in range(num_blocks)] + + [nn.LayerNorm(d_model)] + )) + self.model_out_linear_layer = nn.Linear(d_model, max_token_value) + + def forward(self, idx, targets=None): + B, T = idx.shape + position_encoding_lookup_table = torch.zeros(context_length, d_model, device=device) + position = torch.arange(0, context_length, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term) + position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term) + # change position_encoding_lookup_table from (context_length, d_model) to (T, d_model) + position_embedding = position_encoding_lookup_table[:T, :].to(device) + x = self.token_embedding_lookup_table(idx) + position_embedding + x = self.transformer_blocks(x) + # get the final logits + logits = self.model_out_linear_layer(x) + + if targets is not None: + B, T, C = logits.shape + logits_reshaped = logits.view(B * T, C) + targets_reshaped = targets.view(B * T) + loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped) + else: + loss = None + return logits, loss + + def generate(self, idx, max_new_tokens=100): + # idx is (B,T) array of indices in the current context + for _ in range(max_new_tokens): + # Crop idx to the max size of our positional embeddings table + idx_crop = idx[:, -context_length:] + # Get predictions + logits, loss = self.forward(idx_crop) + # Get the last time step from logits where the dimensions of the logits are (B,T,C) + logits_last_timestep = logits[:, -1, :] + # Apply softmax to get probabilities + probs = F.softmax(input=logits_last_timestep, dim=-1) + # Sample from the probabilities' distribution. + idx_next = torch.multinomial(input=probs, num_samples=1) + # Append the sampled indexes idx_next to idx + idx = torch.cat((idx, idx_next), dim=1) + return idx + diff --git a/scifi-demo/show-parameters.py b/scifi-demo/show-parameters.py new file mode 100644 index 0000000..3d3e55e --- /dev/null +++ b/scifi-demo/show-parameters.py @@ -0,0 +1,11 @@ +import torch +from model import Model + +model = Model() +state_dict = torch.load('model/model-scifi.pt') +model.load_state_dict(state_dict) + +# Calculate the number of parameters +total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + +print(f"模型参数量为: {total_params:,}") diff --git a/scifi-demo/tools_download.py b/scifi-demo/tools_download.py new file mode 100644 index 0000000..689c8a4 --- /dev/null +++ b/scifi-demo/tools_download.py @@ -0,0 +1,35 @@ +import os +import sys +import requests +import glob + +def download_file(url, save_path): + response = requests.get(url, stream=True) + with open(save_path, 'wb') as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + +url = "https://huggingface.co/datasets/zxbsmk/webnovel_cn/resolve/main/novel_cn_token512_50k.json?download=true" +save_path = "data/scifi-finetune.json" +download_file(url, save_path) + +sys.exit(0) + +# 小说模型训练数据集下载链接: https://pan.baidu.com/s/1bC8fH8hyt28L9pV3fjOHIQ 提取码: 9i9g +def find_txt_files(directory): + return glob.glob(os.path.join(directory, '**', '*.txt'), recursive=True) + +def concatenate_txt_files(files, output_file): + with open(output_file, 'w') as outfile: + for file in files: + with open(file, 'r') as infile: + outfile.write(infile.read() + '\n') # Adds a newline between files + +directory = 'data' +output_file = 'data/scifi.txt' + +# Find all .txt files +txt_files = find_txt_files(directory) + +# Concatenate all found .txt files into one +concatenate_txt_files(txt_files, output_file) diff --git a/scifi-demo/train-scifi.py b/scifi-demo/train-scifi.py new file mode 100644 index 0000000..2be1276 --- /dev/null +++ b/scifi-demo/train-scifi.py @@ -0,0 +1,117 @@ +""" +Train a model +""" +import os +import sys +import pickle +from contextlib import nullcontext +import torch +import tiktoken +from aim import Run +from model import Model + + +# Hyperparameters +batch_size = 12 # How many batches per training step +context_length = 128 # Length of the token chunk each batch +max_iters = 20000 # Total of training iterations <- Change this to smaller number for testing +learning_rate = 1e-3 # 0.001 +eval_interval = 50 # How often to evaluate +eval_iters = 20 # Number of iterations to average for evaluation +device = 'cuda' if torch.cuda.is_available() else 'cpu' # Use GPU if it's available. +TORCH_SEED = 1337 +torch.manual_seed(TORCH_SEED) + + +# AIM Logs +run = Run() +run["hparams"] = { + "learning_rate": learning_rate, + "max_iters": max_iters, + "batch_size": batch_size, + "context_length": context_length +} + +# 准备训练数据 +with open('data/scifi.txt', 'r', encoding="utf-8") as file: + text = file.read() + + +# Using TikToken (Same as GPT3) to tokenize the source text +encoding = tiktoken.get_encoding("cl100k_base") +tokenized_text = encoding.encode(text) +# max_token_value = max(tokenized_text)+1 # the maximum value of the tokenized numbers +tokenized_text = torch.tensor(tokenized_text, dtype=torch.long, device=device) # 将77,919个tokens 转换到Pytorch张量中 + +total_tokens = encoding.encode_ordinary(text) +print(f"数据集合计有 {len(total_tokens):,} tokens") + + + +# vocab = sorted(list(set(text))) +# vocab_size = max_token_value = len(vocab) + +# char2idx = {char: idx for idx, char in enumerate(vocab)} +# idx2char = {idx: char for char, idx in char2idx.items()} +# encode = lambda x: [char2idx[char] for char in x] +# decode = lambda idxs: ''.join([idx2char[idx] for idx in idxs]) +# tokenized_text = torch.tensor(encode(text), dtype=torch.long) + +# Split train and validation +train_size = int(len(tokenized_text) * 0.9) +train_data = tokenized_text[:train_size] +val_data = tokenized_text[train_size:] + + +# Initialize the model +model = Model().to(device) + +# get batch +def get_batch(split: str): + data = train_data if split == 'train' else val_data + idxs = torch.randint(low=0, high=len(data) - context_length, size=(batch_size,)) + x = torch.stack([data[idx:idx + context_length] for idx in idxs]).to(device) + y = torch.stack([data[idx + 1:idx + context_length + 1] for idx in idxs]).to(device) + return x, y + + +# calculate the loss +@torch.no_grad() +def estimate_loss(): + out = {} + model.eval() + for split in ['train', 'valid']: + losses = torch.zeros(eval_iters) + for k in range(eval_iters): + x_batch, y_batch = get_batch(split) + logits, loss = model(x_batch, y_batch) + losses[k] = loss.item() + out[split] = losses.mean() + model.train() + return out + + +# Create the optimizer +optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) +tracked_losses = list() +for step in range(max_iters): + if step % eval_iters == 0 or step == max_iters - 1: + losses = estimate_loss() + tracked_losses.append(losses) + print('Step:', step, 'Training Loss:', round(losses['train'].item(), 3), 'Validation Loss:', round(losses['valid'].item(), 3)) + run.track(round(losses['train'].item(), 3), name='Training Loss') + run.track(round(losses['valid'].item(), 3), name='Validation Loss') + + xb, yb = get_batch('train') + logits, loss = model(xb, yb) + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + +# Save the model +torch.save(model.state_dict(), 'model/model-scifi.pt') + + + + +