Skip to content

Commit

Permalink
remove iter.Scan (#1484)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zariel authored Jun 1, 2024
1 parent 50492f1 commit b0c9c24
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 121 deletions.
67 changes: 27 additions & 40 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1342,96 +1342,83 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
}
}

func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) {
func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) *Iter {
q := c.session.Query(statement, values...).Consistency(One)
q.trace = nil
q.skipPrepare = true
q.disableSkipMetadata = true
return c.executeQuery(ctx, q)
}

func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
func (c *Conn) awaitSchemaAgreement(ctx context.Context) error {
const (
peerSchemas = "SELECT * FROM system.peers"
localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
)

var versions map[string]struct{}
var schemaVersion string

endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement)
for time.Now().Before(endDeadline) {
iter := c.query(ctx, peerSchemas)

versions = make(map[string]struct{})

rows, err := iter.SliceMap()
fetchVersions := func() (map[string]struct{}, error) {
rows, err := c.query(ctx, peerSchemas).SliceMap()
if err != nil {
goto cont
return nil, err
}

versions := make(map[string]struct{}, len(rows)+1)
for _, row := range rows {
host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port})
if err != nil {
goto cont
return nil, err
}
if !isValidPeer(host) || host.schemaVersion == "" {
Logger.Printf("invalid peer or peer with empty schema_version: peer=%q", host)
continue
}

versions[host.schemaVersion] = struct{}{}
}

if err = iter.Close(); err != nil {
goto cont
}

iter = c.query(ctx, localSchemas)
for iter.Scan(&schemaVersion) {
sc := c.query(ctx, localSchemas).Scanner()
for sc.Next() {
var schemaVersion string
if err := sc.Scan(&schemaVersion); err != nil {
return nil, err
}
versions[schemaVersion] = struct{}{}
schemaVersion = ""
}

if err = iter.Close(); err != nil {
goto cont
if err := sc.Err(); err != nil {
return nil, err
}
return versions, nil
}

if len(versions) <= 1 {
endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement)
for time.Now().Before(endDeadline) {
versions, err := fetchVersions()
if err == nil && len(versions) == 1 {
return nil
}

cont:
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(200 * time.Millisecond):
}
}

if err != nil {
return err
}

schemas := make([]string, 0, len(versions))
for schema := range versions {
schemas = append(schemas, schema)
}

// not exported
return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas)
return errors.New("gocql: cluster schema versions not consistent")
}

func (c *Conn) localHostInfo(ctx context.Context) (*HostInfo, error) {
row, err := c.query(ctx, "SELECT * FROM system.local WHERE key='local'").rowMap()
if err != nil {
return nil, err
m := make(map[string]interface{})
iter := c.query(ctx, "SELECT * FROM system.local WHERE key='local'")
if ok := iter.MapScan(m); !ok {
return nil, iter.err
}

port := c.conn.RemoteAddr().(*net.TCPAddr).Port

// TODO(zariel): avoid doing this here
host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.connectAddress, port: port})
host, err := c.session.hostInfoFromMap(m, &HostInfo{connectAddress: c.host.connectAddress, port: port})
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ func TestCancel(t *testing.T) {

go func() {
if err := qry.Exec(); err != context.Canceled {
t.Fatalf("expected to get context cancel error: '%v', got '%v'", context.Canceled, err)
t.Errorf("expected to get context cancel error: '%v', got '%v'", context.Canceled, err)
}
wg.Done()
}()
Expand Down
89 changes: 47 additions & 42 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,66 +299,74 @@ func TupleColumnName(c string, n int) string {
return fmt.Sprintf("%s[%d]", c, n)
}

func (iter *Iter) RowData() (RowData, error) {
if iter.err != nil {
return RowData{}, iter.err
}

columns := make([]string, 0, len(iter.Columns()))
values := make([]interface{}, 0, len(iter.Columns()))
func (iter *Iter) rowData() (RowData, error) {
// TODO: unexport this? What is it used for?

columns := make([]string, 0, len(iter.meta.columns))
for _, column := range iter.Columns() {
if c, ok := column.TypeInfo.(TupleTypeInfo); !ok {
val := column.TypeInfo.New()
columns = append(columns, column.Name)
values = append(values, val)
} else {
for i, elem := range c.Elems {
for i := range c.Elems {
columns = append(columns, TupleColumnName(column.Name, i))
values = append(values, elem.New())
}
}
}

rowData := RowData{
return RowData{
Columns: columns,
Values: values,
}

return rowData, nil
Values: iter.resultValues(),
}, nil
}

// TODO(zariel): is it worth exporting this?
func (iter *Iter) rowMap() (map[string]interface{}, error) {
if iter.err != nil {
return nil, iter.err
func (iter *Iter) resultValues() []interface{} {
values := make([]interface{}, 0, len(iter.meta.columns))

for _, column := range iter.meta.columns {
if c, ok := column.TypeInfo.(TupleTypeInfo); !ok {
val := column.TypeInfo.New()
values = append(values, val)
} else {
for _, elem := range c.Elems {
values = append(values, elem.New())
}
}
}
return values
}

rowData, _ := iter.RowData()
iter.Scan(rowData.Values...)
m := make(map[string]interface{}, len(rowData.Columns))
rowData.rowMap(m)
return m, nil
func (iter *Iter) scanRow(dest ...interface{}) error {
sc := iter.Scanner()
if sc.Next() {
if err := sc.Scan(dest...); err != nil {
return err
}
}
return sc.Err()
}

// SliceMap is a helper function to make the API easier to use
// returns the data from the query in the form of []map[string]interface{}
func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
if iter.err != nil {
return nil, iter.err
rowData, err := iter.rowData()
if err != nil {
return nil, err
}

// Not checking for the error because we just did
rowData, _ := iter.RowData()
dataToReturn := make([]map[string]interface{}, 0)
for iter.Scan(rowData.Values...) {
var dataToReturn []map[string]interface{}
sc := iter.Scanner()
for sc.Next() {
row := make([]interface{}, len(rowData.Values))
m := make(map[string]interface{}, len(rowData.Columns))
rowData.rowMap(m)
for i, col := range rowData.Columns {
m[col] = row[i]
}
dataToReturn = append(dataToReturn, m)
}
if iter.err != nil {
return nil, iter.err
if err := sc.Err(); err != nil {
return nil, err
}

return dataToReturn, nil
}

Expand Down Expand Up @@ -401,20 +409,18 @@ func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
// fmt.Printf("First: %s Age: %d Address: %q\n", fullName.FirstName, age, address)
// }
func (iter *Iter) MapScan(m map[string]interface{}) bool {
if iter.err != nil {
rowData, err := iter.rowData()
if err != nil {
return false
}

// Not checking for the error because we just did
rowData, _ := iter.RowData()

for i, col := range rowData.Columns {
if dest, ok := m[col]; ok {
for i, col := range iter.meta.columns {
if dest, ok := m[col.Name]; ok {
rowData.Values[i] = dest
}
}

if iter.Scan(rowData.Values...) {
if iter.scan(rowData.Values...) {
rowData.rowMap(m)
return true
}
Expand All @@ -434,5 +440,4 @@ func LookupIP(host string) ([]net.IP, error) {
return nil, &net.DNSError{}
}
return net.LookupIP(host)

}
14 changes: 6 additions & 8 deletions metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ func getKeyspaceMetadata(session *Session, keyspaceName string) (*KeyspaceMetada
if iter.NumRows() == 0 {
return nil, ErrKeyspaceDoesNotExist
}
iter.Scan(&keyspace.DurableWrites, &replication)
iter.scan(&keyspace.DurableWrites, &replication)
err := iter.Close()
if err != nil {
return nil, fmt.Errorf("Error querying keyspace schema: %v", err)
Expand All @@ -541,7 +541,7 @@ func getKeyspaceMetadata(session *Session, keyspaceName string) (*KeyspaceMetada
if iter.NumRows() == 0 {
return nil, ErrKeyspaceDoesNotExist
}
iter.Scan(&keyspace.DurableWrites, &keyspace.StrategyClass, &strategyOptionsJSON)
iter.scan(&keyspace.DurableWrites, &keyspace.StrategyClass, &strategyOptionsJSON)
err := iter.Close()
if err != nil {
return nil, fmt.Errorf("Error querying keyspace schema: %v", err)
Expand Down Expand Up @@ -590,14 +590,12 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
}

scan = func(iter *Iter, table *TableMetadata) bool {
r := iter.Scan(
&table.Name,
)
r := iter.scan(&table.Name)
if !r {
iter = switchIter()
if iter != nil {
switchIter = func() *Iter { return nil }
r = iter.Scan(&table.Name)
r = iter.scan(&table.Name)
}
}
return r
Expand All @@ -617,7 +615,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
WHERE keyspace_name = ?`

scan = func(iter *Iter, table *TableMetadata) bool {
return iter.Scan(
return iter.scan(
&table.Name,
&table.KeyValidator,
&table.Comparator,
Expand All @@ -638,7 +636,7 @@ func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, e
WHERE keyspace_name = ?`

scan = func(iter *Iter, table *TableMetadata) bool {
return iter.Scan(
return iter.scan(
&table.Name,
&table.KeyValidator,
&table.Comparator,
Expand Down
Loading

0 comments on commit b0c9c24

Please sign in to comment.