Skip to content

Commit

Permalink
d2ir: Fix and add test for glob-edge-glob-index
Browse files Browse the repository at this point in the history
  • Loading branch information
nhooyr committed Jul 29, 2023
1 parent b29a8ef commit 137b909
Show file tree
Hide file tree
Showing 7 changed files with 4,766 additions and 25 deletions.
10 changes: 5 additions & 5 deletions d2ir/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ func (c *compiler) compileKey(refctx *RefContext) {
}

func (c *compiler) compileField(dst *Map, kp *d2ast.KeyPath, refctx *RefContext) {
fa, err := dst.EnsureField(kp, refctx)
fa, err := dst.EnsureField(kp, refctx, true)
if err != nil {
c.err.Errors = append(c.err.Errors, err.(d2ast.Error))
return
Expand Down Expand Up @@ -614,7 +614,7 @@ func (c *compiler) compileEdges(refctx *RefContext) {
return
}

fa, err := refctx.ScopeMap.EnsureField(refctx.Key.Key, refctx)
fa, err := refctx.ScopeMap.EnsureField(refctx.Key.Key, refctx, true)
if err != nil {
c.err.Errors = append(c.err.Errors, err.(d2ast.Error))
return
Expand Down Expand Up @@ -648,7 +648,7 @@ func (c *compiler) _compileEdges(refctx *RefContext) {

var ea []*Edge
if eid.Index != nil || eid.Glob {
ea = refctx.ScopeMap.GetEdges(eid)
ea = refctx.ScopeMap.GetEdges(eid, refctx)
if len(ea) == 0 {
c.errorf(refctx.Edge, "indexed edge does not exist")
continue
Expand All @@ -661,12 +661,12 @@ func (c *compiler) _compileEdges(refctx *RefContext) {
refctx.ScopeMap.appendFieldReferences(0, refctx.Edge.Dst, refctx)
}
} else {
_, err := refctx.ScopeMap.EnsureField(refctx.Edge.Src, refctx)
_, err := refctx.ScopeMap.EnsureField(refctx.Edge.Src, refctx, true)
if err != nil {
c.err.Errors = append(c.err.Errors, err.(d2ast.Error))
continue
}
_, err = refctx.ScopeMap.EnsureField(refctx.Edge.Dst, refctx)
_, err = refctx.ScopeMap.EnsureField(refctx.Edge.Dst, refctx, true)
if err != nil {
c.err.Errors = append(c.err.Errors, err.(d2ast.Error))
continue
Expand Down
136 changes: 119 additions & 17 deletions d2ir/d2ir.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ func (m *Map) getField(ida []string) *Field {
return nil
}

func (m *Map) EnsureField(kp *d2ast.KeyPath, refctx *RefContext) ([]*Field, error) {
func (m *Map) EnsureField(kp *d2ast.KeyPath, refctx *RefContext, create bool) ([]*Field, error) {
i := 0
for kp.Path[i].Unbox().ScalarString() == "_" {
m = ParentMap(m)
Expand All @@ -667,11 +667,11 @@ func (m *Map) EnsureField(kp *d2ast.KeyPath, refctx *RefContext) ([]*Field, erro
}

var fa []*Field
err := m.ensureField(i, kp, refctx, &fa)
err := m.ensureField(i, kp, refctx, create, &fa)
return fa, err
}

func (m *Map) ensureField(i int, kp *d2ast.KeyPath, refctx *RefContext, fa *[]*Field) error {
func (m *Map) ensureField(i int, kp *d2ast.KeyPath, refctx *RefContext, create bool, fa *[]*Field) error {
us, ok := kp.Path[i].Unbox().(*d2ast.UnquotedString)
if ok && us.Pattern != nil {
fa2, ok := m.doubleGlob(us.Pattern)
Expand All @@ -685,7 +685,7 @@ func (m *Map) ensureField(i int, kp *d2ast.KeyPath, refctx *RefContext, fa *[]*F
parent: f,
}
}
err := f.Map().ensureField(i+1, kp, refctx, fa)
err := f.Map().ensureField(i+1, kp, refctx, create, fa)
if err != nil {
return err
}
Expand All @@ -703,7 +703,7 @@ func (m *Map) ensureField(i int, kp *d2ast.KeyPath, refctx *RefContext, fa *[]*F
parent: f,
}
}
err := f.Map().ensureField(i+1, kp, refctx, fa)
err := f.Map().ensureField(i+1, kp, refctx, create, fa)
if err != nil {
return err
}
Expand Down Expand Up @@ -760,9 +760,12 @@ func (m *Map) ensureField(i int, kp *d2ast.KeyPath, refctx *RefContext, fa *[]*F
parent: f,
}
}
return f.Map().ensureField(i+1, kp, refctx, fa)
return f.Map().ensureField(i+1, kp, refctx, create, fa)
}

if !create {
return nil
}
f := &Field{
parent: m,
Name: head,
Expand All @@ -783,7 +786,7 @@ func (m *Map) ensureField(i int, kp *d2ast.KeyPath, refctx *RefContext, fa *[]*F
f.Composite = &Map{
parent: f,
}
return f.Map().ensureField(i+1, kp, refctx, fa)
return f.Map().ensureField(i+1, kp, refctx, create, fa)
}

func (m *Map) DeleteEdge(eid *EdgeID) *Edge {
Expand Down Expand Up @@ -848,7 +851,13 @@ func (m *Map) DeleteField(ida ...string) *Field {
return nil
}

func (m *Map) GetEdges(eid *EdgeID) []*Edge {
func (m *Map) GetEdges(eid *EdgeID, refctx *RefContext) []*Edge {
if refctx != nil {
var ea []*Edge
m.getEdges(eid, refctx, &ea)
return ea
}

eid, m, common, err := eid.resolve(m)
if err != nil {
return nil
Expand All @@ -859,7 +868,7 @@ func (m *Map) GetEdges(eid *EdgeID) []*Edge {
return nil
}
if f.Map() != nil {
return f.Map().GetEdges(eid)
return f.Map().GetEdges(eid, nil)
}
return nil
}
Expand All @@ -873,6 +882,90 @@ func (m *Map) GetEdges(eid *EdgeID) []*Edge {
return ea
}

func (m *Map) getEdges(eid *EdgeID, refctx *RefContext, ea *[]*Edge) error {
eid, m, common, err := eid.resolve(m)
if err != nil {
return err
}

if len(common) > 0 {
commonKP := d2ast.MakeKeyPath(common)
lastMatch := 0
for i, el := range commonKP.Path {
for j := lastMatch; j < len(refctx.Edge.Src.Path); j++ {
realEl := refctx.Edge.Src.Path[j]
if el.ScalarString() == realEl.ScalarString() {
commonKP.Path[i] = realEl
lastMatch += j + 1
}
}
}
fa, err := m.EnsureField(commonKP, nil, false)
if err != nil {
return nil
}
for _, f := range fa {
if _, ok := f.Composite.(*Array); ok {
return d2parser.Errorf(refctx.Edge.Src, "cannot index into array")
}
if f.Map() == nil {
f.Composite = &Map{
parent: f,
}
}
err = f.Map().getEdges(eid, refctx, ea)
if err != nil {
return err
}
}
return nil
}

srcKP := d2ast.MakeKeyPath(eid.SrcPath)
lastMatch := 0
for i, el := range srcKP.Path {
for j := lastMatch; j < len(refctx.Edge.Src.Path); j++ {
realEl := refctx.Edge.Src.Path[j]
if el.ScalarString() == realEl.ScalarString() {
srcKP.Path[i] = realEl
lastMatch += j + 1
}
}
}
dstKP := d2ast.MakeKeyPath(eid.DstPath)
lastMatch = 0
for i, el := range dstKP.Path {
for j := lastMatch; j < len(refctx.Edge.Dst.Path); j++ {
realEl := refctx.Edge.Dst.Path[j]
if el.ScalarString() == realEl.ScalarString() {
dstKP.Path[i] = realEl
lastMatch += j + 1
}
}
}

srcFA, err := m.EnsureField(srcKP, nil, false)
if err != nil {
return err
}
dstFA, err := m.EnsureField(dstKP, nil, false)
if err != nil {
return err
}

for _, src := range srcFA {
for _, dst := range dstFA {
eid2 := eid.Copy()
eid2.SrcPath = RelIDA(m, src)
eid2.DstPath = RelIDA(m, dst)

ea2 := m.GetEdges(eid2, nil)
*ea = append(*ea, ea2...)
}
}
return nil
}

func (m *Map) CreateEdge(eid *EdgeID, refctx *RefContext) ([]*Edge, error) {
var ea []*Edge
return ea, m.createEdge(eid, refctx, &ea)
Expand All @@ -888,11 +981,18 @@ func (m *Map) createEdge(eid *EdgeID, refctx *RefContext, ea *[]*Edge) error {
return d2parser.Errorf(refctx.Edge, err.Error())
}
if len(common) > 0 {
commonEnd := len(refctx.Edge.Src.Path) - len(eid.SrcPath)
commonStart := commonEnd - len(common)
commonKP := refctx.Edge.Src.Copy()
commonKP.Path = commonKP.Path[commonStart:commonEnd]
fa, err := m.EnsureField(commonKP, nil)
commonKP := d2ast.MakeKeyPath(common)
lastMatch := 0
for i, el := range commonKP.Path {
for j := lastMatch; j < len(refctx.Edge.Src.Path); j++ {
realEl := refctx.Edge.Src.Path[j]
if el.ScalarString() == realEl.ScalarString() {
commonKP.Path[i] = realEl
lastMatch += j + 1
}
}
}
fa, err := m.EnsureField(commonKP, nil, true)
if err != nil {
return err
}
Expand Down Expand Up @@ -954,11 +1054,11 @@ func (m *Map) createEdge(eid *EdgeID, refctx *RefContext, ea *[]*Edge) error {
}
}

srcFA, err := m.EnsureField(srcKP, nil)
srcFA, err := m.EnsureField(srcKP, nil, true)
if err != nil {
return err
}
dstFA, err := m.EnsureField(dstKP, nil)
dstFA, err := m.EnsureField(dstKP, nil, true)
if err != nil {
return err
}
Expand Down Expand Up @@ -990,9 +1090,11 @@ func (m *Map) createEdge2(eid *EdgeID, refctx *RefContext, src, dst *Field) (*Ed
}

eid.Index = nil
ea := m.GetEdges(eid)
eid.Glob = true
ea := m.GetEdges(eid, refctx)
index := len(ea)
eid.Index = &index
eid.Glob = false
e := &Edge{
parent: m,
ID: eid,
Expand Down
2 changes: 1 addition & 1 deletion d2ir/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func OverlayMap(base, overlay *Map) {
}

for _, oe := range overlay.Edges {
bea := base.GetEdges(oe.ID)
bea := base.GetEdges(oe.ID, nil)
if len(bea) == 0 {
base.Edges = append(base.Edges, oe.Copy(base).(*Edge))
continue
Expand Down
29 changes: 28 additions & 1 deletion d2ir/pattern_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,23 @@ a -> b
assertQuery(t, m, 0, 0, "red", "(a -> b)[2].style.fill")
},
},
{
name: "glob-edge-glob-index",
run: func(t testing.TB) {
m, err := compile(t, `a -> b
a -> b
a -> b
c -> b
(* -> b)[*].style.fill: red
`)
assert.Success(t, err)
assertQuery(t, m, 11, 4, nil, "")
assertQuery(t, m, 0, 0, "red", "(a -> b)[0].style.fill")
assertQuery(t, m, 0, 0, "red", "(a -> b)[1].style.fill")
assertQuery(t, m, 0, 0, "red", "(a -> b)[2].style.fill")
assertQuery(t, m, 0, 0, "red", "(c -> b)[0].style.fill")
},
},
{
name: "double-glob/1",
run: func(t testing.TB) {
Expand All @@ -169,7 +186,17 @@ shared.animal
runa(t, tca)

t.Run("errors", func(t *testing.T) {
tca := []testCase{}
tca := []testCase{
{
name: "glob-edge-glob-index",
run: func(t testing.TB) {
m, err := compile(t, `(* -> b)[*].style.fill: red
`)
assert.ErrorString(t, err, `TestCompile/patterns/errors/glob-edge-glob-index.d2:1:2: indexed edge does not exist`)
assertQuery(t, m, 0, 0, nil, "")
},
},
}
runa(t, tca)
})
}
2 changes: 1 addition & 1 deletion d2ir/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (m *Map) QueryAll(idStr string) (na []Node, _ error) {

eida := NewEdgeIDs(k)
for _, eid := range eida {
ea := m.GetEdges(eid)
ea := m.GetEdges(eid, nil)
for _, e := range ea {
if k.EdgeKey == nil {
na = append(na, e)
Expand Down
Loading

0 comments on commit 137b909

Please sign in to comment.