@@ -26,7 +26,7 @@ import Test.Framework (defaultMain, Test)
26
26
import Lens.Family2 ((^..) , (.~) )
27
27
28
28
import Test.Framework.Providers.HUnit (testCase )
29
- import Test.HUnit ((@=?) , assertEqual )
29
+ import Test.HUnit ((@=?) , assertEqual , assertFailure )
30
30
import qualified Data.Vector as V
31
31
import System.Random (randomIO , randomRIO )
32
32
import Control.Monad (forM_ , replicateM , zipWithM )
@@ -557,13 +557,19 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
557
557
x <- TF. render $ TF. zeros $ TF. Shape [batch, 1 ]
558
558
w <- TF. zeroInitializedVariable $ TF. Shape [1 , width]
559
559
let f = x `TF.matMul` TF. readValue w
560
- [dfdx] <- TF. gradients f [x]
560
+ l1 <- TF. gradients f [x]
561
+ let dfdx = head l1 -- avoid MonadFail
561
562
let f'x = TF. reduceSum dfdx
562
- [dfdw] <- TF. gradients f'x [w] -- take gradient again (this time over w)
563
+ l2 <- TF. gradients f'x [w] -- take gradient again (this time over w)
564
+ let dfdw = head l2
563
565
return [TF. readValue w, TF. expr dfdw]
564
566
565
567
TF. runSession $ do
566
- [w, dfdw] <- TF. build tower
568
+ l <- TF. build tower
569
+ (w, dfdw) <-
570
+ case l of
571
+ [w, dfdw] -> pure (w, dfdw)
572
+ _ -> liftIO $ assertFailure " pattern-match failure in matMulGradMad"
567
573
(wShape, dfdwShape) <- TF. run (TF. shape w, TF. shape dfdw)
568
574
liftIO $ assertEqual " Shape of gradient must match input" wShape (dfdwShape :: V. Vector Int32 )
569
575
@@ -589,7 +595,11 @@ matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ show txw)
589
595
return (x, w, ds)
590
596
591
597
TF. runSession $ do
592
- (x, w, [dx, dw]) <- TF. build dfBuild
598
+ (x, w, d) <- TF. build dfBuild
599
+ (dx, dw) <-
600
+ case d of
601
+ [dx, dw] -> pure (dx, dw)
602
+ _ -> liftIO $ assertFailure " pattern-match failure in matMulTransposeGradient"
593
603
xShape <- TF. run $ TF. shape x
594
604
dxShape <- TF. run $ TF. shape dx
595
605
liftIO $ assertEqual " xShape must match dxShape" xShape (dxShape :: V. Vector Int32 )
@@ -616,7 +626,11 @@ batchMatMulGradient = testCase "batchMatMulGradients" $ do
616
626
return (x, dfs)
617
627
618
628
(xShape, dxShape) <- TF. runSession $ do
619
- (x, [dx]) <- TF. build dfBuild
629
+ (x, dl) <- TF. build dfBuild
630
+ dx <-
631
+ case dl of
632
+ [dx] -> pure dx
633
+ _ -> liftIO $ assertFailure " pattern-match failure in batchMatMulGradient"
620
634
TF. run (TF. shape x, TF. shape dx)
621
635
622
636
assertEqual " Shape of gradient must match shape of input" xShape (dxShape :: V. Vector Int32 )
@@ -633,13 +647,19 @@ batchMatMulGradGrad = testCase "batchMatMulGradGrad" $ do
633
647
x <- TF. render $ TF. zeros $ TF. Shape [batch, height, 1 ]
634
648
w <- TF. zeroInitializedVariable $ TF. Shape [batch, 1 , width]
635
649
let f = x `TF.batchMatMul` TF. readValue w
636
- [dfdx] <- TF. gradients f [x]
650
+ l1 <- TF. gradients f [x]
651
+ let dfdx = head l1
637
652
let f'x = TF. sum dfdx (TF. vector [1 , 2 :: Int32 ])
638
- [dfdw] <- TF. gradients f'x [w] -- take gradient again (this time over w)
653
+ l2 <- TF. gradients f'x [w] -- take gradient again (this time over w)
654
+ let dfdw = head l2
639
655
return [TF. readValue w, TF. expr dfdw]
640
656
641
657
TF. runSession $ do
642
- [w, dfdw] <- TF. build tower
658
+ l <- TF. build tower
659
+ (w, dfdw) <-
660
+ case l of
661
+ [w, dfdw] -> pure (w, dfdw)
662
+ _ -> liftIO $ assertFailure " pattern-match failure in batchMatMulGradGrad"
643
663
(wShape, dfdwShape) <- TF. run (TF. shape w, TF. shape dfdw)
644
664
liftIO $ assertEqual " Shape of gradient must match input" wShape (dfdwShape :: V. Vector Int32 )
645
665
@@ -665,7 +685,11 @@ batchMatMulAdjointGradient axw = testCase ("batchMatMulAdjointGradients " ++ sho
665
685
return (x, w, ds)
666
686
667
687
TF. runSession $ do
668
- (x, w, [dx, dw]) <- TF. build dfBuild
688
+ (x, w, d) <- TF. build dfBuild
689
+ (dx, dw) <-
690
+ case d of
691
+ [dx, dw] -> pure (dx, dw)
692
+ _ -> liftIO $ assertFailure " pattern-match failure in batchMatMulAdjointGradient"
669
693
xShape <- TF. run $ TF. shape x
670
694
dxShape <- TF. run $ TF. shape dx
671
695
liftIO $ assertEqual " xShape must match dxShape" xShape (dxShape :: V. Vector Int32 )
0 commit comments