From 6832a796414d0414fbf83e36eec80a84e0dd6ade Mon Sep 17 00:00:00 2001 From: Jaume Marhuenda Date: Fri, 19 Oct 2018 15:03:03 -0700 Subject: [PATCH] Add metadata for functions and aggregates (#1204) * Add metadata for functions and aggregates * Review comments * Style nits * Review comments --- cassandra_test.go | 132 +++++++++++++++++++++++++++++++++++ common_test.go | 33 +++++++++ helpers.go | 14 ++++ metadata.go | 171 +++++++++++++++++++++++++++++++++++++++++++++- metadata_test.go | 4 +- 5 files changed, 351 insertions(+), 3 deletions(-) diff --git a/cassandra_test.go b/cassandra_test.go index 023a809ca..9e3b1db6b 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -2184,6 +2184,126 @@ func TestGetColumnMetadata(t *testing.T) { } } +func TestAggregateMetadata(t *testing.T) { + session := createSession(t) + defer session.Close() + createAggregate(t, session) + + aggregates, err := getAggregatesMetadata(session, "gocql_test") + if err != nil { + t.Fatalf("failed to query aggregate metadata with err: %v", err) + } + if aggregates == nil { + t.Fatal("failed to query aggregate metadata, nil returned") + } + if len(aggregates) != 1 { + t.Fatal("expected only a single aggregate") + } + aggregate := aggregates[0] + + expectedAggregrate := AggregateMetadata{ + Keyspace: "gocql_test", + Name: "average", + ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt}}, + InitCond: "(0, 0)", + ReturnType: NativeType{typ: TypeDouble}, + StateType: TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeBigInt}, + }, + }, + stateFunc: "avgstate", + finalFunc: "avgfinal", + } + + // In this case cassandra is returning a blob + if flagCassVersion.Before(3, 0, 0) { + expectedAggregrate.InitCond = string([]byte{0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0}) + } + + if !reflect.DeepEqual(aggregate, expectedAggregrate) { + t.Fatalf("aggregate is %+v, but expected %+v", aggregate, expectedAggregrate) + } +} + +func TestFunctionMetadata(t *testing.T) { + session := createSession(t) + defer session.Close() + createFunctions(t, session) + + functions, err := getFunctionsMetadata(session, "gocql_test") + if err != nil { + t.Fatalf("failed to query function metadata with err: %v", err) + } + if functions == nil { + t.Fatal("failed to query function metadata, nil returned") + } + if len(functions) != 2 { + t.Fatal("expected two functions") + } + avgState := functions[1] + avgFinal := functions[0] + + avgStateBody := "if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;" + expectedAvgState := FunctionMetadata{ + Keyspace: "gocql_test", + Name: "avgstate", + ArgumentTypes: []TypeInfo{ + TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeBigInt}, + }, + }, + NativeType{typ: TypeInt}, + }, + ArgumentNames: []string{"state", "val"}, + ReturnType: TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeBigInt}, + }, + }, + CalledOnNullInput: true, + Language: "java", + Body: avgStateBody, + } + if !reflect.DeepEqual(avgState, expectedAvgState) { + t.Fatalf("function is %+v, but expected %+v", avgState, expectedAvgState) + } + + finalStateBody := "double r = 0; if (state.getInt(0) == 0) return null; r = state.getLong(1); r/= state.getInt(0); return Double.valueOf(r);" + expectedAvgFinal := FunctionMetadata{ + Keyspace: "gocql_test", + Name: "avgfinal", + ArgumentTypes: []TypeInfo{ + TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeBigInt}, + }, + }, + }, + ArgumentNames: []string{"state"}, + ReturnType: NativeType{typ: TypeDouble}, + CalledOnNullInput: true, + Language: "java", + Body: finalStateBody, + } + if !reflect.DeepEqual(avgFinal, expectedAvgFinal) { + t.Fatalf("function is %+v, but expected %+v", avgFinal, expectedAvgFinal) + } +} + // Integration test of querying and composition the keyspace metadata func TestKeyspaceMetadata(t *testing.T) { session := createSession(t) @@ -2192,6 +2312,7 @@ func TestKeyspaceMetadata(t *testing.T) { if err := createTable(session, "CREATE TABLE gocql_test.test_metadata (first_id int, second_id int, third_id int, PRIMARY KEY (first_id, second_id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } + createAggregate(t, session) if err := session.Query("CREATE INDEX index_metadata ON test_metadata ( third_id )").Exec(); err != nil { t.Fatalf("failed to create index with err: %v", err) @@ -2246,6 +2367,17 @@ func TestKeyspaceMetadata(t *testing.T) { // TODO(zariel): scan index info from system_schema t.Errorf("Expected column index named 'index_metadata' but was '%s'", thirdColumn.Index.Name) } + + aggregate, found := keyspaceMetadata.Aggregates["average"] + if !found { + t.Fatal("failed to find the aggreate in metadata") + } + if aggregate.FinalFunc.Name != "avgfinal" { + t.Fatalf("expected final function %s, but got %s", "avgFinal", aggregate.FinalFunc.Name) + } + if aggregate.StateFunc.Name != "avgstate" { + t.Fatalf("expected state function %s, but got %s", "avgstate", aggregate.StateFunc.Name) + } } // Integration test of the routing key calculation diff --git a/common_test.go b/common_test.go index bcd88bbb5..a269101ed 100644 --- a/common_test.go +++ b/common_test.go @@ -170,6 +170,39 @@ func createTestSession() *Session { return session } +func createFunctions(t *testing.T, session *Session) { + if err := session.Query(` + CREATE OR REPLACE FUNCTION gocql_test.avgState ( state tuple, val int ) + CALLED ON NULL INPUT + RETURNS tuple + LANGUAGE java AS + $$if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;$$; `).Exec(); err != nil { + t.Fatalf("failed to create function with err: %v", err) + } + if err := session.Query(` + CREATE OR REPLACE FUNCTION gocql_test.avgFinal ( state tuple ) + CALLED ON NULL INPUT + RETURNS double + LANGUAGE java AS + $$double r = 0; if (state.getInt(0) == 0) return null; r = state.getLong(1); r/= state.getInt(0); return Double.valueOf(r);$$ + `).Exec(); err != nil { + t.Fatalf("failed to create function with err: %v", err) + } +} + +func createAggregate(t *testing.T, session *Session) { + createFunctions(t, session) + if err := session.Query(` + CREATE OR REPLACE AGGREGATE gocql_test.average(int) + SFUNC avgState + STYPE tuple + FINALFUNC avgFinal + INITCOND (0,0); + `).Exec(); err != nil { + t.Fatalf("failed to create aggregate with err: %v", err) + } +} + func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator { return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) { return newAddr, newPort diff --git a/helpers.go b/helpers.go index 89112a65f..7259c7467 100644 --- a/helpers.go +++ b/helpers.go @@ -191,6 +191,20 @@ func splitCompositeTypes(name string) []string { return parts } +func apacheToCassandraType(t string) string { + t = strings.Replace(t, apacheCassandraTypePrefix, "", -1) + t = strings.Replace(t, "(", "<", -1) + t = strings.Replace(t, ")", ">", -1) + types := strings.FieldsFunc(t, func(r rune) bool { + return r == '<' || r == '>' || r == ',' + }) + for _, typ := range types { + t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1) + } + // This is done so it exactly matches what Cassandra returns + return strings.Replace(t, ",", ", ", -1) +} + func getApacheCassandraType(class string) Type { switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { case "AsciiType": diff --git a/metadata.go b/metadata.go index db496a6e5..083c51dfa 100644 --- a/metadata.go +++ b/metadata.go @@ -20,6 +20,8 @@ type KeyspaceMetadata struct { StrategyClass string StrategyOptions map[string]interface{} Tables map[string]*TableMetadata + Functions map[string]*FunctionMetadata + Aggregates map[string]*AggregateMetadata } // schema metadata for a table (a.k.a. column family) @@ -52,6 +54,33 @@ type ColumnMetadata struct { Index ColumnIndexMetadata } +// FunctionMetadata holds metadata for function constructs +type FunctionMetadata struct { + Keyspace string + Name string + ArgumentTypes []TypeInfo + ArgumentNames []string + Body string + CalledOnNullInput bool + Language string + ReturnType TypeInfo +} + +// AggregateMetadata holds metadata for aggregate constructs +type AggregateMetadata struct { + Keyspace string + Name string + ArgumentTypes []TypeInfo + FinalFunc FunctionMetadata + InitCond string + ReturnType TypeInfo + StateFunc FunctionMetadata + StateType TypeInfo + + stateFunc string + finalFunc string +} + // the ordering of the column with regard to its comparator type ColumnOrder bool @@ -196,9 +225,17 @@ func (s *schemaDescriber) refreshSchema(keyspaceName string) error { if err != nil { return err } + functions, err := getFunctionsMetadata(s.session, keyspaceName) + if err != nil { + return err + } + aggregates, err := getAggregatesMetadata(s.session, keyspaceName) + if err != nil { + return err + } // organize the schema data - compileMetadata(s.session.cfg.ProtoVersion, keyspace, tables, columns) + compileMetadata(s.session.cfg.ProtoVersion, keyspace, tables, columns, functions, aggregates) // update the cache s.cache[keyspaceName] = keyspace @@ -216,6 +253,8 @@ func compileMetadata( keyspace *KeyspaceMetadata, tables []TableMetadata, columns []ColumnMetadata, + functions []FunctionMetadata, + aggregates []AggregateMetadata, ) { keyspace.Tables = make(map[string]*TableMetadata) for i := range tables { @@ -223,6 +262,16 @@ func compileMetadata( keyspace.Tables[tables[i].Name] = &tables[i] } + keyspace.Functions = make(map[string]*FunctionMetadata, len(functions)) + for i := range functions { + keyspace.Functions[functions[i].Name] = &functions[i] + } + keyspace.Aggregates = make(map[string]*AggregateMetadata, len(aggregates)) + for _, aggregate := range aggregates { + aggregate.FinalFunc = *keyspace.Functions[aggregate.finalFunc] + aggregate.StateFunc = *keyspace.Functions[aggregate.stateFunc] + keyspace.Aggregates[aggregate.Name] = &aggregate + } // add columns from the schema data for i := range columns { @@ -793,6 +842,126 @@ func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata, return columns, nil } +func getTypeInfo(t string) TypeInfo { + if strings.HasPrefix(t, apacheCassandraTypePrefix) { + t = apacheToCassandraType(t) + } + return getCassandraType(t) +} + +func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMetadata, error) { + if session.cfg.ProtoVersion == protoVersion1 { + return nil, nil + } + var tableName string + if session.useSystemSchema { + tableName = "system_schema.functions" + } else { + tableName = "system.schema_functions" + } + stmt := fmt.Sprintf(` + SELECT + function_name, + argument_types, + argument_names, + body, + called_on_null_input, + language, + return_type + FROM %s + WHERE keyspace_name = ?`, tableName) + + var functions []FunctionMetadata + + rows := session.control.query(stmt, keyspaceName).Scanner() + for rows.Next() { + function := FunctionMetadata{Keyspace: keyspaceName} + var argumentTypes []string + var returnType string + err := rows.Scan(&function.Name, + &argumentTypes, + &function.ArgumentNames, + &function.Body, + &function.CalledOnNullInput, + &function.Language, + &returnType, + ) + if err != nil { + return nil, err + } + function.ReturnType = getTypeInfo(returnType) + function.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) + for i, argumentType := range argumentTypes { + function.ArgumentTypes[i] = getTypeInfo(argumentType) + } + functions = append(functions, function) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return functions, nil +} + +func getAggregatesMetadata(session *Session, keyspaceName string) ([]AggregateMetadata, error) { + if session.cfg.ProtoVersion == protoVersion1 { + return nil, nil + } + var tableName string + if session.useSystemSchema { + tableName = "system_schema.aggregates" + } else { + tableName = "system.schema_aggregates" + } + + stmt := fmt.Sprintf(` + SELECT + aggregate_name, + argument_types, + final_func, + initcond, + return_type, + state_func, + state_type + FROM %s + WHERE keyspace_name = ?`, tableName) + + var aggregates []AggregateMetadata + + rows := session.control.query(stmt, keyspaceName).Scanner() + for rows.Next() { + aggregate := AggregateMetadata{Keyspace: keyspaceName} + var argumentTypes []string + var returnType string + var stateType string + err := rows.Scan(&aggregate.Name, + &argumentTypes, + &aggregate.finalFunc, + &aggregate.InitCond, + &returnType, + &aggregate.stateFunc, + &stateType, + ) + if err != nil { + return nil, err + } + aggregate.ReturnType = getTypeInfo(returnType) + aggregate.StateType = getTypeInfo(stateType) + aggregate.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) + for i, argumentType := range argumentTypes { + aggregate.ArgumentTypes[i] = getTypeInfo(argumentType) + } + aggregates = append(aggregates, aggregate) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return aggregates, nil +} + // type definition parser state type typeParser struct { input string diff --git a/metadata_test.go b/metadata_test.go index cc4631acc..cb924fc3b 100644 --- a/metadata_test.go +++ b/metadata_test.go @@ -94,7 +94,7 @@ func TestCompileMetadata(t *testing.T) { {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "schema_version", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UUIDType"}, {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "tokens", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.UTF8Type)"}, } - compileMetadata(1, keyspace, tables, columns) + compileMetadata(1, keyspace, tables, columns, nil, nil) assertKeyspaceMetadata( t, keyspace, @@ -375,7 +375,7 @@ func TestCompileMetadata(t *testing.T) { Validator: "org.apache.cassandra.db.marshal.UTF8Type", }, } - compileMetadata(2, keyspace, tables, columns) + compileMetadata(2, keyspace, tables, columns, nil, nil) assertKeyspaceMetadata( t, keyspace,