Skip to content

Commit

Permalink
feature: support IN query clauses for nested JSON columns (#272)
Browse files Browse the repository at this point in the history
* feat: support IN for JSONArrayQuery

* chore: restrict json in for other dialects

* chore: use builder.WriteQuoted for JSON_EXTRACT

* fix: remove superfluous quoted

* fix: allow in query for nested values only

* fix: resolve JSON type and IN operator incompatibility with older MySQL versions

* feat: support in queries for non-nested columns as well
  • Loading branch information
jiazheng2 authored Oct 15, 2024
1 parent 610acc2 commit 1399d3c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 14 deletions.
60 changes: 47 additions & 13 deletions json.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,39 +450,73 @@ func JSONArrayQuery(column string) *JSONArrayExpression {
}

type JSONArrayExpression struct {
contains bool
in bool
column string
keys []string
equalsValue interface{}
}

// Contains checks if the column[keys] has contains the value given. The keys parameter is only supported for MySQL.
// Contains checks if the column[keys] contains the value given. The keys parameter is only supported for MySQL.
func (json *JSONArrayExpression) Contains(value interface{}, keys ...string) *JSONArrayExpression {
json.contains = true
json.equalsValue = value
json.keys = keys
return json
}

// In checks if columns[keys] is in the array value given. This method is only supported for MySQL.
func (json *JSONArrayExpression) In(value interface{}, keys ...string) *JSONArrayExpression {
json.in = true
json.keys = keys
json.equalsValue = value
return json
}

// Build implements clause.Expression
func (json *JSONArrayExpression) Build(builder clause.Builder) {
if stmt, ok := builder.(*gorm.Statement); ok {
switch stmt.Dialector.Name() {
case "mysql":
builder.WriteString("JSON_CONTAINS(" + stmt.Quote(json.column) + ",JSON_ARRAY(")
builder.AddVar(stmt, json.equalsValue)
builder.WriteByte(')')
if len(json.keys) > 0 {
switch {
case json.contains:
builder.WriteString("JSON_CONTAINS(" + stmt.Quote(json.column) + ",JSON_ARRAY(")
builder.AddVar(stmt, json.equalsValue)
builder.WriteByte(')')
if len(json.keys) > 0 {
builder.WriteByte(',')
builder.AddVar(stmt, jsonQueryJoin(json.keys))
}
builder.WriteByte(')')
case json.in:
builder.WriteString("JSON_CONTAINS(JSON_ARRAY")
builder.AddVar(stmt, json.equalsValue)
builder.WriteByte(',')
builder.AddVar(stmt, jsonQueryJoin(json.keys))
if len(json.keys) > 0 {
builder.WriteString("JSON_EXTRACT(")
}
builder.WriteQuoted(json.column)
if len(json.keys) > 0 {
builder.WriteByte(',')
builder.AddVar(stmt, jsonQueryJoin(json.keys))
builder.WriteByte(')')
}
builder.WriteByte(')')
}
builder.WriteByte(')')
case "sqlite":
builder.WriteString("exists(SELECT 1 FROM json_each(" + stmt.Quote(json.column) + ") WHERE value = ")
builder.AddVar(stmt, json.equalsValue)
builder.WriteString(")")
switch {
case json.contains:
builder.WriteString("exists(SELECT 1 FROM json_each(" + stmt.Quote(json.column) + ") WHERE value = ")
builder.AddVar(stmt, json.equalsValue)
builder.WriteString(")")
}
case "postgres":
builder.WriteString(stmt.Quote(json.column))
builder.WriteString(" ? ")
builder.AddVar(stmt, json.equalsValue)
switch {
case json.contains:
builder.WriteString(stmt.Quote(json.column))
builder.WriteString(" ? ")
builder.AddVar(stmt, json.equalsValue)
}
}
}
}
18 changes: 17 additions & 1 deletion json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,6 @@ func TestJSONArrayQuery(t *testing.T) {
DisplayName: "JSONArray-1",
Config: datatypes.JSON("[\"a\", \"b\"]"),
}

cmp2 := Param{
DisplayName: "JSONArray-2",
Config: datatypes.JSON("[\"c\", \"a\"]"),
Expand All @@ -472,6 +471,10 @@ func TestJSONArrayQuery(t *testing.T) {
DisplayName: "JSONArray-3",
Config: datatypes.JSON("{\"test\": [\"a\", \"b\"]}"),
}
cmp4 := Param{
DisplayName: "JSONArray-4",
Config: datatypes.JSON("{\"test\": \"c\"}"),
}

if err := DB.Create(&cmp1).Error; err != nil {
t.Errorf("Failed to create param %v", err)
Expand All @@ -482,6 +485,9 @@ func TestJSONArrayQuery(t *testing.T) {
if err := DB.Create(&cmp3).Error; err != nil {
t.Errorf("Failed to create param %v", err)
}
if err := DB.Create(&cmp4).Error; err != nil {
t.Errorf("Failed to create param %v", err)
}

var retSingle1 Param
if err := DB.Where("id = ?", cmp2.ID).First(&retSingle1).Error; err != nil {
Expand All @@ -507,5 +513,15 @@ func TestJSONArrayQuery(t *testing.T) {
t.Fatalf("failed to find params with json value and keys, got error %v", err)
}
AssertEqual(t, len(retMultiple), 1)

if err := DB.Where(datatypes.JSONArrayQuery("config").In([]string{"c", "a"})).Find(&retMultiple).Error; err != nil {
t.Fatalf("failed to find params with json value, got error %v", err)
}
AssertEqual(t, len(retMultiple), 1)

if err := DB.Where(datatypes.JSONArrayQuery("config").In([]string{"c", "d"}, "test")).Find(&retMultiple).Error; err != nil {
t.Fatalf("failed to find params with json value and keys, got error %v", err)
}
AssertEqual(t, len(retMultiple), 1)
}
}

0 comments on commit 1399d3c

Please sign in to comment.