Skip to content

Commit fde0390

Browse files
committed
[ENH]: Optimize GetCollections and remove usage of raw gorm
1 parent 71aae94 commit fde0390

File tree

1 file changed

+52
-31
lines changed

1 file changed

+52
-31
lines changed

go/pkg/sysdb/metastore/db/dao/collection.go

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@ package dao
22

33
import (
44
"errors"
5-
"fmt"
65
"sort"
76
"time"
87

98
"github.com/chroma-core/chroma/go/pkg/common"
10-
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbcore"
119
"github.com/jackc/pgx/v5/pgconn"
1210
"gorm.io/gorm/clause"
1311

@@ -178,15 +176,36 @@ func (s *collectionDb) getCollections(ids []string, name *string, tenantID strin
178176
MetadataUpdatedAt *time.Time `gorm:"column:metadata_updated_at"`
179177
}
180178

181-
query := s.db.Table("collections").
182-
Select("collections.id as collection_id, collections.name as collection_name, collections.configuration_json_str, collections.dimension, collections.database_id AS database_id, collections.ts as collection_ts, collections.is_deleted, collections.created_at as collection_created_at, collections.updated_at as collection_updated_at, collections.log_position, collections.version, collections.version_file_name, collections.root_collection_id, NULLIF(collections.lineage_file_name, '') AS lineage_file_name, collections.total_records_post_compaction, collections.size_bytes_post_compaction, collections.last_compaction_time_secs, databases.name as database_name, databases.tenant_id as db_tenant_id, collections.tenant as tenant").
183-
Joins("INNER JOIN databases ON collections.database_id = databases.id").
184-
Order("collections.created_at ASC")
179+
isQueryOptimized := true && databaseName != "" && tenantID != ""
180+
181+
query := s.db.Table("collections")
182+
collection_targets := "collections.id as collection_id, collections.name as collection_name, collections.configuration_json_str, collections.dimension, collections.database_id AS database_id, collections.ts as collection_ts, collections.is_deleted, collections.created_at as collection_created_at, collections.updated_at as collection_updated_at, collections.log_position, collections.version, collections.version_file_name, collections.root_collection_id, NULLIF(collections.lineage_file_name, '') AS lineage_file_name, collections.total_records_post_compaction, collections.size_bytes_post_compaction, collections.last_compaction_time_secs, "
183+
db_targets := " databases.name as database_name, databases.tenant_id as db_tenant_id, "
184+
collection_tenant := "collections.tenant as tenant"
185+
186+
if isQueryOptimized {
187+
db_id_query := s.db.Model(&dbmodel.Database{}).
188+
Select("id").
189+
Where("tenant_id = ?", tenantID).
190+
Where("name = ?", databaseName).
191+
Limit(1)
192+
193+
// We rewrite the query to get the one database_id with what is hopefully an initplan
194+
// that first gets the database_id and then uses it to do an ordered scan over
195+
// the matching collections.
196+
query = query.Select(collection_targets+"? as database_name, ? as db_tenant_id, "+collection_tenant, databaseName, tenantID).
197+
Where("collections.database_id = (?)", db_id_query)
198+
} else {
199+
query = query.Select(collection_targets + db_targets + collection_tenant).
200+
Joins("INNER JOIN databases ON collections.database_id = databases.id")
201+
}
185202

186-
if databaseName != "" {
203+
query = query.Order("collections.created_at ASC")
204+
205+
if databaseName != "" && !isQueryOptimized {
187206
query = query.Where("databases.name = ?", databaseName)
188207
}
189-
if tenantID != "" {
208+
if tenantID != "" && !isQueryOptimized {
190209
query = query.Where("databases.tenant_id = ?", tenantID)
191210
}
192211
if ids != nil {
@@ -206,26 +225,8 @@ func (s *collectionDb) getCollections(ids []string, name *string, tenantID strin
206225
query = query.Offset(int(*offset))
207226
}
208227

209-
// Use optimized CTE query only if feature flag is enabled
210-
if dbcore.IsOptimizedCollectionQueriesEnabled() && databaseName != "" && tenantID != "" {
211-
var dummy []Result
212-
stmt := query.Session(&gorm.Session{DryRun: true}).Find(&dummy).Statement
213-
sqlString := stmt.SQL.String()
214-
vars := stmt.Vars
215-
216-
cte := fmt.Sprintf(`WITH db AS (
217-
SELECT id
218-
FROM databases
219-
WHERE name = $%d AND tenant_id = $%d
220-
)`, len(vars)+1, len(vars)+2)
221-
222-
fullSQL := cte + `SELECT * FROM (` + sqlString + `) p WHERE p.database_id = (SELECT id FROM db)`
223-
vars = append([]interface{}{databaseName, tenantID}, vars...)
224-
query = s.db.Raw(fullSQL, vars...)
225-
}
226-
227228
var results []Result
228-
err = s.db.Table("(?) as ci", query).
229+
query = s.db.Table("(?) as ci", query).
229230
Select(`
230231
ci.*,
231232
cm.key,
@@ -237,11 +238,31 @@ func (s *collectionDb) getCollections(ids []string, name *string, tenantID strin
237238
cm.created_at as metadata_created_at,
238239
cm.updated_at as metadata_updated_at
239240
`).
240-
Joins("LEFT JOIN collection_metadata cm ON cm.collection_id = ci.collection_id").
241-
Scan(&results).Error
241+
Joins("LEFT JOIN collection_metadata cm ON cm.collection_id = ci.collection_id")
242242

243-
if err != nil {
244-
return nil, err
243+
if isQueryOptimized {
244+
// Setting random_page_cost to 1.1 because that's usually the recommended value
245+
// for SSD based databases. This encourages index usage. The default used
246+
// to be 4.0 which was more for HDD based databases where random seeking
247+
// was way more expensive than sequential access.
248+
var dummy []Result
249+
stmt := query.Session(&gorm.Session{DryRun: true}).Find(&dummy).Statement
250+
sqlString := stmt.SQL.String()
251+
252+
// Use a transaction to execute both commands in a single round trip
253+
err := s.db.Transaction(func(tx *gorm.DB) error {
254+
if err := tx.Exec("SET LOCAL random_page_cost = 1.1").Error; err != nil {
255+
return err
256+
}
257+
return tx.Raw(sqlString, stmt.Vars...).Scan(&results).Error
258+
})
259+
if err != nil {
260+
return nil, err
261+
}
262+
} else {
263+
if err := query.Scan(&results).Error; err != nil {
264+
return nil, err
265+
}
245266
}
246267

247268
var collectionsMap = make(map[string]*dbmodel.CollectionAndMetadata)

0 commit comments

Comments
 (0)