diff --git a/internal/core/adt/adt.go b/internal/core/adt/adt.go index 09d4cbf130c..80cccd16e26 100644 --- a/internal/core/adt/adt.go +++ b/internal/core/adt/adt.go @@ -313,12 +313,8 @@ func (*DisjunctionExpr) elemNode() {} // Decl, Elem, and Yielder -func (*ForClause) declNode() {} -func (*ForClause) elemNode() {} -func (*IfClause) declNode() {} -func (*IfClause) elemNode() {} - -// Yielder only: ValueClause +func (*Comprehension) declNode() {} +func (*Comprehension) elemNode() {} // Node @@ -359,6 +355,7 @@ func (*OptionalField) node() {} func (*BulkOptionalField) node() {} func (*DynamicField) node() {} func (*Ellipsis) node() {} +func (*Comprehension) node() {} func (*ForClause) node() {} func (*IfClause) node() {} func (*LetClause) node() {} diff --git a/internal/core/adt/binop.go b/internal/core/adt/binop.go index e1d22c18d77..e2410a1e8fb 100644 --- a/internal/core/adt/binop.go +++ b/internal/core/adt/binop.go @@ -182,13 +182,17 @@ func BinOp(c *OpContext, op Op, left, right Value) Value { x := MakeIdentLabel(c, "x", "") - forClause := func(src Expr) *ForClause { - return &ForClause{ - Value: x, - Src: src, - Dst: &ValueClause{&StructLit{Decls: []Decl{ - &FieldReference{UpCount: 1, Label: x}, - }}}, + forClause := func(src Expr) *Comprehension { + s := &StructLit{Decls: []Decl{ + &FieldReference{UpCount: 1, Label: x}, + }} + return &Comprehension{ + Clauses: &ForClause{ + Value: x, + Src: src, + Dst: &ValueClause{s}, + }, + Value: s, } } @@ -242,13 +246,17 @@ func BinOp(c *OpContext, op Op, left, right Value) Value { x := MakeIdentLabel(c, "x", "") for i := c.uint64(left, "list multiplier"); i > 0; i-- { + st := &StructLit{Decls: []Decl{ + &FieldReference{UpCount: 1, Label: x}, + }} list.Elems = append(list.Elems, - &ForClause{ - Value: x, - Src: right, - Dst: &ValueClause{&StructLit{Decls: []Decl{ - &FieldReference{UpCount: 1, Label: x}, - }}}, + &Comprehension{ + Clauses: &ForClause{ + Value: x, + Src: right, + Dst: &ValueClause{st}, + }, + Value: st, }, ) } diff --git a/internal/core/adt/comprehension.go b/internal/core/adt/comprehension.go index 35014d3f917..59e395ef730 100644 --- a/internal/core/adt/comprehension.go +++ b/internal/core/adt/comprehension.go @@ -21,12 +21,16 @@ type envYield struct { err *Bottom } +func (n *nodeContext) insertComprehension(env *Environment, x Yielder, ci CloseInfo) { + n.comprehensions = append(n.comprehensions, envYield{env, x, ci, nil}) +} + // injectComprehensions evaluates and inserts comprehensions. func (n *nodeContext) injectComprehensions(all *[]envYield) (progress bool) { ctx := n.ctx type envStruct struct { env *Environment - s *StructLit + s *StructLit // always the same. } var sa []envStruct f := func(env *Environment, st *StructLit) { diff --git a/internal/core/adt/eval.go b/internal/core/adt/eval.go index e1422208bd5..d4357785e72 100644 --- a/internal/core/adt/eval.go +++ b/internal/core/adt/eval.go @@ -1740,13 +1740,8 @@ func (n *nodeContext) addStruct( n.aStructID = closeInfo n.dynamicFields = append(n.dynamicFields, envDynamic{childEnv, x, closeInfo, nil}) - case *ForClause: - // Why is this not an embedding? - n.comprehensions = append(n.comprehensions, envYield{childEnv, x, closeInfo, nil}) - - case Yielder: - // Why is this not an embedding? - n.comprehensions = append(n.comprehensions, envYield{childEnv, x, closeInfo, nil}) + case *Comprehension: + n.insertComprehension(childEnv, x.Clauses, closeInfo) case Expr: // add embedding to optional @@ -1992,9 +1987,10 @@ outer: hasComprehension := false for j, elem := range l.list.Elems { switch x := elem.(type) { - case Yielder: - err := c.Yield(l.env, x, func(e *Environment, st *StructLit) { - label, err := MakeLabel(x.Source(), index, IntLabel) + case *Comprehension: + xx := x.Clauses + err := c.Yield(l.env, xx, func(e *Environment, st *StructLit) { + label, err := MakeLabel(xx.Source(), index, IntLabel) n.addErr(err) index++ c := MakeConjunct(e, st, l.id) diff --git a/internal/core/adt/expr.go b/internal/core/adt/expr.go index e67cd34fbcb..efcbddcf20f 100644 --- a/internal/core/adt/expr.go +++ b/internal/core/adt/expr.go @@ -119,7 +119,10 @@ func (o *StructLit) Init() { case Expr: o.HasEmbed = true - case *ForClause, Yielder: + case *Comprehension: + o.HasEmbed = true + + case *LetClause: o.HasEmbed = true case *BulkOptionalField: @@ -1601,6 +1604,18 @@ func (x *Disjunction) Kind() Kind { return k } +type Comprehension struct { + Clauses Yielder + Value *StructLit // TODO: changes this to Expr? +} + +func (x *Comprehension) Source() ast.Node { + if x.Clauses == nil { + return nil + } + return x.Clauses.Source() +} + // A ForClause represents a for clause of a comprehension. It can be used // as a struct or list element. // diff --git a/internal/core/adt/expr_test.go b/internal/core/adt/expr_test.go index 3d7a936b85c..fddfe057c0f 100644 --- a/internal/core/adt/expr_test.go +++ b/internal/core/adt/expr_test.go @@ -32,6 +32,7 @@ func TestNilSource(t *testing.T) { &BulkOptionalField{}, &Bytes{}, &CallExpr{}, + &Comprehension{}, &Conjunction{}, &Disjunction{}, &DisjunctionExpr{}, diff --git a/internal/core/compile/compile.go b/internal/core/compile/compile.go index 4155a6d5238..18fa53be468 100644 --- a/internal/core/compile/compile.go +++ b/internal/core/compile/compile.go @@ -702,7 +702,7 @@ func (c *compiler) elem(n ast.Expr) adt.Elem { func (c *compiler) comprehension(x *ast.Comprehension) adt.Elem { var cur adt.Yielder - var first adt.Elem + var first adt.Yielder var prev, next *adt.Yielder for _, v := range x.Clauses { switch x := v.(type) { @@ -745,8 +745,8 @@ func (c *compiler) comprehension(x *ast.Comprehension) adt.Elem { if prev != nil { *prev = cur } else { - var ok bool - if first, ok = cur.(adt.Elem); !ok { + first = cur + if _, ok := cur.(*adt.LetClause); ok { return c.errf(x, "first comprehension clause must be 'if' or 'for'") } @@ -774,7 +774,10 @@ func (c *compiler) comprehension(x *ast.Comprehension) adt.Elem { return c.errf(x, "comprehension value without clauses") } - return first + return &adt.Comprehension{ + Clauses: first, + Value: st, + } } func (c *compiler) labeledExpr(f *ast.Field, lab labeler, expr ast.Expr) adt.Expr { diff --git a/internal/core/debug/compact.go b/internal/core/debug/compact.go index 0a49d942e82..c0069f8b52a 100644 --- a/internal/core/debug/compact.go +++ b/internal/core/debug/compact.go @@ -301,6 +301,10 @@ func (w *compactPrinter) node(n adt.Node) { w.node(c) } + case *adt.Comprehension: + w.node(x.Clauses) + w.node(x.Value) + case *adt.ForClause: w.string("for ") w.ident(x.Key) @@ -326,7 +330,6 @@ func (w *compactPrinter) node(n adt.Node) { w.node(x.Dst) case *adt.ValueClause: - w.node(x.StructLit) default: panic(fmt.Sprintf("unknown type %T", x)) diff --git a/internal/core/debug/debug.go b/internal/core/debug/debug.go index 865ff07fc54..8394f83eee0 100644 --- a/internal/core/debug/debug.go +++ b/internal/core/debug/debug.go @@ -497,6 +497,10 @@ func (w *printer) node(n adt.Node) { } w.string(")") + case *adt.Comprehension: + w.node(x.Clauses) + w.node(x.Value) + case *adt.ForClause: w.string("for ") w.ident(x.Key) @@ -522,7 +526,6 @@ func (w *printer) node(n adt.Node) { w.node(x.Dst) case *adt.ValueClause: - w.node(x.StructLit) default: panic(fmt.Sprintf("unknown type %T", x)) diff --git a/internal/core/dep/dep.go b/internal/core/dep/dep.go index 95dcf81a2a5..2a0d1a33910 100644 --- a/internal/core/dep/dep.go +++ b/internal/core/dep/dep.go @@ -188,8 +188,8 @@ func (c *visitor) markExpr(env *adt.Environment, expr adt.Elem) { env := &adt.Environment{Up: env, Vertex: empty} for _, e := range x.Elems { switch x := e.(type) { - case adt.Yielder: - c.markYielder(env, x) + case *adt.Comprehension: + c.markComprehension(env, x) case adt.Expr: c.markSubExpr(env, x) @@ -287,8 +287,8 @@ func (c *visitor) markDecl(env *adt.Environment, d adt.Decl) { // a matching field in the parallel actual evaluation. c.markSubExpr(env, x.Value) - case adt.Yielder: - c.markYielder(env, x) + case *adt.Comprehension: + c.markComprehension(env, x) case adt.Expr: c.markExpr(env, x) @@ -300,26 +300,29 @@ func (c *visitor) markDecl(env *adt.Environment, d adt.Decl) { } } -func (c *visitor) markYielder(env *adt.Environment, y adt.Yielder) { +func (c *visitor) markComprehension(env *adt.Environment, y *adt.Comprehension) { + env = c.markYielder(env, y.Clauses) + c.markExpr(env, y.Value) +} + +func (c *visitor) markYielder(env *adt.Environment, y adt.Yielder) *adt.Environment { switch x := y.(type) { case *adt.ForClause: c.markExpr(env, x.Src) - env := &adt.Environment{Up: env, Vertex: empty} - c.markYielder(env, x.Dst) + env = &adt.Environment{Up: env, Vertex: empty} + env = c.markYielder(env, x.Dst) // In dynamic mode, iterate over all actual value and // evaluate. case *adt.LetClause: c.markExpr(env, x.Expr) - env := &adt.Environment{Up: env, Vertex: empty} - c.markYielder(env, x.Dst) + env = &adt.Environment{Up: env, Vertex: empty} + env = c.markYielder(env, x.Dst) case *adt.IfClause: c.markExpr(env, x.Condition) // In dynamic mode, only continue if condition is true. - c.markYielder(env, x.Dst) - - case *adt.ValueClause: - c.markExpr(env, x.StructLit) + env = c.markYielder(env, x.Dst) } + return env } diff --git a/internal/core/dep/mixed.go b/internal/core/dep/mixed.go index 601bbfd2bba..073c67499e5 100644 --- a/internal/core/dep/mixed.go +++ b/internal/core/dep/mixed.go @@ -94,8 +94,8 @@ func (m marked) markExpr(x adt.Expr) { case adt.Expr: m.markExpr(x) - case adt.Yielder: - m.markYielder(x) + case *adt.Comprehension: + m.markComprehension(x) default: panic(fmt.Sprintf("unreachable %T", x)) @@ -108,8 +108,8 @@ func (m marked) markExpr(x adt.Expr) { case adt.Expr: m.markExpr(x) - case adt.Yielder: - m.markYielder(x) + case *adt.Comprehension: + m.markComprehension(x) case *adt.Ellipsis: m.markExpr(x.Value) @@ -123,12 +123,14 @@ func (m marked) markExpr(x adt.Expr) { for _, d := range x.Values { m.markExpr(d.Val) } - - case adt.Yielder: - m.markYielder(x) } } +func (m marked) markComprehension(y *adt.Comprehension) { + m.markYielder(y.Clauses) + m.markExpr(y.Value) +} + func (m marked) markYielder(y adt.Yielder) { switch x := y.(type) { case *adt.ForClause: @@ -139,8 +141,5 @@ func (m marked) markYielder(y adt.Yielder) { case *adt.LetClause: m.markYielder(x.Dst) - - case *adt.ValueClause: - m.markExpr(x.StructLit) } } diff --git a/internal/core/export/adt.go b/internal/core/export/adt.go index d9fa583b52d..8e860d609e9 100644 --- a/internal/core/export/adt.go +++ b/internal/core/export/adt.go @@ -420,7 +420,7 @@ func (e *exporter) elem(d adt.Elem) ast.Expr { } return t - case adt.Yielder: + case *adt.Comprehension: return e.comprehension(x) default: @@ -428,9 +428,12 @@ func (e *exporter) elem(d adt.Elem) ast.Expr { } } -func (e *exporter) comprehension(y adt.Yielder) ast.Expr { +func (e *exporter) comprehension(comp *adt.Comprehension) *ast.Comprehension { c := &ast.Comprehension{} + y := comp.Clauses + +loop: for { switch x := y.(type) { case *adt.ForClause: @@ -474,16 +477,17 @@ func (e *exporter) comprehension(y adt.Yielder) ast.Expr { y = x.Dst case *adt.ValueClause: - v := e.expr(x.StructLit) - if _, ok := v.(*ast.StructLit); !ok { - v = ast.NewStruct(ast.Embed(v)) - } - c.Value = v - return c + break loop default: panic(fmt.Sprintf("unknown field %T", x)) } } + v := e.expr(comp.Value) + if _, ok := v.(*ast.StructLit); !ok { + v = ast.NewStruct(ast.Embed(v)) + } + c.Value = v + return c } diff --git a/internal/core/subsume/structural.go b/internal/core/subsume/structural.go index a4e29aee953..450a03e1863 100644 --- a/internal/core/subsume/structural.go +++ b/internal/core/subsume/structural.go @@ -268,11 +268,8 @@ func (c *collatedDecls) collate(env *adt.Environment, s *adt.StructLit) { c.isOpen = true c.additional = append(c.additional, x) - case *adt.ForClause: - c.yielders = append(c.yielders, x) - - case *adt.IfClause: - c.yielders = append(c.yielders, x) + case *adt.Comprehension: + c.yielders = append(c.yielders, x.Clauses) case *adt.LetClause: c.yielders = append(c.yielders, x) diff --git a/internal/core/walk/walk.go b/internal/core/walk/walk.go index 4c09aefa552..3a12139f971 100644 --- a/internal/core/walk/walk.go +++ b/internal/core/walk/walk.go @@ -165,6 +165,10 @@ func (w *Visitor) node(n adt.Node) { // Yielders + case *adt.Comprehension: + w.node(x.Clauses) + w.node(x.Value) + case *adt.ForClause: w.feature(x.Key, x) w.feature(x.Value, x) @@ -180,7 +184,6 @@ func (w *Visitor) node(n adt.Node) { w.node(x.Dst) case *adt.ValueClause: - w.node(x.StructLit) default: panic(fmt.Sprintf("unknown field %T", x)) diff --git a/tools/trim/trim.go b/tools/trim/trim.go index 28eb918b129..b863f92524d 100644 --- a/tools/trim/trim.go +++ b/tools/trim/trim.go @@ -94,7 +94,7 @@ func Files(files []*ast.File, inst cue.InstanceOrValue, cfg *Config) error { // Structs with comprehensions may never be removed. for _, d := range x.Decls { switch d.(type) { - case *adt.IfClause, *adt.ForClause: + case *adt.Comprehension: t.markKeep(x) } }