diff --git a/json.go b/json.go index 13c8cac..6167477 100644 --- a/json.go +++ b/json.go @@ -450,39 +450,73 @@ func JSONArrayQuery(column string) *JSONArrayExpression { } type JSONArrayExpression struct { + contains bool + in bool column string keys []string equalsValue interface{} } -// Contains checks if the column[keys] has contains the value given. The keys parameter is only supported for MySQL. +// Contains checks if the column[keys] contains the value given. The keys parameter is only supported for MySQL. func (json *JSONArrayExpression) Contains(value interface{}, keys ...string) *JSONArrayExpression { + json.contains = true json.equalsValue = value json.keys = keys return json } +// In checks if columns[keys] is in the array value given. This method is only supported for MySQL. +func (json *JSONArrayExpression) In(value interface{}, keys ...string) *JSONArrayExpression { + json.in = true + json.keys = keys + json.equalsValue = value + return json +} + // Build implements clause.Expression func (json *JSONArrayExpression) Build(builder clause.Builder) { if stmt, ok := builder.(*gorm.Statement); ok { switch stmt.Dialector.Name() { case "mysql": - builder.WriteString("JSON_CONTAINS(" + stmt.Quote(json.column) + ",JSON_ARRAY(") - builder.AddVar(stmt, json.equalsValue) - builder.WriteByte(')') - if len(json.keys) > 0 { + switch { + case json.contains: + builder.WriteString("JSON_CONTAINS(" + stmt.Quote(json.column) + ",JSON_ARRAY(") + builder.AddVar(stmt, json.equalsValue) + builder.WriteByte(')') + if len(json.keys) > 0 { + builder.WriteByte(',') + builder.AddVar(stmt, jsonQueryJoin(json.keys)) + } + builder.WriteByte(')') + case json.in: + builder.WriteString("JSON_CONTAINS(JSON_ARRAY") + builder.AddVar(stmt, json.equalsValue) builder.WriteByte(',') - builder.AddVar(stmt, jsonQueryJoin(json.keys)) + if len(json.keys) > 0 { + builder.WriteString("JSON_EXTRACT(") + } + builder.WriteQuoted(json.column) + if len(json.keys) > 0 { + builder.WriteByte(',') + builder.AddVar(stmt, jsonQueryJoin(json.keys)) + builder.WriteByte(')') + } + builder.WriteByte(')') } - builder.WriteByte(')') case "sqlite": - builder.WriteString("exists(SELECT 1 FROM json_each(" + stmt.Quote(json.column) + ") WHERE value = ") - builder.AddVar(stmt, json.equalsValue) - builder.WriteString(")") + switch { + case json.contains: + builder.WriteString("exists(SELECT 1 FROM json_each(" + stmt.Quote(json.column) + ") WHERE value = ") + builder.AddVar(stmt, json.equalsValue) + builder.WriteString(")") + } case "postgres": - builder.WriteString(stmt.Quote(json.column)) - builder.WriteString(" ? ") - builder.AddVar(stmt, json.equalsValue) + switch { + case json.contains: + builder.WriteString(stmt.Quote(json.column)) + builder.WriteString(" ? ") + builder.AddVar(stmt, json.equalsValue) + } } } } diff --git a/json_test.go b/json_test.go index d8c2071..4e43432 100644 --- a/json_test.go +++ b/json_test.go @@ -463,7 +463,6 @@ func TestJSONArrayQuery(t *testing.T) { DisplayName: "JSONArray-1", Config: datatypes.JSON("[\"a\", \"b\"]"), } - cmp2 := Param{ DisplayName: "JSONArray-2", Config: datatypes.JSON("[\"c\", \"a\"]"), @@ -472,6 +471,10 @@ func TestJSONArrayQuery(t *testing.T) { DisplayName: "JSONArray-3", Config: datatypes.JSON("{\"test\": [\"a\", \"b\"]}"), } + cmp4 := Param{ + DisplayName: "JSONArray-4", + Config: datatypes.JSON("{\"test\": \"c\"}"), + } if err := DB.Create(&cmp1).Error; err != nil { t.Errorf("Failed to create param %v", err) @@ -482,6 +485,9 @@ func TestJSONArrayQuery(t *testing.T) { if err := DB.Create(&cmp3).Error; err != nil { t.Errorf("Failed to create param %v", err) } + if err := DB.Create(&cmp4).Error; err != nil { + t.Errorf("Failed to create param %v", err) + } var retSingle1 Param if err := DB.Where("id = ?", cmp2.ID).First(&retSingle1).Error; err != nil { @@ -507,5 +513,15 @@ func TestJSONArrayQuery(t *testing.T) { t.Fatalf("failed to find params with json value and keys, got error %v", err) } AssertEqual(t, len(retMultiple), 1) + + if err := DB.Where(datatypes.JSONArrayQuery("config").In([]string{"c", "a"})).Find(&retMultiple).Error; err != nil { + t.Fatalf("failed to find params with json value, got error %v", err) + } + AssertEqual(t, len(retMultiple), 1) + + if err := DB.Where(datatypes.JSONArrayQuery("config").In([]string{"c", "d"}, "test")).Find(&retMultiple).Error; err != nil { + t.Fatalf("failed to find params with json value and keys, got error %v", err) + } + AssertEqual(t, len(retMultiple), 1) } }