Skip to content

Commit 0f322b2

Browse files
authored
Fix MonadFail-related errors to support ghc 8.8
1 parent 739f661 commit 0f322b2

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

tensorflow-ops/src/TensorFlow/Variable.hs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,10 @@ initializedVariable' :: forall a m v . (MonadBuild m, TensorType a)
9090
=> OpParams -> Tensor v a -> m (Variable a)
9191
initializedVariable' params initializer = do
9292
-- The shape is not known initially.
93-
(Variable h Nothing :: Variable a) <- variableInternal params Nothing
93+
variables <- variableInternal params Nothing
94+
h <- pure $ case variables of
95+
(Variable h Nothing :: Variable a) -> h
96+
_ -> error "variableInternal is empty"
9497
initializer' <- renderValue initializer
9598
i <- CoreOps.assignVariableOp h initializer'
9699
addInitializer =<< group i

tensorflow-ops/tests/EmbeddingOpsTest.hs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ instance Arbitrary a => Arbitrary (LookupExample a) where
161161
let maxDim = fromIntegral (ceiling doubleMaxDim :: Int64)
162162
doubleMaxDim :: Double
163163
doubleMaxDim = 100 ** (1 / fromIntegral rank)
164-
shape@(firstDim : _) <- vectorOf rank (choose (1, maxDim))
164+
shape <- vectorOf rank (choose (1, maxDim))
165+
let firstDim = head shape
165166
values <- vectorOf (fromIntegral $ product shape) arbitrary
166167
numParts <- choose (2, 15)
167168
indSize <- choose (0, fromIntegral $ firstDim - 1)

tensorflow-ops/tests/GradientTest.hs

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import Test.Framework (defaultMain, Test)
2626
import Lens.Family2 ((^..), (.~))
2727

2828
import Test.Framework.Providers.HUnit (testCase)
29-
import Test.HUnit ((@=?), assertEqual)
29+
import Test.HUnit ((@=?), assertEqual, assertFailure)
3030
import qualified Data.Vector as V
3131
import System.Random (randomIO, randomRIO)
3232
import Control.Monad(forM_, replicateM, zipWithM)
@@ -557,13 +557,19 @@ matMulGradGrad = testCase "matMulGradGrad" $ do
557557
x <- TF.render $ TF.zeros $ TF.Shape [batch, 1]
558558
w <- TF.zeroInitializedVariable $ TF.Shape [1, width]
559559
let f = x `TF.matMul` TF.readValue w
560-
[dfdx] <- TF.gradients f [x]
560+
l1 <- TF.gradients f [x]
561+
let dfdx = head l1 -- avoid MonadFail
561562
let f'x = TF.reduceSum dfdx
562-
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
563+
l2 <- TF.gradients f'x [w] -- take gradient again (this time over w)
564+
let dfdw = head l2
563565
return [TF.readValue w, TF.expr dfdw]
564566

565567
TF.runSession $ do
566-
[w, dfdw] <- TF.build tower
568+
l <- TF.build tower
569+
(w, dfdw) <-
570+
case l of
571+
[w, dfdw] -> pure (w, dfdw)
572+
_ -> liftIO $ assertFailure "pattern-match failure in matMulGradMad"
567573
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
568574
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
569575

@@ -589,7 +595,11 @@ matMulTransposeGradient txw = testCase ("matMulTransposeGradients " ++ show txw)
589595
return (x, w, ds)
590596

591597
TF.runSession $ do
592-
(x, w, [dx, dw]) <- TF.build dfBuild
598+
(x, w, d) <- TF.build dfBuild
599+
(dx, dw) <-
600+
case d of
601+
[dx, dw] -> pure (dx, dw)
602+
_ -> liftIO $ assertFailure "pattern-match failure in matMulTransposeGradient"
593603
xShape <- TF.run $ TF.shape x
594604
dxShape <- TF.run $ TF.shape dx
595605
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
@@ -616,7 +626,11 @@ batchMatMulGradient = testCase "batchMatMulGradients" $ do
616626
return (x, dfs)
617627

618628
(xShape, dxShape) <- TF.runSession $ do
619-
(x, [dx]) <- TF.build dfBuild
629+
(x, dl) <- TF.build dfBuild
630+
dx <-
631+
case dl of
632+
[dx] -> pure dx
633+
_ -> liftIO $ assertFailure "pattern-match failure in batchMatMulGradient"
620634
TF.run (TF.shape x, TF.shape dx)
621635

622636
assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32)
@@ -633,13 +647,19 @@ batchMatMulGradGrad = testCase "batchMatMulGradGrad" $ do
633647
x <- TF.render $ TF.zeros $ TF.Shape [batch, height, 1]
634648
w <- TF.zeroInitializedVariable $ TF.Shape [batch, 1, width]
635649
let f = x `TF.batchMatMul` TF.readValue w
636-
[dfdx] <- TF.gradients f [x]
650+
l1 <- TF.gradients f [x]
651+
let dfdx = head l1
637652
let f'x = TF.sum dfdx (TF.vector [1, 2 :: Int32])
638-
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
653+
l2 <- TF.gradients f'x [w] -- take gradient again (this time over w)
654+
let dfdw = head l2
639655
return [TF.readValue w, TF.expr dfdw]
640656

641657
TF.runSession $ do
642-
[w, dfdw] <- TF.build tower
658+
l <- TF.build tower
659+
(w, dfdw) <-
660+
case l of
661+
[w, dfdw] -> pure (w, dfdw)
662+
_ -> liftIO $ assertFailure "pattern-match failure in batchMatMulGradGrad"
643663
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
644664
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
645665

@@ -665,7 +685,11 @@ batchMatMulAdjointGradient axw = testCase ("batchMatMulAdjointGradients " ++ sho
665685
return (x, w, ds)
666686

667687
TF.runSession $ do
668-
(x, w, [dx, dw]) <- TF.build dfBuild
688+
(x, w, d) <- TF.build dfBuild
689+
(dx, dw) <-
690+
case d of
691+
[dx, dw] -> pure (dx, dw)
692+
_ -> liftIO $ assertFailure "pattern-match failure in batchMatMulAdjointGradient"
669693
xShape <- TF.run $ TF.shape x
670694
dxShape <- TF.run $ TF.shape dx
671695
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)

0 commit comments

Comments
 (0)