Skip to content

Commit

Permalink
ast: Don't allow calls in function signatures
Browse files Browse the repository at this point in the history
Rather than causing a panic calls in function declaration args will
now just raise a parse error.

In the future we could potentially support them but we would need to
sort out some additional details/ambiguity around behavior.

Its unclear that any users need this behavior so for now we'll just
correct the panic.

Fixes: #2081
Signed-off-by: Patrick East <east.patrick@gmail.com>
  • Loading branch information
patrick-east authored and tsandall committed Apr 2, 2020
1 parent 6436300 commit aee2bd9
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 37 deletions.
5 changes: 4 additions & 1 deletion ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -1207,8 +1207,11 @@ func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor {
// Scalars are no-ops. Comprehensions are handled above. Sets must not
// contain variables.
return nil
case Call:
vis.errs = append(vis.errs, NewError(CompileErr, t.Location, "rule arguments cannot contain calls"))
return nil
default:
// Recurse on refs, arrays, and calls. Any embedded
// Recurse on refs and arrays. Any embedded
// variables can be rewritten.
return vis
}
Expand Down
131 changes: 95 additions & 36 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -930,47 +930,106 @@ func TestCompilerExprExpansion(t *testing.T) {
}

func TestCompilerRewriteExprTerms(t *testing.T) {
module := `
package test
p { x = a + b * y }
q[[data.test.f(x)]] { x = 1 }
r = [data.test.f(x)] { x = 1 }
f(x) = data.test.g(x)
pi = 3 + .14
with_value { 1 with input as f(1) }
`

compiler := NewCompiler()
compiler.Modules = map[string]*Module{
"test": MustParseModule(module),
cases := []struct {
note string
module string
expected interface{}
}{
{
note: "base",
module: `
package test
p { x = a + b * y }
q[[data.test.f(x)]] { x = 1 }
r = [data.test.f(x)] { x = 1 }
f(x) = data.test.g(x)
pi = 3 + .14
with_value { 1 with input as f(1) }
`,
expected: `
package test
p { mul(b, y, __local1__); plus(a, __local1__, __local2__); eq(x, __local2__) }
q[[__local3__]] { x = 1; data.test.f(x, __local3__) }
r = [__local4__] { x = 1; data.test.f(x, __local4__) }
f(__local0__) = __local5__ { true; data.test.g(__local0__, __local5__) }
pi = __local6__ { true; plus(3, 0.14, __local6__) }
with_value { data.test.f(1, __local7__); 1 with input as __local7__ }
`,
},
{
note: "builtin calls in head",
module: `
package test
f(1+1) = 7
`,
expected: Errors{&Error{Message: "rule arguments cannot contain calls"}},
},
{
note: "builtin calls in head",
module: `
package test
f(object.get(x)) { object := {"a": 1}; object.a == x }
`,
expected: Errors{&Error{Message: "rule arguments cannot contain calls"}},
},
}
compileStages(compiler, compiler.rewriteExprTerms)
assertNotFailed(t, compiler)

expected := MustParseModule(`
package test
p { mul(b, y, __local1__); plus(a, __local1__, __local2__); eq(x, __local2__) }
q[[__local3__]] { x = 1; data.test.f(x, __local3__) }
r = [__local4__] { x = 1; data.test.f(x, __local4__) }
for _, tc := range cases {
t.Run(tc.note, func(t *testing.T) {
compiler := NewCompiler()
compiler.Modules = map[string]*Module{
"test": MustParseModule(tc.module),
}
compileStages(compiler, compiler.rewriteExprTerms)

f(__local0__) = __local5__ { true; data.test.g(__local0__, __local5__) }
switch exp := tc.expected.(type) {
case string:
assertNotFailed(t, compiler)

pi = __local6__ { true; plus(3, 0.14, __local6__) }
expected := MustParseModule(exp)

with_value { data.test.f(1, __local7__); 1 with input as __local7__ }
`)
if !expected.Equal(compiler.Modules["test"]) {
t.Fatalf("Expected modules to be equal. Expected:\n\n%v\n\nGot:\n\n%v", expected, compiler.Modules["test"])
}
case Errors:
if len(exp) != len(compiler.Errors) {
t.Fatalf("Expected %d errors, got %d:\n\n%s\n", len(exp), len(compiler.Errors), compiler.Errors.Error())
}
incorrectErrs := false
for _, e := range exp {
found := false
for _, actual := range compiler.Errors {
if e.Message == actual.Message {
found = true
break
}
}
if !found {
incorrectErrs = true
}
}
if incorrectErrs {
t.Fatalf("Expected errors:\n\n%s\n\nGot:\n\n%s\n", exp.Error(), compiler.Errors.Error())
}
default:
t.Fatalf("Unsupported value type for test case 'expected' field: %v", exp)
}

if !expected.Equal(compiler.Modules["test"]) {
t.Fatalf("Expected modules to be equal. Expected:\n\n%v\n\nGot:\n\n%v", expected, compiler.Modules["test"])
})
}
}

Expand Down Expand Up @@ -3336,7 +3395,7 @@ func assertCompilerErrorStrings(t *testing.T, compiler *Compiler, expected []str

func assertNotFailed(t *testing.T, c *Compiler) {
if c.Failed() {
t.Errorf("Unexpected compilation error: %v", c.Errors)
t.Fatalf("Unexpected compilation error: %v", c.Errors)
}
}

Expand Down
4 changes: 4 additions & 0 deletions ast/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ func (env *TypeEnv) Get(x interface{}) types.Type {
}
return nil

// Calls.
case Call:
return nil

default:
panic("unreachable")
}
Expand Down

0 comments on commit aee2bd9

Please sign in to comment.