Skip to content

Commit

Permalink
Feat(graphql): Add vector support to graphql (#9074)
Browse files Browse the repository at this point in the history
Adds support for vector predicate in GraphQL. Introduced new queries
like similar_to() in graphql
  • Loading branch information
harshil-goel authored Apr 18, 2024
1 parent 3a8de31 commit 041ba15
Show file tree
Hide file tree
Showing 95 changed files with 2,227 additions and 167 deletions.
35 changes: 19 additions & 16 deletions dgraphtest/local_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"io"
"log"
"math/rand"
"net/http"
"os"
"os/exec"
Expand Down Expand Up @@ -346,7 +347,8 @@ func (c *LocalCluster) Start() error {
if err1 := c.Stop(); err1 != nil {
log.Printf("[WARNING] error while stopping :%v", err)
}
c.Cleanup(false)
c.Cleanup(true)
c.conf.prefix = fmt.Sprintf("dgraphtest-%d", rand.NewSource(time.Now().UnixNano()).Int63()%1000000)
if err := c.init(); err != nil {
c.Cleanup(true)
return err
Expand Down Expand Up @@ -449,11 +451,7 @@ func (c *LocalCluster) HealthCheck(zeroOnly bool) error {
if !zo.isRunning {
break
}
url, err := zo.healthURL(c)
if err != nil {
return errors.Wrap(err, "error getting health URL")
}
if err := c.containerHealthCheck(url); err != nil {
if err := c.containerHealthCheck(zo.healthURL); err != nil {
return err
}
log.Printf("[INFO] container [%v] passed health check", zo.containerName)
Expand All @@ -470,11 +468,7 @@ func (c *LocalCluster) HealthCheck(zeroOnly bool) error {
if !aa.isRunning {
break
}
url, err := aa.healthURL(c)
if err != nil {
return errors.Wrap(err, "error getting health URL")
}
if err := c.containerHealthCheck(url); err != nil {
if err := c.containerHealthCheck(aa.healthURL); err != nil {
return err
}
log.Printf("[INFO] container [%v] passed health check", aa.containerName)
Expand All @@ -486,18 +480,27 @@ func (c *LocalCluster) HealthCheck(zeroOnly bool) error {
return nil
}

func (c *LocalCluster) containerHealthCheck(url string) error {
func (c *LocalCluster) containerHealthCheck(url func(c *LocalCluster) (string, error)) error {
endpoint, err := url(c)
if err != nil {
return errors.Wrap(err, "error getting health URL")
}
for i := 0; i < 60; i++ {
time.Sleep(waitDurBeforeRetry)

req, err := http.NewRequest(http.MethodGet, url, nil)
endpoint, err = url(c)
if err != nil {
return errors.Wrap(err, "error getting health URL")
}

req, err := http.NewRequest(http.MethodGet, endpoint, nil)
if err != nil {
log.Printf("[WARNING] error building req for endpoint [%v], err: [%v]", url, err)
log.Printf("[WARNING] error building req for endpoint [%v], err: [%v]", endpoint, err)
continue
}
body, err := doReq(req)
if err != nil {
log.Printf("[WARNING] error hitting health endpoint [%v], err: [%v]", url, err)
log.Printf("[WARNING] error hitting health endpoint [%v], err: [%v]", endpoint, err)
continue
}
resp := string(body)
Expand All @@ -523,7 +526,7 @@ func (c *LocalCluster) containerHealthCheck(url string) error {
return nil
}

return fmt.Errorf("health failed, cluster took too long to come up [%v]", url)
return fmt.Errorf("health failed, cluster took too long to come up [%v]", endpoint)
}

func (c *LocalCluster) waitUntilLogin() error {
Expand Down
2 changes: 1 addition & 1 deletion graphql/bench/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ type Owner {
hasRestaurants: [Restaurant] @hasInverse(field: owner)
}

# Dgraph.Authorization {"VerificationKey":"secretkey","Header":"X-Test-Auth","Namespace":"https://xyz.io/jwt/claims","Algo":"HS256","Audience":["aud1","63do0q16n6ebjgkumu05kkeian","aud5"]}
# Dgraph.Authorization {"VerificationKey":"secretkey","Header":"X-Test-Auth","Namespace":"https://xyz.io/jwt/claims","Algo":"HS256","Audience":["aud1","63do0q16n6ebjgkumu05kkeian","aud5"]}
2 changes: 1 addition & 1 deletion graphql/bench/schema_auth.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,4 @@ type Owner @auth(
hasRestaurants: [Restaurant] @hasInverse(field: owner)
}

# Dgraph.Authorization {"VerificationKey":"secretkey","Header":"X-Test-Auth","Namespace":"https://xyz.io/jwt/claims","Algo":"HS256","Audience":["aud1","63do0q16n6ebjgkumu05kkeian","aud5"]}
# Dgraph.Authorization {"VerificationKey":"secretkey","Header":"X-Test-Auth","Namespace":"https://xyz.io/jwt/claims","Algo":"HS256","Audience":["aud1","63do0q16n6ebjgkumu05kkeian","aud5"]}
33 changes: 31 additions & 2 deletions graphql/dgraph/graphquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@ import (
// validate query, and so doesn't return an error if query is 'malformed' - it might
// just write something that wouldn't parse as a Dgraph query.
func AsString(queries []*dql.GraphQuery) string {
if queries == nil {
if len(queries) == 0 {
return ""
}

var b strings.Builder
x.Check2(b.WriteString("query {\n"))
queryName := queries[len(queries)-1].Attr
x.Check2(b.WriteString("query "))
addQueryVars(&b, queryName, queries[0].Args)
x.Check2(b.WriteString("{\n"))

numRewrittenQueries := 0
for _, q := range queries {
if q == nil {
Expand All @@ -54,6 +58,24 @@ func AsString(queries []*dql.GraphQuery) string {
return b.String()
}

func addQueryVars(b *strings.Builder, queryName string, args map[string]string) {
dollarFound := false
for name, val := range args {
if strings.HasPrefix(name, "$") {
if !dollarFound {
x.Check2(b.WriteString(queryName + "("))
x.Check2(b.WriteString(name + ": " + val))
dollarFound = true
} else {
x.Check2(b.WriteString(", " + name + ": " + val))
}
}
}
if dollarFound {
x.Check2(b.WriteString(") "))
}
}

func writeQuery(b *strings.Builder, query *dql.GraphQuery, prefix string) {
if query.Var != "" || query.Alias != "" || query.Attr != "" {
x.Check2(b.WriteString(prefix))
Expand Down Expand Up @@ -145,6 +167,9 @@ func writeRoot(b *strings.Builder, q *dql.GraphQuery) {
}

switch {
// TODO: Instead of the hard-coded strings "uid", "type", etc., use the
// pre-defined constants in dql/parser.go such as dql.uidFunc, dql.typFunc,
// etc. This of course will require that we make these constants public.
case q.Func.Name == "uid":
x.Check2(b.WriteString("(func: "))
writeUIDFunc(b, q.Func.UID, q.Func.Args)
Expand All @@ -154,6 +179,10 @@ func writeRoot(b *strings.Builder, q *dql.GraphQuery) {
x.Check2(b.WriteString("(func: eq("))
writeFilterArguments(b, q.Func.Args)
x.Check2(b.WriteRune(')'))
case q.Func.Name == "similar_to":
x.Check2(b.WriteString("(func: similar_to("))
writeFilterArguments(b, q.Func.Args)
x.Check2(b.WriteRune(')'))
}
writeOrderAndPage(b, q, true)
x.Check2(b.WriteRune(')'))
Expand Down
8 changes: 4 additions & 4 deletions graphql/e2e/auth/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ type Contact @auth(
query: { rule: "{$ContactRole: { eq: \"ADMINISTRATOR\"}}" }
) {
id: ID!
nickName: String @search(by: [exact, term, fulltext, regexp])
nickName: String @search(by: ["exact", "term", "fulltext", "regexp"])
adminTasks: [AdminTask] @hasInverse(field: forContact)
tasks: [Task] @hasInverse(field: forContact)
}
Expand All @@ -577,14 +577,14 @@ type AdminTask @auth(
query: { rule: "{$TaskRole: { eq: \"ADMINISTRATOR\"}}" }
) {
id: ID!
name: String @search(by: [exact, term, fulltext, regexp])
name: String @search(by: ["exact", "term", "fulltext", "regexp"])
occurrences: [TaskOccurrence] @hasInverse(field: adminTask)
forContact: Contact @hasInverse(field: adminTasks)
}

type Task {
id: ID!
name: String @search(by: [exact, term, fulltext, regexp])
name: String @search(by: ["exact", "term", "fulltext", "regexp"])
occurrences: [TaskOccurrence] @hasInverse(field: task)
forContact: Contact @hasInverse(field: tasks)
}
Expand All @@ -608,7 +608,7 @@ type TaskOccurrence @auth(
task: Task @hasInverse(field: occurrences)
adminTask: AdminTask @hasInverse(field: occurrences)
isPublic: Boolean @search
role: String @search(by: [exact, term, fulltext, regexp])
role: String @search(by: ["exact", "term", "fulltext", "regexp"])
}

type Author {
Expand Down
2 changes: 1 addition & 1 deletion graphql/e2e/common/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -1285,7 +1285,7 @@ func stringExactFilters(t *testing.T) {

func scalarListFilters(t *testing.T) {

// tags is a list of strings with @search(by: exact). So all the filters
// tags is a list of strings with @search(by: "exact"). So all the filters
// lt, le, ... mean "is there something in the list that's lt 'Dgraph'", etc.

cases := map[string]struct {
Expand Down
6 changes: 3 additions & 3 deletions graphql/e2e/custom_logic/custom_logic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ func TestCustomFieldsShouldPassBody(t *testing.T) {

schema := `
type User {
id: String! @id @search(by: [hash, regexp])
id: String! @id @search(by: ["hash", "regexp"])
address:String
name: String
@custom(
Expand Down Expand Up @@ -2573,7 +2573,7 @@ func TestCustomDQL(t *testing.T) {
}
type Tweets implements Node {
id: ID!
text: String! @search(by: [fulltext, exact])
text: String! @search(by: ["fulltext", "exact"])
user: User
timestamp: DateTime! @search
}
Expand Down Expand Up @@ -2864,7 +2864,7 @@ func TestCustomFieldsWithRestError(t *testing.T) {
}
type User {
id: String! @id @search(by: [hash, regexp])
id: String! @id @search(by: ["hash", "regexp"])
name: String
@custom(
http: {
Expand Down
2 changes: 1 addition & 1 deletion graphql/e2e/directives/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -423,4 +423,4 @@ type CricketTeam implements Team {
type LibraryManager {
name: String! @id
manages: [LibraryMember]
}
}
2 changes: 1 addition & 1 deletion graphql/e2e/normal/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -422,4 +422,4 @@ type CricketTeam implements Team {
type LibraryManager {
name: String! @id
manages: [LibraryMember]
}
}
4 changes: 3 additions & 1 deletion graphql/e2e/schema/apollo_service_response.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ enum DgraphIndex {
day
hour
geo
hnsw
}

input AuthRule {
Expand Down Expand Up @@ -196,7 +197,8 @@ input GenerateMutationParams {
}

directive @hasInverse(field: String!) on FIELD_DEFINITION
directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION
directive @search(by: [String!]) on FIELD_DEFINITION
directive @embedding on FIELD_DEFINITION
directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION
directive @id(interface: Boolean) on FIELD_DEFINITION
directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION
Expand Down
4 changes: 3 additions & 1 deletion graphql/e2e/schema/generatedSchema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ enum DgraphIndex {
day
hour
geo
hnsw
}

input AuthRule {
Expand Down Expand Up @@ -177,7 +178,8 @@ input GenerateMutationParams {
}

directive @hasInverse(field: String!) on FIELD_DEFINITION
directive @search(by: [DgraphIndex!]) on FIELD_DEFINITION
directive @search(by: [String!]) on FIELD_DEFINITION
directive @embedding on FIELD_DEFINITION
directive @dgraph(type: String, pred: String) on OBJECT | INTERFACE | FIELD_DEFINITION
directive @id(interface: Boolean) on FIELD_DEFINITION
directive @withSubscription on OBJECT | INTERFACE | FIELD_DEFINITION
Expand Down
2 changes: 1 addition & 1 deletion graphql/e2e/schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ func TestLargeSchemaUpdate(t *testing.T) {

schema := "type LargeSchema {"
for i := 1; i <= numFields; i++ {
schema = schema + "\n" + fmt.Sprintf("field%d: String! @search(by: [regexp])", i)
schema = schema + "\n" + fmt.Sprintf("field%d: String! @search(by: [\"regexp\"])", i)
}
schema = schema + "\n}"

Expand Down
2 changes: 1 addition & 1 deletion graphql/e2e/subscription/subscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ const (
}
type Customer {
username: String! @id @search(by: [hash, regexp])
username: String! @id @search(by: ["hash", "regexp"])
reviews: [Review] @hasInverse(field: by)
}
Expand Down
6 changes: 6 additions & 0 deletions graphql/resolve/mutation_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,12 @@ func rewriteObject(
fieldName = fieldName[1 : len(fieldName)-1]
}

if fieldDef.HasEmbeddingDirective() {
// embedding is a JSON array of numbers. Rewrite it as a string, for now
var valBytes []byte
valBytes, _ = json.Marshal(val)
val = string(valBytes)
}
// TODO: Write a function for aggregating data of fragment from child nodes.
switch val := val.(type) {
case map[string]interface{}:
Expand Down
Loading

0 comments on commit 041ba15

Please sign in to comment.