Skip to content

Commit

Permalink
metadata: fix all types being NativeType (apache#1052)
Browse files Browse the repository at this point in the history
Parse the correct TypeInfo from the cassandra string in the db. Fixing
representation to recursivly parse the nested types.
  • Loading branch information
Zariel authored Jan 23, 2018
1 parent d93886f commit dd47639
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 23 deletions.
58 changes: 47 additions & 11 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func dereference(i interface{}) interface{} {
return reflect.Indirect(reflect.ValueOf(i)).Interface()
}

func getCassandraType(name string) Type {
func getCassandraBaseType(name string) Type {
switch name {
case "ascii":
return TypeAscii
Expand All @@ -92,8 +92,10 @@ func getCassandraType(name string) Type {
return TypeTimestamp
case "uuid":
return TypeUUID
case "varchar", "text":
case "varchar":
return TypeVarchar
case "text":
return TypeText
case "varint":
return TypeVarint
case "timeuuid":
Expand All @@ -109,19 +111,53 @@ func getCassandraType(name string) Type {
case "TupleType":
return TypeTuple
default:
if strings.HasPrefix(name, "set") {
return TypeSet
} else if strings.HasPrefix(name, "list") {
return TypeList
} else if strings.HasPrefix(name, "map") {
return TypeMap
} else if strings.HasPrefix(name, "tuple") {
return TypeTuple
}
return TypeCustom
}
}

func getCassandraType(name string) TypeInfo {
if strings.HasPrefix(name, "frozen<") {
return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"))
} else if strings.HasPrefix(name, "set<") {
return CollectionType{
NativeType: NativeType{typ: TypeSet},
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<")),
}
} else if strings.HasPrefix(name, "list<") {
return CollectionType{
NativeType: NativeType{typ: TypeList},
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<")),
}
} else if strings.HasPrefix(name, "map<") {
names := strings.Split(strings.TrimPrefix(name[:len(name)-1], "map<"), ", ")
if len(names) != 2 {
panic(fmt.Sprintf("invalid map type: %v", name))
}

return CollectionType{
NativeType: NativeType{typ: TypeMap},
Key: getCassandraType(names[0]),
Elem: getCassandraType(names[1]),
}
} else if strings.HasPrefix(name, "tuple<") {
names := strings.Split(strings.TrimPrefix(name[:len(name)-1], "tuple<"), ", ")
types := make([]TypeInfo, len(names))

for i, name := range names {
types[i] = getCassandraType(name)
}

return TupleTypeInfo{
NativeType: NativeType{typ: TypeTuple},
Elems: types,
}
} else {
return NativeType{
typ: getCassandraBaseType(name),
}
}
}

func getApacheCassandraType(class string) Type {
switch strings.TrimPrefix(class, apacheCassandraTypePrefix) {
case "AsciiType":
Expand Down
74 changes: 74 additions & 0 deletions helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package gocql

import (
"reflect"
"testing"
)

func TestGetCassandraType_Set(t *testing.T) {
typ := getCassandraType("set<text>")
set, ok := typ.(CollectionType)
if !ok {
t.Fatalf("expected CollectionType got %T", typ)
} else if set.typ != TypeSet {
t.Fatalf("expected type %v got %v", TypeSet, set.typ)
}

inner, ok := set.Elem.(NativeType)
if !ok {
t.Fatalf("expected to get NativeType got %T", set.Elem)
} else if inner.typ != TypeText {
t.Fatalf("expected to get %v got %v for set value", TypeText, set.typ)
}
}

func TestGetCassandraType(t *testing.T) {
tests := []struct {
input string
exp TypeInfo
}{
{
"set<text>", CollectionType{
NativeType: NativeType{typ: TypeSet},

Elem: NativeType{typ: TypeText},
},
},
{
"map<text, varchar>", CollectionType{
NativeType: NativeType{typ: TypeMap},

Key: NativeType{typ: TypeText},
Elem: NativeType{typ: TypeVarchar},
},
},
{
"list<int>", CollectionType{
NativeType: NativeType{typ: TypeList},
Elem: NativeType{typ: TypeInt},
},
},
{
"tuple<int, int, text>", TupleTypeInfo{
NativeType: NativeType{typ: TypeTuple},

Elems: []TypeInfo{
NativeType{typ: TypeInt},
NativeType{typ: TypeInt},
NativeType{typ: TypeText},
},
},
},
}

for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
got := getCassandraType(test.input)

// TODO(zariel): define an equal method on the types?
if !reflect.DeepEqual(got, test.exp) {
t.Fatalf("expected %v got %v", test.exp, got)
}
})
}
}
11 changes: 11 additions & 0 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -2053,6 +2053,17 @@ type TupleTypeInfo struct {
Elems []TypeInfo
}

func (t TupleTypeInfo) String() string {
var buf bytes.Buffer
buf.WriteString(fmt.Sprintf("%s(", t.typ))
for _, elem := range t.Elems {
buf.WriteString(fmt.Sprintf("%s, ", elem))
}
buf.Truncate(buf.Len() - 2)
buf.WriteByte(')')
return buf.String()
}

func (t TupleTypeInfo) New() interface{} {
return reflect.New(goType(t)).Interface()
}
Expand Down
25 changes: 13 additions & 12 deletions metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,25 +226,26 @@ func compileMetadata(

// add columns from the schema data
for i := range columns {
col := &columns[i]
// decode the validator for TypeInfo and order
if columns[i].ClusteringOrder != "" { // Cassandra 3.x+
columns[i].Type = NativeType{typ: getCassandraType(columns[i].Validator)}
columns[i].Order = ASC
if columns[i].ClusteringOrder == "desc" {
columns[i].Order = DESC
if col.ClusteringOrder != "" { // Cassandra 3.x+
col.Type = getCassandraType(col.Validator)
col.Order = ASC
if col.ClusteringOrder == "desc" {
col.Order = DESC
}
} else {
validatorParsed := parseType(columns[i].Validator)
columns[i].Type = validatorParsed.types[0]
columns[i].Order = ASC
validatorParsed := parseType(col.Validator)
col.Type = validatorParsed.types[0]
col.Order = ASC
if validatorParsed.reversed[0] {
columns[i].Order = DESC
col.Order = DESC
}
}

table := keyspace.Tables[columns[i].Table]
table.Columns[columns[i].Name] = &columns[i]
table.OrderedColumns = append(table.OrderedColumns, columns[i].Name)
table := keyspace.Tables[col.Table]
table.Columns[col.Name] = col
table.OrderedColumns = append(table.OrderedColumns, col.Name)
}

if protoVersion == protoVersion1 {
Expand Down

0 comments on commit dd47639

Please sign in to comment.