Skip to content

Commit

Permalink
Make Fortran-to-CL with more recent pymbolic
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed May 17, 2015
1 parent dcad574 commit b64346e
Showing 1 changed file with 31 additions and 19 deletions.
50 changes: 31 additions & 19 deletions contrib/fortran-to-opencl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import re
from pymbolic.parser import Parser as ExpressionParserBase
from pymbolic.mapper import CombineMapper
import pymbolic.primitives
import pymbolic.primitives as p
from pymbolic.mapper.c_code import CCodeMapper as CCodeMapperBase

from warnings import warn
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_decl_pair(self):
_or = intern("or")


class TypedLiteral(pymbolic.primitives.Leaf):
class TypedLiteral(p.Leaf):
def __init__(self, value, dtype):
self.value = value
self.dtype = np.dtype(dtype)
Expand All @@ -110,6 +110,18 @@ def __getinitargs__(self):
mapper_method = intern("map_literal")


def simplify_typed_literal(expr):
if (isinstance(expr, p.Product)
and len(expr.children) == 2
and isinstance(expr.children[1], TypedLiteral)
and p.is_constant(expr.children[0])
and expr.children[0] == -1):
tl = expr.children[1]
return TypedLiteral("-"+tl.value, tl.dtype)
else:
return expr


class FortranExpressionParser(ExpressionParserBase):
# FIXME double/single prec literals

Expand All @@ -134,7 +146,6 @@ def __init__(self, tree_walker):
def parse_terminal(self, pstate):
scope = self.tree_walker.scope_stack[-1]

from pymbolic.primitives import Subscript, Call, Variable
from pymbolic.parser import (
_identifier, _openpar, _closepar, _float)

Expand Down Expand Up @@ -164,17 +175,17 @@ def parse_terminal(self, pstate):
# not a subscript
scope.use_name(name)

return Variable(name)
return p.Variable(name)

left_exp = Variable(name)
left_exp = p.Variable(name)

pstate.advance()
pstate.expect_not_end()

if scope.is_known(name):
cls = Subscript
cls = p.Subscript
else:
cls = Call
cls = p.Call

if pstate.next_tag is _closepar:
pstate.advance()
Expand Down Expand Up @@ -219,14 +230,14 @@ def parse_postfix(self, pstate, min_precedence, left_exp):
_PREC_CALL, _PREC_COMPARISON, _openpar,
_PREC_LOGICAL_OR, _PREC_LOGICAL_AND)
from pymbolic.primitives import (
ComparisonOperator, LogicalAnd, LogicalOr)
Comparison, LogicalAnd, LogicalOr)

next_tag = pstate.next_tag()
if next_tag is _openpar and _PREC_CALL > min_precedence:
raise TranslationError("parenthesis operator only works on names")
elif next_tag in self.COMP_MAP and _PREC_COMPARISON > min_precedence:
pstate.advance()
left_exp = ComparisonOperator(
left_exp = Comparison(
left_exp,
self.COMP_MAP[next_tag],
self.parse_expression(pstate, _PREC_COMPARISON))
Expand All @@ -250,7 +261,10 @@ def parse_postfix(self, pstate, min_precedence, left_exp):
assert len(left_exp) == 2
r, i = left_exp

dtype = (r.dtype.type(0) + i.dtype.type(0))
r = simplify_typed_literal(r)
i = simplify_typed_literal(i)

dtype = (r.dtype.type(0) + i.dtype.type(0)).dtype
if dtype == np.float32:
dtype = np.complex64
else:
Expand Down Expand Up @@ -758,10 +772,9 @@ def map_Assignment(self, node):

lhs = self.parse_expr(node.variable)

from pymbolic.primitives import Subscript, Call
if isinstance(lhs, Subscript):
if isinstance(lhs, p.Subscript):
lhs_name = lhs.aggregate.name
elif isinstance(lhs, Call):
elif isinstance(lhs, p.Call):
# in absence of dim info, subscripts get parsed as calls
lhs_name = lhs.function.name
else:
Expand Down Expand Up @@ -797,11 +810,10 @@ def map_Goto(self, node):
def map_Call(self, node):
scope = self.scope_stack[-1]

from pymbolic.primitives import Subscript, Variable
for i, arg_str in enumerate(node.items):
arg = self.parse_expr(arg_str)
if isinstance(arg, (Variable, Subscript)):
if isinstance(arg, Subscript):
if isinstance(arg, (p.Variable, p.Subscript)):
if isinstance(arg, p.Subscript):
arg_name = arg.aggregate.name
else:
arg_name = arg.name
Expand Down Expand Up @@ -926,9 +938,9 @@ def gen_shape(start_end):
if shape is not None:
result.append(cgen.Statement(
"%s %s[nitemsof(%s)]"
% (
dtype_to_ctype(scope.get_type(name)),
name, name)))
% (
dtype_to_ctype(scope.get_type(name)),
name, name)))
else:
result.append(self.get_declarator(name))

Expand Down

0 comments on commit b64346e

Please sign in to comment.