Skip to content
114 changes: 114 additions & 0 deletions sql/expression/function/json/json_common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package json

import (
"fmt"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

// getMutableJSONVal returns a JSONValue from the given row and expression. The underling value is deeply copied so that
// you are free to use the mutation functions on the returned value.
// nil will be returned only if the inputs are nil. This will not return an error, so callers must check.
func getMutableJSONVal(ctx *sql.Context, row sql.Row, json sql.Expression) (types.MutableJSONValue, error) {
doc, err := getJSONDocumentFromRow(ctx, row, json)
if err != nil || doc == nil || doc.Val == nil {
return nil, err
}

mutable := types.DeepCopyJson(doc.Val)
return types.JSONDocument{Val: mutable}, nil
}

// getSearchableJSONVal returns a SearchableJSONValue from the given row and expression. The underling value is not copied
// so it is intended to be used for read-only operations.
// nil will be returned only if the inputs are nil. This will not return an error, so callers must check.
func getSearchableJSONVal(ctx *sql.Context, row sql.Row, json sql.Expression) (types.SearchableJSONValue, error) {
doc, err := getJSONDocumentFromRow(ctx, row, json)
if err != nil || doc == nil || doc.Val == nil {
return nil, err
}

return doc, nil
}

// getJSONDocumentFromRow returns a JSONDocument from the given row and expression. Helper function only intended to be
// used by functions in this file.
func getJSONDocumentFromRow(ctx *sql.Context, row sql.Row, json sql.Expression) (*types.JSONDocument, error) {
js, err := json.Eval(ctx, row)
if err != nil || js == nil {
return nil, err
}

var converted interface{}
switch js.(type) {
case string, []interface{}, map[string]interface{}, types.JSONValue:
converted, _, err = types.JSON.Convert(js)
if err != nil {
return nil, sql.ErrInvalidJSONText.New(js)
}
default:
return nil, sql.ErrInvalidArgument.New(fmt.Sprintf("%v", js))
}

doc, ok := converted.(types.JSONDocument)
if !ok {
// This should never happen, but just in case.
doc, err = js.(types.JSONValue).Unmarshall(ctx)
if err != nil {
return nil, err
}
}

return &doc, nil
}

// pathValPair is a helper struct for use by functions which take json paths paired with a json value. eg. JSON_SET, JSON_INSERT, etc.
type pathValPair struct {
path string
val types.JSONValue
}

// buildPathValue builds a pathValPair from the given row and expressions. This is a common pattern in json methods to have
// pairs of arguments, and this ensures they are of the right type, non-nil, and they wrapped in a struct as a unit.
func buildPathValue(ctx *sql.Context, pathExp sql.Expression, valExp sql.Expression, row sql.Row) (*pathValPair, error) {
path, err := pathExp.Eval(ctx, row)
if err != nil {
return nil, err
}

if path == nil {
// MySQL documented behavior is to return null, not error, if any path is null.
return nil, nil
}

// make sure path is string
if _, ok := path.(string); !ok {
return nil, fmt.Errorf("Invalid JSON path expression")
}

val, err := valExp.Eval(ctx, row)
if err != nil {
return nil, err
}
jsonVal, ok := val.(types.JSONValue)
if !ok {
jsonVal = types.JSONDocument{val}
}

return &pathValPair{path.(string), jsonVal}, nil
}
31 changes: 0 additions & 31 deletions sql/expression/function/json/json_contains.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,37 +165,6 @@ func (j *JSONContains) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
return target.Contains(ctx, candidate)
}

func getSearchableJSONVal(ctx *sql.Context, row sql.Row, json sql.Expression) (types.SearchableJSONValue, error) {
js, err := json.Eval(ctx, row)
if err != nil {
return nil, err
}
if js == nil {
return nil, nil
}

var converted interface{}
switch js.(type) {
case string, []interface{}, map[string]interface{}, types.JSONValue:
converted, _, err = types.JSON.Convert(js)
if err != nil {
return nil, sql.ErrInvalidJSONText.New(js)
}
default:
return nil, sql.ErrInvalidArgument.New(fmt.Sprintf("%v", js))
}

searchable, ok := converted.(types.SearchableJSONValue)
if !ok {
searchable, err = js.(types.JSONValue).Unmarshall(ctx)
if err != nil {
return nil, err
}
}

return searchable, nil
}

func (j *JSONContains) Children() []sql.Expression {
if j.Path != nil {
return []sql.Expression{j.JSONTarget, j.JSONCandidate, j.Path}
Expand Down
136 changes: 136 additions & 0 deletions sql/expression/function/json/json_insert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package json

import (
"fmt"
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

// JSON_INSERT(json_doc, path, val[, path, val] ...)
//
// JSONInsert Inserts data into a JSON document and returns the result. Returns NULL if any argument is NULL. An error
// occurs if the json_doc argument is not a valid JSON document or any path argument is not a valid path expression or
// contains a * or ** wildcard. The path-value pairs are evaluated left to right. The document produced by evaluating
// one pair becomes the new value against which the next pair is evaluated. A path-value pair for an existing path in
// the document is ignored and does not overwrite the existing document value. A path-value pair for a nonexisting path
// in the document adds the value to the document if the path identifies one of these types of values:
// - A member not present in an existing object. The member is added to the object and associated with the new value.
// - A position past the end of an existing array. The array is extended with the new value. If the existing value is
// not an array, it is autowrapped as an array, then extended with the new value.
//
// Otherwise, a path-value pair for a nonexisting path in the document is ignored and has no effect.
//
// https://dev.mysql.com/doc/refman/8.0/en/json-modification-functions.html#function_json-insert
type JSONInsert struct {
doc sql.Expression
pathVals []sql.Expression
}

var _ sql.FunctionExpression = JSONInsert{}

func (j JSONInsert) Resolved() bool {
for _, child := range j.Children() {
if child != nil && !child.Resolved() {
return false
}
}
return true
}

func (j JSONInsert) String() string {
children := j.Children()
var parts = make([]string, len(children))

for i, c := range children {
parts[i] = c.String()
}

return fmt.Sprintf("%s(%s)", j.FunctionName(), strings.Join(parts, ","))
}

func (j JSONInsert) Type() sql.Type {
return types.JSON
}

func (j JSONInsert) IsNullable() bool {
for _, arg := range j.pathVals {
if arg.IsNullable() {
return true
}
}
return j.doc.IsNullable()
}

func (j JSONInsert) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
doc, err := getMutableJSONVal(ctx, row, j.doc)
if err != nil || doc == nil {
return nil, err
}

pairs := make([]pathValPair, 0, len(j.pathVals)/2)
for i := 0; i < len(j.pathVals); i += 2 {
argPair, err := buildPathValue(ctx, j.pathVals[i], j.pathVals[i+1], row)
if argPair == nil || err != nil {
return nil, err
}
pairs = append(pairs, *argPair)
}

// Apply the path-value pairs to the document.
for _, pair := range pairs {
doc, _, err = doc.Insert(ctx, pair.path, pair.val)
if err != nil {
return nil, err
}
}

return doc, nil
}

func (j JSONInsert) Children() []sql.Expression {
return append([]sql.Expression{j.doc}, j.pathVals...)
}

func (j JSONInsert) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(j.Children()) != len(children) {
return nil, fmt.Errorf("json_replace did not receive the correct amount of args")
}
return NewJSONInsert(children...)
}

// NewJSONInsert creates a new JSONInsert function.
func NewJSONInsert(args ...sql.Expression) (sql.Expression, error) {
if len(args) <= 1 {
return nil, sql.ErrInvalidArgumentNumber.New("JSON_INSERT", "more than 1", len(args))
} else if (len(args)-1)%2 == 1 {
return nil, sql.ErrInvalidArgumentNumber.New("JSON_INSERT", "even number of path/val", len(args)-1)
}

return JSONInsert{args[0], args[1:]}, nil
}

// FunctionName implements sql.FunctionExpression
func (j JSONInsert) FunctionName() string {
return "json_insert"
}

// Description implements sql.FunctionExpression
func (j JSONInsert) Description() string {
return "inserts data into JSON document"
}
108 changes: 108 additions & 0 deletions sql/expression/function/json/json_insert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package json

import (
"fmt"
"strings"
"testing"

"github.com/stretchr/testify/require"
"gopkg.in/src-d/go-errors.v1"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

func TestInsert(t *testing.T) {
_, err := NewJSONInsert()
require.True(t, errors.Is(err, sql.ErrInvalidArgumentNumber))

f1 := buildGetFieldExpressions(t, NewJSONInsert, 3)
f2 := buildGetFieldExpressions(t, NewJSONInsert, 5)

json := `{"a": 1, "b": [2, 3], "c": {"d": "foo"}}`

testCases := []struct {
f sql.Expression
row sql.Row
expected interface{}
err error
}{
{f1, sql.Row{json, "$.a", 10.1}, json, nil}, // insert existing does nothing
{f1, sql.Row{json, "$.e", "new"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"},"e":"new"}`, nil}, // insert new
{f1, sql.Row{json, "$.c.d", "test"}, json, nil}, // insert existing nested does nothing
{f2, sql.Row{json, "$.a", 10.1, "$.e", "new"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"},"e":"new"}`, nil}, // insert multiple, one change.
{f1, sql.Row{json, "$.a.e", "test"}, json, nil}, // insert nested does nothing
{f1, sql.Row{json, "$.c.e", "test"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo","e":"test"}}`, nil}, // insert nested in existing struct
{f1, sql.Row{json, "$.c[5]", 4.1}, `{"a": 1, "b": [2, 3], "c": [{"d": "foo"}, 4.1]}`, nil}, // insert struct with indexing out of range
{f1, sql.Row{json, "$.b[0]", 4.1}, json, nil}, // insert element in array does nothing
{f1, sql.Row{json, "$.b[5]", 4.1}, `{"a": 1, "b": [2, 3, 4.1], "c": {"d": "foo"}}`, nil}, // insert element in array out of range
{f1, sql.Row{json, "$.b.c", 4}, json, nil}, // insert nested in array does nothing
{f1, sql.Row{json, "$.a[0]", 4.1}, json, nil}, // struct as array does nothing
{f1, sql.Row{json, "$[0]", 4.1}, json, nil}, // struct does nothing.
{f1, sql.Row{json, "$.[0]", 4.1}, nil, fmt.Errorf("Invalid JSON path expression")}, // improper struct indexing
{f1, sql.Row{json, "foo", "test"}, nil, fmt.Errorf("Invalid JSON path expression")}, // invalid path
{f1, sql.Row{json, "$.c.*", "test"}, nil, fmt.Errorf("Path expressions may not contain the * and ** tokens")}, // path contains * wildcard
{f1, sql.Row{json, "$.c.**", "test"}, nil, fmt.Errorf("Path expressions may not contain the * and ** tokens")}, // path contains ** wildcard
{f1, sql.Row{json, "$", 10.1}, json, nil}, // whole document no opt
{f1, sql.Row{nil, "$", 42.7}, nil, nil}, // null document returns null
{f1, sql.Row{json, nil, 10}, nil, nil}, // if any path is null, return null

// mysql> select JSON_INSERT(JSON_ARRAY(), "$[2]", 1 , "$[2]", 2 ,"$[2]", 3 ,"$[2]", 4);
// +------------------------------------------------------------------------+
// | JSON_INSERT(JSON_ARRAY(), "$[2]", 1 , "$[2]", 2 ,"$[2]", 3 ,"$[2]", 4) |
// +------------------------------------------------------------------------+
// | [1, 2, 3] |
// +------------------------------------------------------------------------+
{buildGetFieldExpressions(t, NewJSONInsert, 9),
sql.Row{`[]`,
"$[2]", 1.1, // [] -> [1.1]
"$[2]", 2.2, // [1.1] -> [1.1,2.2]
"$[2]", 3.3, // [1.1, 2.2] -> [1.1, 2.2, 3.3]
"$[2]", 4.4}, // [1.1, 2.2, 3.3] -> [1.1, 2.2, 3.3]
`[1.1, 2.2, 3.3]`, nil},
}

for _, tstC := range testCases {
var paths []string
for _, path := range tstC.row[1:] {
if _, ok := path.(string); ok {
paths = append(paths, path.(string))
}
}

t.Run(tstC.f.String()+"."+strings.Join(paths, ","), func(t *testing.T) {
req := require.New(t)
result, err := tstC.f.Eval(sql.NewEmptyContext(), tstC.row)
if tstC.err == nil {
req.NoError(err)

var expect interface{}
if tstC.expected != nil {
expect, _, err = types.JSON.Convert(tstC.expected)
if err != nil {
panic("Bad test string. Can't convert string to JSONDocument: " + tstC.expected.(string))
}
}

req.Equal(expect, result)
} else {
req.Error(tstC.err, err)
}
})
}

}
Loading