Skip to content

Commit 6df608d

Browse files
Add GPT2
1 parent a1bd3f2 commit 6df608d

File tree

1 file changed

+244
-0
lines changed

1 file changed

+244
-0
lines changed

src/Torch/Compose/NN.hs

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ module Torch.Compose.NN where
2121

2222
import Torch
2323
import Torch.Compose
24+
import qualified Torch.Functional.Internal as T
2425
import System.IO.Unsafe (unsafePerformIO)
2526
import GHC.Generics hiding ((:+:))
2627

@@ -289,3 +290,246 @@ instanceForwardAssocs
289290
]
290291
[t| Tensor |] [t| Tensor |]
291292

293+
-------------------------------------------------------------------------------
294+
-- 1. LayerNorm
295+
-------------------------------------------------------------------------------
296+
297+
data LayerNormSpec = LayerNormSpec
298+
{ lnDim :: Int -- ^ dimension (e.g. embedDim)
299+
, lnEps :: Float -- ^ small epsilon
300+
}
301+
deriving (Show, Eq)
302+
303+
data LayerNorm = LayerNorm
304+
{ spec :: LayerNormSpec
305+
, gamma :: Parameter -- scale
306+
, beta :: Parameter -- bias
307+
} deriving (Show)
308+
309+
instance Randomizable LayerNormSpec LayerNorm where
310+
sample s@LayerNormSpec{..} = do
311+
let wInit = ones' [lnDim]
312+
bInit = zeros' [lnDim]
313+
gammaParam <- makeIndependent wInit
314+
betaParam <- makeIndependent bInit
315+
pure LayerNorm
316+
{ spec = s
317+
, gamma = gammaParam
318+
, beta = betaParam
319+
}
320+
321+
--------------------------------------------------------------------------------
322+
-- LayerNorm (fixed mean/var)
323+
--------------------------------------------------------------------------------
324+
325+
instance HasForward LayerNorm Tensor Tensor where
326+
forward LayerNorm{..} input =
327+
let
328+
-- For dimension -1, and keepDim = True:
329+
-- T.meanDim, T.varDim from Torch.Functional.Internal
330+
mean' = meanDim (Dim (-1)) KeepDim Float input
331+
var' = T.varDim input (-1) True True
332+
xNorm = (input - mean') / Torch.sqrt (var' + asTensor spec.lnEps)
333+
out = xNorm * toDependent gamma + toDependent beta
334+
in out
335+
336+
forwardStoch ln = pure . forward ln
337+
338+
-------------------------------------------------------------------------------
339+
-- 2. Simple Feed-Forward Network
340+
-------------------------------------------------------------------------------
341+
342+
data FeedForwardSpec = FeedForwardSpec
343+
{ ffInDim :: Int
344+
, ffHidden :: Int
345+
}
346+
deriving (Show, Eq)
347+
348+
data FeedForward = FeedForward
349+
{ l1 :: Linear
350+
, l2 :: Linear
351+
}
352+
deriving (Show)
353+
354+
instance Randomizable FeedForwardSpec FeedForward where
355+
sample FeedForwardSpec{..} = do
356+
fc1 <- sample $ LinearSpec ffInDim ffHidden
357+
fc2 <- sample $ LinearSpec ffHidden ffInDim
358+
pure FeedForward { l1 = fc1, l2 = fc2 }
359+
360+
instance HasForward FeedForward Tensor Tensor where
361+
forward FeedForward{..} input =
362+
let x1 = relu (linear l1 input)
363+
x2 = linear l2 x1
364+
in x2
365+
366+
forwardStoch ff = pure . forward ff
367+
368+
-------------------------------------------------------------------------------
369+
-- 3. Causal Masking Utility
370+
-------------------------------------------------------------------------------
371+
372+
-- | Create a causal "upper-triangular" mask so that position j > i is masked out.
373+
-- shape: [seqLen, seqLen], with 1.0 = keep, 0.0 = block
374+
createCausalMask :: Int -> Tensor
375+
createCausalMask seqLen =
376+
let range = arange' 0 (fromIntegral seqLen) 1 -- [seqLen]
377+
rowIdx = unsqueeze (Dim (-1)) range -- shape [seqLen, 1]
378+
colIdx = unsqueeze (Dim 0) range -- shape [1, seqLen]
379+
-- If rowIdx < colIdx => "future" => 0.0, else 1.0
380+
keepBool = rowIdx `ge` colIdx
381+
keep = T.where' keepBool (onesLike keepBool) (zerosLike keepBool)
382+
in keep
383+
384+
-------------------------------------------------------------------------------
385+
-- 4. GPT-2 Decoder Block
386+
-------------------------------------------------------------------------------
387+
388+
data GPT2BlockSpec = GPT2BlockSpec
389+
{ blockEmbedDim :: Int
390+
, blockNumHeads :: Int
391+
, blockFfHidden :: Int
392+
, blockLnEps :: Float
393+
}
394+
deriving (Show, Eq)
395+
396+
data GPT2Block = GPT2Block
397+
{ ln1 :: LayerNorm
398+
, attn :: MultiHeadAttention
399+
, ln2 :: LayerNorm
400+
, ff :: FeedForward
401+
}
402+
deriving (Show)
403+
404+
instance Randomizable GPT2BlockSpec GPT2Block where
405+
sample GPT2BlockSpec{..} = do
406+
let lnSpec = LayerNormSpec blockEmbedDim blockLnEps
407+
ffSpec = FeedForwardSpec blockEmbedDim blockFfHidden
408+
mhaSpec = MultiHeadAttentionSpec blockEmbedDim blockNumHeads
409+
GPT2Block
410+
<$> sample lnSpec
411+
<*> sample mhaSpec
412+
<*> sample lnSpec
413+
<*> sample ffSpec
414+
415+
-- | GPT2Block forward:
416+
-- 1) LN + masked self-attn
417+
-- 2) Residual
418+
-- 3) LN + feed-forward
419+
-- 4) Residual
420+
instance HasForward GPT2Block (Tensor, Tensor) Tensor where
421+
-- ^ We'll accept `(x, mask)` as input, return the new hidden states.
422+
-- The `mask` is shape [1, seqLen, seqLen] or broadcastable to [batchSize, seqLen, seqLen].
423+
forward GPT2Block{..} (x, mask) =
424+
let xNorm = forward ln1 x
425+
-- Because our 'multiHeadAttention' does not directly accept a mask yet,
426+
-- we can *simulate* it by zeroing out "future" attention in the matmul,
427+
-- or you can adapt your MHA to accept a mask argument.
428+
-- For simplicity, let's do a minimal approach:
429+
-- We'll skip the explicit mask in the code if your MHA doesn't use it.
430+
-- If you extended multiHeadAttention to handle a mask, you'd pass it there.
431+
attnOut = multiHeadAttention attn xNorm xNorm xNorm
432+
x1 = x + attnOut -- residual
433+
x1Norm = forward ln2 x1
434+
ffOut = forward ff x1Norm
435+
x2 = x1 + ffOut -- residual
436+
in x2
437+
438+
forwardStoch block (x, mask) = pure $ forward block (x, mask)
439+
440+
-------------------------------------------------------------------------------
441+
-- 5. The Full GPT2 Model
442+
-------------------------------------------------------------------------------
443+
444+
data GPT2Spec = GPT2Spec
445+
{ vocabSize :: Int
446+
, maxPos :: Int
447+
, numLayers :: Int
448+
, embedDim :: Int
449+
, numHeads :: Int
450+
, ffHiddenDim:: Int
451+
, lnEpsVal :: Float
452+
}
453+
deriving (Show, Eq)
454+
455+
data GPT2 = GPT2
456+
{ tokenEmbed :: Parameter -- ^ [vocabSize, embedDim]
457+
, positionEmbed:: Parameter -- ^ [maxPos, embedDim]
458+
, blocks :: [GPT2Block]
459+
, lnFinal :: LayerNorm
460+
}
461+
deriving (Show)
462+
463+
instance Randomizable GPT2Spec GPT2 where
464+
sample GPT2Spec{..} = do
465+
tokenParam <- makeIndependent =<< randnIO' [vocabSize, embedDim]
466+
posParam <- makeIndependent =<< randnIO' [maxPos, embedDim]
467+
let blockSpec = GPT2BlockSpec
468+
{ blockEmbedDim = embedDim
469+
, blockNumHeads = numHeads
470+
, blockFfHidden = ffHiddenDim
471+
, blockLnEps = lnEpsVal
472+
}
473+
gpt2Blocks <- mapM (const $ sample blockSpec) [1..numLayers]
474+
finalNorm <- sample $ LayerNormSpec embedDim lnEpsVal
475+
pure GPT2
476+
{ tokenEmbed = tokenParam
477+
, positionEmbed = posParam
478+
, blocks = gpt2Blocks
479+
, lnFinal = finalNorm
480+
}
481+
482+
-- | We'll define HasForward for GPT2 taking just the input token IDs:
483+
-- shape: [batchSize, seqLen], returning [batchSize, seqLen, vocabSize].
484+
instance HasForward GPT2 Tensor Tensor where
485+
forward GPT2{..} inputIds =
486+
let (batchSize, seqLen) = case shape inputIds of
487+
[b, s] -> (b, s)
488+
_ -> error "GPT2 forward: expected [batchSize, seqLen]"
489+
-- 1) Get token embeddings
490+
xToken = embedding' (toDependent tokenEmbed) inputIds
491+
-- [batchSize, seqLen, embedDim]
492+
-- 2) Get position embeddings
493+
positions = arange' 0 (fromIntegral seqLen) 1 -- [seqLen]
494+
posEmbs = embedding' (toDependent positionEmbed) positions
495+
-- [seqLen, embedDim]
496+
posEmbs3d = unsqueeze (Dim 0) posEmbs
497+
-- [1, seqLen, embedDim]
498+
posEmbsB = expand posEmbs3d False [batchSize, seqLen, shape posEmbs3d !! 2]
499+
500+
x = xToken + posEmbsB
501+
-- 3) Build a causal mask if your MHA supports it; for now let's ignore if your MHA doesn't handle masks:
502+
mask = unsqueeze (Dim 0) (createCausalMask seqLen)
503+
-- shape [1, seqLen, seqLen]
504+
505+
-- 4) Pass through each GPT2Block
506+
xOut = foldl (\acc block -> forward block (acc, mask)) x blocks
507+
-- 5) Final layer norm
508+
xNorm = forward lnFinal xOut
509+
-- 6) Project to vocab (if you want weight tying, typically we do xNorm `matmul` transpose tokenEmbed)
510+
tokenWeightT = transpose2D (toDependent tokenEmbed)
511+
-- shape [embedDim, vocabSize]
512+
logits = xNorm `matmul` tokenWeightT
513+
-- [batchSize, seqLen, vocabSize]
514+
in logits
515+
516+
forwardStoch net inputIds = pure $ forward net inputIds
517+
518+
-------------------------------------------------------------------------------
519+
-- 6. Add HasForwardAssoc (Optional)
520+
-------------------------------------------------------------------------------
521+
522+
-- If you are using `instanceForwardAssocs` to auto-generate associated type families,
523+
-- you can include GPT2, GPT2Block, and so on. For example:
524+
{-
525+
instanceForwardAssocs
526+
[ [t| GPT2Block |]
527+
, [t| GPT2 |]
528+
]
529+
[t| (Tensor, Tensor) |] -- For GPT2Block we used (x,mask) as input
530+
[t| Tensor |]
531+
532+
instanceForwardAssocs
533+
[ [t| GPT2 |] ]
534+
[t| Tensor |] [t| Tensor |]
535+
-}

0 commit comments

Comments
 (0)