Skip to content

Commit afd2239

Browse files
authored
Update and rename enhanced_nlp_model.py to NLPModel.py
Refactor NLPModel with pre-norm, weight tying, causal masking, and PyTorch 2.2+ support. - Added ModelConfig with label smoothing, AdamW, and gradient clipping - Implemented mixed precision and gradient accumulation - Improved positional encoding and attention logic - Enhanced logging and device management - Fully type-annotated and Python 3.11+ compatible
1 parent 198921f commit afd2239

File tree

2 files changed

+186
-255
lines changed

2 files changed

+186
-255
lines changed

NLPModel.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.utils.data import Dataset, DataLoader
5+
from torch.cuda.amp import autocast, GradScaler
6+
import math
7+
import logging
8+
from typing import Optional, List, Tuple, Dict, Any
9+
from dataclasses import dataclass
10+
11+
# --- Configure Logging ---
12+
logging.basicConfig(level=logging.INFO)
13+
logger = logging.getLogger(__name__)
14+
15+
# --- Model Configuration ---
16+
@dataclass
17+
class ModelConfig:
18+
vocab_size: int
19+
d_model: int = 512
20+
num_heads: int = 8
21+
num_layers: int = 6
22+
d_ff: int = 2048
23+
dropout: float = 0.1
24+
max_seq_length: int = 5000
25+
learning_rate: float = 1e-4
26+
warmup_steps: int = 4000
27+
label_smoothing: float = 0.1
28+
weight_decay: float = 0.01
29+
betas: Tuple[float, float] = (0.9, 0.98)
30+
eps: float = 1e-9
31+
pad_token_id: int = 0
32+
33+
# --- Positional Encoding ---
34+
class PositionalEncoding(nn.Module):
35+
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
36+
super().__init__()
37+
self.dropout = nn.Dropout(dropout)
38+
pe = torch.zeros(max_len, d_model)
39+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
40+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
41+
pe[:, 0::2] = torch.sin(position * div_term)
42+
pe[:, 1::2] = torch.cos(position * div_term)
43+
self.register_buffer("pe", pe.unsqueeze(0))
44+
45+
def forward(self, x: torch.Tensor) -> torch.Tensor:
46+
return self.dropout(x + self.pe[:, :x.size(1)])
47+
48+
# --- Multi-Head Attention ---
49+
class MultiHeadAttention(nn.Module):
50+
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
51+
super().__init__()
52+
assert d_model % num_heads == 0
53+
self.d_k = d_model // num_heads
54+
self.qkv = nn.Linear(d_model, d_model * 3)
55+
self.out = nn.Linear(d_model, d_model)
56+
self.dropout = nn.Dropout(dropout)
57+
58+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
59+
B, T, _ = x.shape
60+
qkv = self.qkv(x).chunk(3, dim=-1)
61+
q, k, v = map(lambda t: t.view(B, T, -1, self.d_k).transpose(1, 2), qkv)
62+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
63+
if mask is not None:
64+
scores = scores.masked_fill(mask == 0, float("-inf"))
65+
attn = F.softmax(scores, dim=-1)
66+
attn = self.dropout(attn)
67+
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, T, -1)
68+
return self.out(out)
69+
70+
# --- Feed Forward Network ---
71+
class FeedForward(nn.Module):
72+
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
73+
super().__init__()
74+
self.net = nn.Sequential(
75+
nn.Linear(d_model, d_ff),
76+
nn.GELU(),
77+
nn.Dropout(dropout),
78+
nn.Linear(d_ff, d_model),
79+
)
80+
81+
def forward(self, x: torch.Tensor) -> torch.Tensor:
82+
return self.net(x)
83+
84+
# --- Transformer Block (Pre-Norm) ---
85+
class TransformerBlock(nn.Module):
86+
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
87+
super().__init__()
88+
self.attn = MultiHeadAttention(d_model, num_heads, dropout)
89+
self.ffn = FeedForward(d_model, d_ff, dropout)
90+
self.norm1 = nn.LayerNorm(d_model)
91+
self.norm2 = nn.LayerNorm(d_model)
92+
self.dropout = nn.Dropout(dropout)
93+
94+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
95+
x = x + self.dropout(self.attn(self.norm1(x), mask))
96+
x = x + self.dropout(self.ffn(self.norm2(x)))
97+
return x
98+
99+
# --- Main NLP Model ---
100+
class NLPModel(nn.Module): # Renamed from EnhancedNLPModel
101+
def __init__(self, config: ModelConfig):
102+
super().__init__()
103+
self.config = config
104+
self.embedding = nn.Embedding(
105+
config.vocab_size, config.d_model, padding_idx=config.pad_token_id
106+
)
107+
self.pos_enc = PositionalEncoding(config.d_model, config.max_seq_length, config.dropout)
108+
self.blocks = nn.Sequential(
109+
*[TransformerBlock(config.d_model, config.num_heads, config.d_ff, config.dropout) for _ in range(config.num_layers)]
110+
)
111+
self.final_layer = nn.Linear(config.d_model, config.vocab_size)
112+
self.final_layer.weight = self.embedding.weight # Weight tying
113+
self._init_weights()
114+
115+
def _init_weights(self):
116+
for p in self.parameters():
117+
if p.dim() > 1:
118+
nn.init.xavier_uniform_(p)
119+
else:
120+
nn.init.zeros_(p)
121+
122+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
123+
x = self.embedding(x) * math.sqrt(self.config.d_model)
124+
x = self.pos_enc(x)
125+
for block in self.blocks:
126+
x = block(x, mask)
127+
return self.final_layer(x)
128+
129+
# --- Enhanced Trainer ---
130+
class EnhancedTrainer:
131+
def __init__(self, model: NLPModel, config: ModelConfig):
132+
self.model = model
133+
self.config = config
134+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
135+
self.model.to(self.device)
136+
self.optimizer = torch.optim.AdamW(
137+
model.parameters(),
138+
lr=config.learning_rate,
139+
weight_decay=config.weight_decay,
140+
betas=config.betas,
141+
eps=config.eps,
142+
)
143+
self.scheduler = self._create_scheduler()
144+
self.criterion = nn.CrossEntropyLoss(
145+
label_smoothing=config.label_smoothing, ignore_index=config.pad_token_id
146+
)
147+
self.scaler = GradScaler()
148+
149+
def _create_scheduler(self):
150+
def lr_lambda(step):
151+
step = max(1, step)
152+
return min(step ** -0.5, step * self.config.warmup_steps ** -1.5)
153+
return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
154+
155+
def train_epoch(self, dataloader: DataLoader, accumulation_steps: int = 4):
156+
self.model.train()
157+
total_loss = 0
158+
for i, batch in enumerate(dataloader):
159+
src, tgt = batch[:, :-1], batch[:, 1:]
160+
src, tgt = src.to(self.device), tgt.to(self.device)
161+
mask = torch.tril(torch.ones(src.size(1), src.size(1))).bool().to(src.device)
162+
163+
with autocast():
164+
logits = self.model(src, mask)
165+
loss = self.criterion(logits.view(-1, self.config.vocab_size), tgt.view(-1)) / accumulation_steps
166+
167+
self.scaler.scale(loss).backward()
168+
if (i + 1) % accumulation_steps == 0:
169+
self.scaler.unscale_(self.optimizer)
170+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
171+
self.scaler.step(self.optimizer)
172+
self.scaler.update()
173+
self.scheduler.step()
174+
self.optimizer.zero_grad()
175+
176+
total_loss += loss.item() * accumulation_steps
177+
if i % 100 == 0:
178+
logger.info(f"Batch {i}, Loss: {loss.item() * accumulation_steps:.4f}")
179+
return total_loss / len(dataloader)
180+
181+
# --- Helper Function ---
182+
def create_model(vocab_size: int) -> tuple[NLPModel, EnhancedTrainer]:
183+
config = ModelConfig(vocab_size=vocab_size)
184+
model = NLPModel(config)
185+
trainer = EnhancedTrainer(model, config)
186+
return model, trainer

0 commit comments

Comments
 (0)