|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | +from torch.utils.data import Dataset, DataLoader |
| 5 | +from torch.cuda.amp import autocast, GradScaler |
| 6 | +import math |
| 7 | +import logging |
| 8 | +from typing import Optional, List, Tuple, Dict, Any |
| 9 | +from dataclasses import dataclass |
| 10 | + |
| 11 | +# --- Configure Logging --- |
| 12 | +logging.basicConfig(level=logging.INFO) |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | +# --- Model Configuration --- |
| 16 | +@dataclass |
| 17 | +class ModelConfig: |
| 18 | + vocab_size: int |
| 19 | + d_model: int = 512 |
| 20 | + num_heads: int = 8 |
| 21 | + num_layers: int = 6 |
| 22 | + d_ff: int = 2048 |
| 23 | + dropout: float = 0.1 |
| 24 | + max_seq_length: int = 5000 |
| 25 | + learning_rate: float = 1e-4 |
| 26 | + warmup_steps: int = 4000 |
| 27 | + label_smoothing: float = 0.1 |
| 28 | + weight_decay: float = 0.01 |
| 29 | + betas: Tuple[float, float] = (0.9, 0.98) |
| 30 | + eps: float = 1e-9 |
| 31 | + pad_token_id: int = 0 |
| 32 | + |
| 33 | +# --- Positional Encoding --- |
| 34 | +class PositionalEncoding(nn.Module): |
| 35 | + def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1): |
| 36 | + super().__init__() |
| 37 | + self.dropout = nn.Dropout(dropout) |
| 38 | + pe = torch.zeros(max_len, d_model) |
| 39 | + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| 40 | + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) |
| 41 | + pe[:, 0::2] = torch.sin(position * div_term) |
| 42 | + pe[:, 1::2] = torch.cos(position * div_term) |
| 43 | + self.register_buffer("pe", pe.unsqueeze(0)) |
| 44 | + |
| 45 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 46 | + return self.dropout(x + self.pe[:, :x.size(1)]) |
| 47 | + |
| 48 | +# --- Multi-Head Attention --- |
| 49 | +class MultiHeadAttention(nn.Module): |
| 50 | + def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1): |
| 51 | + super().__init__() |
| 52 | + assert d_model % num_heads == 0 |
| 53 | + self.d_k = d_model // num_heads |
| 54 | + self.qkv = nn.Linear(d_model, d_model * 3) |
| 55 | + self.out = nn.Linear(d_model, d_model) |
| 56 | + self.dropout = nn.Dropout(dropout) |
| 57 | + |
| 58 | + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| 59 | + B, T, _ = x.shape |
| 60 | + qkv = self.qkv(x).chunk(3, dim=-1) |
| 61 | + q, k, v = map(lambda t: t.view(B, T, -1, self.d_k).transpose(1, 2), qkv) |
| 62 | + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) |
| 63 | + if mask is not None: |
| 64 | + scores = scores.masked_fill(mask == 0, float("-inf")) |
| 65 | + attn = F.softmax(scores, dim=-1) |
| 66 | + attn = self.dropout(attn) |
| 67 | + out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, T, -1) |
| 68 | + return self.out(out) |
| 69 | + |
| 70 | +# --- Feed Forward Network --- |
| 71 | +class FeedForward(nn.Module): |
| 72 | + def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): |
| 73 | + super().__init__() |
| 74 | + self.net = nn.Sequential( |
| 75 | + nn.Linear(d_model, d_ff), |
| 76 | + nn.GELU(), |
| 77 | + nn.Dropout(dropout), |
| 78 | + nn.Linear(d_ff, d_model), |
| 79 | + ) |
| 80 | + |
| 81 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 82 | + return self.net(x) |
| 83 | + |
| 84 | +# --- Transformer Block (Pre-Norm) --- |
| 85 | +class TransformerBlock(nn.Module): |
| 86 | + def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1): |
| 87 | + super().__init__() |
| 88 | + self.attn = MultiHeadAttention(d_model, num_heads, dropout) |
| 89 | + self.ffn = FeedForward(d_model, d_ff, dropout) |
| 90 | + self.norm1 = nn.LayerNorm(d_model) |
| 91 | + self.norm2 = nn.LayerNorm(d_model) |
| 92 | + self.dropout = nn.Dropout(dropout) |
| 93 | + |
| 94 | + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| 95 | + x = x + self.dropout(self.attn(self.norm1(x), mask)) |
| 96 | + x = x + self.dropout(self.ffn(self.norm2(x))) |
| 97 | + return x |
| 98 | + |
| 99 | +# --- Main NLP Model --- |
| 100 | +class NLPModel(nn.Module): # Renamed from EnhancedNLPModel |
| 101 | + def __init__(self, config: ModelConfig): |
| 102 | + super().__init__() |
| 103 | + self.config = config |
| 104 | + self.embedding = nn.Embedding( |
| 105 | + config.vocab_size, config.d_model, padding_idx=config.pad_token_id |
| 106 | + ) |
| 107 | + self.pos_enc = PositionalEncoding(config.d_model, config.max_seq_length, config.dropout) |
| 108 | + self.blocks = nn.Sequential( |
| 109 | + *[TransformerBlock(config.d_model, config.num_heads, config.d_ff, config.dropout) for _ in range(config.num_layers)] |
| 110 | + ) |
| 111 | + self.final_layer = nn.Linear(config.d_model, config.vocab_size) |
| 112 | + self.final_layer.weight = self.embedding.weight # Weight tying |
| 113 | + self._init_weights() |
| 114 | + |
| 115 | + def _init_weights(self): |
| 116 | + for p in self.parameters(): |
| 117 | + if p.dim() > 1: |
| 118 | + nn.init.xavier_uniform_(p) |
| 119 | + else: |
| 120 | + nn.init.zeros_(p) |
| 121 | + |
| 122 | + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| 123 | + x = self.embedding(x) * math.sqrt(self.config.d_model) |
| 124 | + x = self.pos_enc(x) |
| 125 | + for block in self.blocks: |
| 126 | + x = block(x, mask) |
| 127 | + return self.final_layer(x) |
| 128 | + |
| 129 | +# --- Enhanced Trainer --- |
| 130 | +class EnhancedTrainer: |
| 131 | + def __init__(self, model: NLPModel, config: ModelConfig): |
| 132 | + self.model = model |
| 133 | + self.config = config |
| 134 | + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 135 | + self.model.to(self.device) |
| 136 | + self.optimizer = torch.optim.AdamW( |
| 137 | + model.parameters(), |
| 138 | + lr=config.learning_rate, |
| 139 | + weight_decay=config.weight_decay, |
| 140 | + betas=config.betas, |
| 141 | + eps=config.eps, |
| 142 | + ) |
| 143 | + self.scheduler = self._create_scheduler() |
| 144 | + self.criterion = nn.CrossEntropyLoss( |
| 145 | + label_smoothing=config.label_smoothing, ignore_index=config.pad_token_id |
| 146 | + ) |
| 147 | + self.scaler = GradScaler() |
| 148 | + |
| 149 | + def _create_scheduler(self): |
| 150 | + def lr_lambda(step): |
| 151 | + step = max(1, step) |
| 152 | + return min(step ** -0.5, step * self.config.warmup_steps ** -1.5) |
| 153 | + return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) |
| 154 | + |
| 155 | + def train_epoch(self, dataloader: DataLoader, accumulation_steps: int = 4): |
| 156 | + self.model.train() |
| 157 | + total_loss = 0 |
| 158 | + for i, batch in enumerate(dataloader): |
| 159 | + src, tgt = batch[:, :-1], batch[:, 1:] |
| 160 | + src, tgt = src.to(self.device), tgt.to(self.device) |
| 161 | + mask = torch.tril(torch.ones(src.size(1), src.size(1))).bool().to(src.device) |
| 162 | + |
| 163 | + with autocast(): |
| 164 | + logits = self.model(src, mask) |
| 165 | + loss = self.criterion(logits.view(-1, self.config.vocab_size), tgt.view(-1)) / accumulation_steps |
| 166 | + |
| 167 | + self.scaler.scale(loss).backward() |
| 168 | + if (i + 1) % accumulation_steps == 0: |
| 169 | + self.scaler.unscale_(self.optimizer) |
| 170 | + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| 171 | + self.scaler.step(self.optimizer) |
| 172 | + self.scaler.update() |
| 173 | + self.scheduler.step() |
| 174 | + self.optimizer.zero_grad() |
| 175 | + |
| 176 | + total_loss += loss.item() * accumulation_steps |
| 177 | + if i % 100 == 0: |
| 178 | + logger.info(f"Batch {i}, Loss: {loss.item() * accumulation_steps:.4f}") |
| 179 | + return total_loss / len(dataloader) |
| 180 | + |
| 181 | +# --- Helper Function --- |
| 182 | +def create_model(vocab_size: int) -> tuple[NLPModel, EnhancedTrainer]: |
| 183 | + config = ModelConfig(vocab_size=vocab_size) |
| 184 | + model = NLPModel(config) |
| 185 | + trainer = EnhancedTrainer(model, config) |
| 186 | + return model, trainer |
0 commit comments