@@ -21,6 +21,7 @@ module Torch.Compose.NN where
21
21
22
22
import Torch
23
23
import Torch.Compose
24
+ import qualified Torch.Functional.Internal as T
24
25
import System.IO.Unsafe (unsafePerformIO )
25
26
import GHC.Generics hiding ((:+:) )
26
27
@@ -289,3 +290,246 @@ instanceForwardAssocs
289
290
]
290
291
[t | Tensor |] [t | Tensor |]
291
292
293
+ -------------------------------------------------------------------------------
294
+ -- 1. LayerNorm
295
+ -------------------------------------------------------------------------------
296
+
297
+ data LayerNormSpec = LayerNormSpec
298
+ { lnDim :: Int -- ^ dimension (e.g. embedDim)
299
+ , lnEps :: Float -- ^ small epsilon
300
+ }
301
+ deriving (Show , Eq )
302
+
303
+ data LayerNorm = LayerNorm
304
+ { spec :: LayerNormSpec
305
+ , gamma :: Parameter -- scale
306
+ , beta :: Parameter -- bias
307
+ } deriving (Show )
308
+
309
+ instance Randomizable LayerNormSpec LayerNorm where
310
+ sample s@ LayerNormSpec {.. } = do
311
+ let wInit = ones' [lnDim]
312
+ bInit = zeros' [lnDim]
313
+ gammaParam <- makeIndependent wInit
314
+ betaParam <- makeIndependent bInit
315
+ pure LayerNorm
316
+ { spec = s
317
+ , gamma = gammaParam
318
+ , beta = betaParam
319
+ }
320
+
321
+ --------------------------------------------------------------------------------
322
+ -- LayerNorm (fixed mean/var)
323
+ --------------------------------------------------------------------------------
324
+
325
+ instance HasForward LayerNorm Tensor Tensor where
326
+ forward LayerNorm {.. } input =
327
+ let
328
+ -- For dimension -1, and keepDim = True:
329
+ -- T.meanDim, T.varDim from Torch.Functional.Internal
330
+ mean' = meanDim (Dim (- 1 )) KeepDim Float input
331
+ var' = T. varDim input (- 1 ) True True
332
+ xNorm = (input - mean') / Torch. sqrt (var' + asTensor spec. lnEps)
333
+ out = xNorm * toDependent gamma + toDependent beta
334
+ in out
335
+
336
+ forwardStoch ln = pure . forward ln
337
+
338
+ -------------------------------------------------------------------------------
339
+ -- 2. Simple Feed-Forward Network
340
+ -------------------------------------------------------------------------------
341
+
342
+ data FeedForwardSpec = FeedForwardSpec
343
+ { ffInDim :: Int
344
+ , ffHidden :: Int
345
+ }
346
+ deriving (Show , Eq )
347
+
348
+ data FeedForward = FeedForward
349
+ { l1 :: Linear
350
+ , l2 :: Linear
351
+ }
352
+ deriving (Show )
353
+
354
+ instance Randomizable FeedForwardSpec FeedForward where
355
+ sample FeedForwardSpec {.. } = do
356
+ fc1 <- sample $ LinearSpec ffInDim ffHidden
357
+ fc2 <- sample $ LinearSpec ffHidden ffInDim
358
+ pure FeedForward { l1 = fc1, l2 = fc2 }
359
+
360
+ instance HasForward FeedForward Tensor Tensor where
361
+ forward FeedForward {.. } input =
362
+ let x1 = relu (linear l1 input)
363
+ x2 = linear l2 x1
364
+ in x2
365
+
366
+ forwardStoch ff = pure . forward ff
367
+
368
+ -------------------------------------------------------------------------------
369
+ -- 3. Causal Masking Utility
370
+ -------------------------------------------------------------------------------
371
+
372
+ -- | Create a causal "upper-triangular" mask so that position j > i is masked out.
373
+ -- shape: [seqLen, seqLen], with 1.0 = keep, 0.0 = block
374
+ createCausalMask :: Int -> Tensor
375
+ createCausalMask seqLen =
376
+ let range = arange' 0 (fromIntegral seqLen) 1 -- [seqLen]
377
+ rowIdx = unsqueeze (Dim (- 1 )) range -- shape [seqLen, 1]
378
+ colIdx = unsqueeze (Dim 0 ) range -- shape [1, seqLen]
379
+ -- If rowIdx < colIdx => "future" => 0.0, else 1.0
380
+ keepBool = rowIdx `ge` colIdx
381
+ keep = T. where' keepBool (onesLike keepBool) (zerosLike keepBool)
382
+ in keep
383
+
384
+ -------------------------------------------------------------------------------
385
+ -- 4. GPT-2 Decoder Block
386
+ -------------------------------------------------------------------------------
387
+
388
+ data GPT2BlockSpec = GPT2BlockSpec
389
+ { blockEmbedDim :: Int
390
+ , blockNumHeads :: Int
391
+ , blockFfHidden :: Int
392
+ , blockLnEps :: Float
393
+ }
394
+ deriving (Show , Eq )
395
+
396
+ data GPT2Block = GPT2Block
397
+ { ln1 :: LayerNorm
398
+ , attn :: MultiHeadAttention
399
+ , ln2 :: LayerNorm
400
+ , ff :: FeedForward
401
+ }
402
+ deriving (Show )
403
+
404
+ instance Randomizable GPT2BlockSpec GPT2Block where
405
+ sample GPT2BlockSpec {.. } = do
406
+ let lnSpec = LayerNormSpec blockEmbedDim blockLnEps
407
+ ffSpec = FeedForwardSpec blockEmbedDim blockFfHidden
408
+ mhaSpec = MultiHeadAttentionSpec blockEmbedDim blockNumHeads
409
+ GPT2Block
410
+ <$> sample lnSpec
411
+ <*> sample mhaSpec
412
+ <*> sample lnSpec
413
+ <*> sample ffSpec
414
+
415
+ -- | GPT2Block forward:
416
+ -- 1) LN + masked self-attn
417
+ -- 2) Residual
418
+ -- 3) LN + feed-forward
419
+ -- 4) Residual
420
+ instance HasForward GPT2Block (Tensor , Tensor ) Tensor where
421
+ -- ^ We'll accept `(x, mask)` as input, return the new hidden states.
422
+ -- The `mask` is shape [1, seqLen, seqLen] or broadcastable to [batchSize, seqLen, seqLen].
423
+ forward GPT2Block {.. } (x, mask) =
424
+ let xNorm = forward ln1 x
425
+ -- Because our 'multiHeadAttention' does not directly accept a mask yet,
426
+ -- we can *simulate* it by zeroing out "future" attention in the matmul,
427
+ -- or you can adapt your MHA to accept a mask argument.
428
+ -- For simplicity, let's do a minimal approach:
429
+ -- We'll skip the explicit mask in the code if your MHA doesn't use it.
430
+ -- If you extended multiHeadAttention to handle a mask, you'd pass it there.
431
+ attnOut = multiHeadAttention attn xNorm xNorm xNorm
432
+ x1 = x + attnOut -- residual
433
+ x1Norm = forward ln2 x1
434
+ ffOut = forward ff x1Norm
435
+ x2 = x1 + ffOut -- residual
436
+ in x2
437
+
438
+ forwardStoch block (x, mask) = pure $ forward block (x, mask)
439
+
440
+ -------------------------------------------------------------------------------
441
+ -- 5. The Full GPT2 Model
442
+ -------------------------------------------------------------------------------
443
+
444
+ data GPT2Spec = GPT2Spec
445
+ { vocabSize :: Int
446
+ , maxPos :: Int
447
+ , numLayers :: Int
448
+ , embedDim :: Int
449
+ , numHeads :: Int
450
+ , ffHiddenDim :: Int
451
+ , lnEpsVal :: Float
452
+ }
453
+ deriving (Show , Eq )
454
+
455
+ data GPT2 = GPT2
456
+ { tokenEmbed :: Parameter -- ^ [vocabSize, embedDim]
457
+ , positionEmbed :: Parameter -- ^ [maxPos, embedDim]
458
+ , blocks :: [GPT2Block ]
459
+ , lnFinal :: LayerNorm
460
+ }
461
+ deriving (Show )
462
+
463
+ instance Randomizable GPT2Spec GPT2 where
464
+ sample GPT2Spec {.. } = do
465
+ tokenParam <- makeIndependent =<< randnIO' [vocabSize, embedDim]
466
+ posParam <- makeIndependent =<< randnIO' [maxPos, embedDim]
467
+ let blockSpec = GPT2BlockSpec
468
+ { blockEmbedDim = embedDim
469
+ , blockNumHeads = numHeads
470
+ , blockFfHidden = ffHiddenDim
471
+ , blockLnEps = lnEpsVal
472
+ }
473
+ gpt2Blocks <- mapM (const $ sample blockSpec) [1 .. numLayers]
474
+ finalNorm <- sample $ LayerNormSpec embedDim lnEpsVal
475
+ pure GPT2
476
+ { tokenEmbed = tokenParam
477
+ , positionEmbed = posParam
478
+ , blocks = gpt2Blocks
479
+ , lnFinal = finalNorm
480
+ }
481
+
482
+ -- | We'll define HasForward for GPT2 taking just the input token IDs:
483
+ -- shape: [batchSize, seqLen], returning [batchSize, seqLen, vocabSize].
484
+ instance HasForward GPT2 Tensor Tensor where
485
+ forward GPT2 {.. } inputIds =
486
+ let (batchSize, seqLen) = case shape inputIds of
487
+ [b, s] -> (b, s)
488
+ _ -> error " GPT2 forward: expected [batchSize, seqLen]"
489
+ -- 1) Get token embeddings
490
+ xToken = embedding' (toDependent tokenEmbed) inputIds
491
+ -- [batchSize, seqLen, embedDim]
492
+ -- 2) Get position embeddings
493
+ positions = arange' 0 (fromIntegral seqLen) 1 -- [seqLen]
494
+ posEmbs = embedding' (toDependent positionEmbed) positions
495
+ -- [seqLen, embedDim]
496
+ posEmbs3d = unsqueeze (Dim 0 ) posEmbs
497
+ -- [1, seqLen, embedDim]
498
+ posEmbsB = expand posEmbs3d False [batchSize, seqLen, shape posEmbs3d !! 2 ]
499
+
500
+ x = xToken + posEmbsB
501
+ -- 3) Build a causal mask if your MHA supports it; for now let's ignore if your MHA doesn't handle masks:
502
+ mask = unsqueeze (Dim 0 ) (createCausalMask seqLen)
503
+ -- shape [1, seqLen, seqLen]
504
+
505
+ -- 4) Pass through each GPT2Block
506
+ xOut = foldl (\ acc block -> forward block (acc, mask)) x blocks
507
+ -- 5) Final layer norm
508
+ xNorm = forward lnFinal xOut
509
+ -- 6) Project to vocab (if you want weight tying, typically we do xNorm `matmul` transpose tokenEmbed)
510
+ tokenWeightT = transpose2D (toDependent tokenEmbed)
511
+ -- shape [embedDim, vocabSize]
512
+ logits = xNorm `matmul` tokenWeightT
513
+ -- [batchSize, seqLen, vocabSize]
514
+ in logits
515
+
516
+ forwardStoch net inputIds = pure $ forward net inputIds
517
+
518
+ -------------------------------------------------------------------------------
519
+ -- 6. Add HasForwardAssoc (Optional)
520
+ -------------------------------------------------------------------------------
521
+
522
+ -- If you are using `instanceForwardAssocs` to auto-generate associated type families,
523
+ -- you can include GPT2, GPT2Block, and so on. For example:
524
+ {-
525
+ instanceForwardAssocs
526
+ [ [t| GPT2Block |]
527
+ , [t| GPT2 |]
528
+ ]
529
+ [t| (Tensor, Tensor) |] -- For GPT2Block we used (x,mask) as input
530
+ [t| Tensor |]
531
+
532
+ instanceForwardAssocs
533
+ [ [t| GPT2 |] ]
534
+ [t| Tensor |] [t| Tensor |]
535
+ -}
0 commit comments