Skip to content

Commit

Permalink
Workaround for fparser behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
reuterbal committed Apr 12, 2023
1 parent 0bb0a7b commit 29a5a80
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
13 changes: 11 additions & 2 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2100,8 +2100,17 @@ def visit_If_Construct(self, o, **kwargs):
else_if_stmt_index, else_if_stmts = zip(*else_if_stmts)
else:
else_if_stmt_index = ()
else_stmt = get_child(o, Fortran2003.Else_Stmt)
else_stmt_index = o.children.index(else_stmt) if else_stmt else end_if_stmt_index

# Note: we need to use here the same method as for else-if because finding Else_Stmt
# directly and checking its position via o.children.index may give the wrong result.
# This is because Else_Stmt may erronously compare equal to other node types.
# See https://github.com/stfc/fparser/issues/400
else_stmt = tuple((i, c) for i, c in enumerate(o.children) if isinstance(c, Fortran2003.Else_Stmt))
if else_stmt:
assert len(else_stmt) == 1
else_stmt_index, else_stmt = else_stmt[0]
else:
else_stmt_index = end_if_stmt_index
conditions = as_tuple(self.visit(c, **kwargs) for c in (if_then_stmt,) + else_if_stmts)
bodies = tuple(
tuple(flatten(as_tuple(self.visit(c, **kwargs) for c in o.children[start+1:stop])))
Expand Down
44 changes: 43 additions & 1 deletion tests/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from conftest import jit_compile, clean_test, available_frontends
from loki import OMNI, Subroutine, FindNodes, Loop, Conditional, Node
from loki import OMNI, Subroutine, FindNodes, Loop, Conditional, Node, Intrinsic


@pytest.fixture(scope='module', name='here')
Expand Down Expand Up @@ -455,3 +455,45 @@ def test_conditional_bodies(frontend):
c.else_body and isinstance(c.else_body, tuple) and all(isinstance(n, Node) for n in c.else_body)
for c in conditionals
)


@pytest.mark.parametrize('frontend', available_frontends())
def test_conditional_else_body_return(frontend):
fcode = """
FUNCTION FUNC(PX,KN)
IMPLICIT NONE
INTEGER,INTENT(INOUT) :: KN
REAL,INTENT(IN) :: PX
REAL :: FUNC
INTEGER :: J
REAL :: Z0, Z1, Z2
Z0= 1.0
Z1= PX
IF (KN == 0) THEN
FUNC= Z0
RETURN
ELSEIF (KN == 1) THEN
FUNC= Z1
RETURN
ELSE
DO J=2,KN
Z2= Z0+Z1
Z0= Z1
Z1= Z2
ENDDO
FUNC= Z2
RETURN
ENDIF
END FUNCTION FUNC
""".strip()

routine = Subroutine.from_source(fcode, frontend=frontend)
conditionals = FindNodes(Conditional).visit(routine.body)
assert len(conditionals) == 2
assert isinstance(conditionals[0].body[-1], Intrinsic)
assert conditionals[0].body[-1].text.upper() == 'RETURN'
assert conditionals[0].else_body == (conditionals[1],)
assert isinstance(conditionals[1].body[-1], Intrinsic)
assert conditionals[1].body[-1].text.upper() == 'RETURN'
assert isinstance(conditionals[1].else_body[-1], Intrinsic)
assert conditionals[1].else_body[-1].text.upper() == 'RETURN'

0 comments on commit 29a5a80

Please sign in to comment.