11use backon:: ExponentialBuilder ;
22use backon:: Retryable ;
33use chroma_error:: ChromaValidationError ;
4+ use chroma_types:: Collection ;
45use parking_lot:: Mutex ;
56use reqwest:: Method ;
67use reqwest:: StatusCode ;
@@ -9,7 +10,10 @@ use std::sync::Arc;
910use thiserror:: Error ;
1011
1112use crate :: client:: ChromaClientOptions ;
12- use crate :: types:: { GetUserIdentityResponse , HeartbeatResponse , ListCollectionsRequest } ;
13+ use crate :: collection:: ChromaCollection ;
14+ use crate :: types:: {
15+ CreateCollectionRequest , GetUserIdentityResponse , HeartbeatResponse , ListCollectionsRequest ,
16+ } ;
1317
1418const USER_AGENT : & str = concat ! (
1519 "Chroma Rust Client v" ,
@@ -35,7 +39,7 @@ pub struct ChromaClient {
3539 client : reqwest:: Client ,
3640 retry_policy : ExponentialBuilder ,
3741 tenant_id : Arc < Mutex < Option < String > > > ,
38- default_database_id : Arc < Mutex < Option < String > > > ,
42+ default_database_name : Arc < Mutex < Option < String > > > ,
3943 resolve_tenant_or_database_lock : Arc < tokio:: sync:: Mutex < ( ) > > ,
4044 #[ cfg( feature = "opentelemetry" ) ]
4145 metrics : crate :: client:: metrics:: Metrics ,
@@ -48,7 +52,7 @@ impl Clone for ChromaClient {
4852 client : self . client . clone ( ) ,
4953 retry_policy : self . retry_policy ,
5054 tenant_id : Arc :: new ( Mutex :: new ( self . tenant_id . lock ( ) . clone ( ) ) ) ,
51- default_database_id : Arc :: new ( Mutex :: new ( self . default_database_id . lock ( ) . clone ( ) ) ) ,
55+ default_database_name : Arc :: new ( Mutex :: new ( self . default_database_name . lock ( ) . clone ( ) ) ) ,
5256 resolve_tenant_or_database_lock : Arc :: new ( tokio:: sync:: Mutex :: new ( ( ) ) ) ,
5357 #[ cfg( feature = "opentelemetry" ) ]
5458 metrics : self . metrics . clone ( ) ,
@@ -79,16 +83,16 @@ impl ChromaClient {
7983 client,
8084 retry_policy : options. retry_options . into ( ) ,
8185 tenant_id : Arc :: new ( Mutex :: new ( options. tenant_id ) ) ,
82- default_database_id : Arc :: new ( Mutex :: new ( options. default_database_id ) ) ,
86+ default_database_name : Arc :: new ( Mutex :: new ( options. default_database_name ) ) ,
8387 resolve_tenant_or_database_lock : Arc :: new ( tokio:: sync:: Mutex :: new ( ( ) ) ) ,
8488 #[ cfg( feature = "opentelemetry" ) ]
8589 metrics : crate :: client:: metrics:: Metrics :: new ( ) ,
8690 }
8791 }
8892
89- pub fn set_default_database_id ( & self , database_id : String ) {
90- let mut lock = self . default_database_id . lock ( ) ;
91- * lock = Some ( database_id ) ;
93+ pub fn set_default_database_name ( & self , database_name : String ) {
94+ let mut lock = self . default_database_name . lock ( ) ;
95+ * lock = Some ( database_name ) ;
9296 }
9397
9498 pub async fn create_database ( & self , name : String ) -> Result < ( ) , ChromaClientError > {
@@ -158,79 +162,135 @@ impl ChromaClient {
158162 . await
159163 }
160164
165+ pub async fn get_or_create_collection (
166+ & self ,
167+ params : CreateCollectionRequest ,
168+ ) -> Result < ChromaCollection , ChromaClientError > {
169+ self . common_create_collection ( params, true ) . await
170+ }
171+
172+ pub async fn create_collection (
173+ & self ,
174+ params : CreateCollectionRequest ,
175+ ) -> Result < ChromaCollection , ChromaClientError > {
176+ self . common_create_collection ( params, false ) . await
177+ }
178+
161179 pub async fn list_collections (
162180 & self ,
163181 params : ListCollectionsRequest ,
164- ) -> Result < Vec < String > , ChromaClientError > {
182+ ) -> Result < Vec < ChromaCollection > , ChromaClientError > {
165183 let tenant_id = self . get_tenant_id ( ) . await ?;
166- let database_id = self . get_database_id ( params. database_id ) . await ?;
184+ let database_name = self . get_database_name ( params. database_name ) . await ?;
167185
168186 #[ derive( Serialize ) ]
169187 struct QueryParams {
170188 limit : usize ,
171189 offset : Option < usize > ,
172190 }
173191
174- self . send :: < ( ) , _ , _ > (
175- "list_collections" ,
176- Method :: GET ,
177- format ! (
178- "/api/v2/tenants/{}/databases/{}/collections" ,
179- tenant_id, database_id
180- ) ,
181- None ,
182- Some ( QueryParams {
183- limit : params. limit ,
184- offset : params. offset ,
185- } ) ,
186- )
187- . await
192+ let collections = self
193+ . send :: < ( ) , _ , Vec < Collection > > (
194+ "list_collections" ,
195+ Method :: GET ,
196+ format ! (
197+ "/api/v2/tenants/{}/databases/{}/collections" ,
198+ tenant_id, database_name
199+ ) ,
200+ None ,
201+ Some ( QueryParams {
202+ limit : params. limit ,
203+ offset : params. offset ,
204+ } ) ,
205+ )
206+ . await ?;
207+
208+ Ok ( collections
209+ . into_iter ( )
210+ . map ( |collection| ChromaCollection {
211+ client : self . clone ( ) ,
212+ collection : Arc :: new ( collection) ,
213+ } )
214+ . collect ( ) )
215+ }
216+
217+ async fn common_create_collection (
218+ & self ,
219+ params : CreateCollectionRequest ,
220+ get_or_create : bool ,
221+ ) -> Result < ChromaCollection , ChromaClientError > {
222+ let tenant_id = self . get_tenant_id ( ) . await ?;
223+ let database_name = self . get_database_name ( params. database_name ) . await ?;
224+
225+ let collection: chroma_types:: Collection = self
226+ . send (
227+ "create_collection" ,
228+ Method :: POST ,
229+ format ! (
230+ "/api/v2/tenants/{}/databases/{}/collections" ,
231+ tenant_id, database_name
232+ ) ,
233+ Some ( serde_json:: json!( {
234+ "name" : params. name,
235+ "configuration" : params. configuration,
236+ "metadata" : params. metadata,
237+ "get_or_create" : get_or_create,
238+ } ) ) ,
239+ None :: < ( ) > ,
240+ )
241+ . await ?;
242+
243+ Ok ( ChromaCollection {
244+ client : self . clone ( ) ,
245+ collection : Arc :: new ( collection) ,
246+ } )
188247 }
189248
190- async fn get_database_id (
249+ async fn get_database_name (
191250 & self ,
192- id_override : Option < String > ,
251+ name_override : Option < String > ,
193252 ) -> Result < String , ChromaClientError > {
194- if let Some ( id) = id_override {
253+ if let Some ( id) = name_override {
195254 return Ok ( id) ;
196255 }
197256
198257 {
199- let database_id_lock = self . default_database_id . lock ( ) ;
200- if let Some ( database_id ) = & * database_id_lock {
201- return Ok ( database_id . clone ( ) ) ;
258+ let database_name_lock = self . default_database_name . lock ( ) ;
259+ if let Some ( database_name ) = & * database_name_lock {
260+ return Ok ( database_name . clone ( ) ) ;
202261 }
203262 }
204263
205264 let _guard = self . resolve_tenant_or_database_lock . lock ( ) . await ;
206265
207266 {
208- let database_id_lock = self . default_database_id . lock ( ) ;
209- if let Some ( database_id ) = & * database_id_lock {
210- return Ok ( database_id . clone ( ) ) ;
267+ let database_name_lock = self . default_database_name . lock ( ) ;
268+ if let Some ( database_name ) = & * database_name_lock {
269+ return Ok ( database_name . clone ( ) ) ;
211270 }
212271 }
213272
214273 let identity = self . get_auth_identity ( ) . await ?;
215274
216275 if identity. databases . len ( ) > 1 {
217276 return Err ( ChromaClientError :: CouldNotResolveDatabaseId (
218- "Client has access to multiple databases; please provide a database_id" . to_string ( ) ,
277+ "Client has access to multiple databases; please provide a database_name"
278+ . to_string ( ) ,
219279 ) ) ;
220280 }
221281
222- let database_id = identity. databases . first ( ) . ok_or_else ( || {
282+ let database_name = identity. databases . first ( ) . ok_or_else ( || {
223283 ChromaClientError :: CouldNotResolveDatabaseId (
224284 "Client has access to no databases" . to_string ( ) ,
225285 )
226286 } ) ?;
227287
228288 {
229- let mut database_id_lock = self . default_database_id . lock ( ) ;
230- * database_id_lock = Some ( database_id . clone ( ) ) ;
289+ let mut database_name_lock = self . default_database_name . lock ( ) ;
290+ * database_name_lock = Some ( database_name . clone ( ) ) ;
231291 }
232292
233- Ok ( database_id . clone ( ) )
293+ Ok ( database_name . clone ( ) )
234294 }
235295
236296 async fn get_tenant_id ( & self ) -> Result < String , ChromaClientError > {
@@ -290,7 +350,7 @@ impl ChromaClient {
290350 #[ cfg( feature = "opentelemetry" ) ]
291351 let started_at = std:: time:: Instant :: now ( ) ;
292352
293- let response = request. send ( ) . await ?;
353+ let response = request. send ( ) . await . map_err ( |err| ( err , None ) ) ?;
294354
295355 #[ cfg( feature = "opentelemetry" ) ]
296356 {
@@ -305,13 +365,16 @@ impl ChromaClient {
305365 let _ = operation_name;
306366 }
307367
308- response. error_for_status_ref ( ) ?;
309- Ok :: < _ , reqwest:: Error > ( response)
368+ if let Err ( err) = response. error_for_status_ref ( ) {
369+ return Err ( ( err, Some ( response) ) ) ;
370+ }
371+
372+ Ok :: < reqwest:: Response , ( reqwest:: Error , Option < reqwest:: Response > ) > ( response)
310373 } ;
311374
312375 let response = attempt
313376 . retry ( & self . retry_policy )
314- . notify ( |err, _| {
377+ . notify ( |( err, _ ) , _| {
315378 tracing:: warn!(
316379 url = %url,
317380 method =? method,
@@ -322,13 +385,33 @@ impl ChromaClient {
322385 #[ cfg( feature = "opentelemetry" ) ]
323386 self . metrics . increment_retry ( operation_name) ;
324387 } )
325- . when ( |err| {
388+ . when ( |( err, _ ) | {
326389 err. status ( )
327390 . map ( |status| status == StatusCode :: TOO_MANY_REQUESTS )
328391 . unwrap_or_default ( )
329392 || method == Method :: GET
330393 } )
331- . await ?;
394+ . await ;
395+
396+ let response = match response {
397+ Ok ( response) => response,
398+ Err ( ( err, maybe_response) ) => {
399+ if let Some ( response) = maybe_response {
400+ let json = response. json :: < serde_json:: Value > ( ) . await ?;
401+
402+ if tracing:: enabled!( tracing:: Level :: TRACE ) {
403+ tracing:: trace!(
404+ url = %url,
405+ method =? method,
406+ "Received response: {}" ,
407+ serde_json:: to_string_pretty( & json) . unwrap_or_else( |_| "<failed to serialize>" . to_string( ) )
408+ ) ;
409+ }
410+ }
411+
412+ return Err ( ChromaClientError :: RequestError ( err) ) ;
413+ }
414+ } ;
332415
333416 let json = response. json :: < serde_json:: Value > ( ) . await ?;
334417
@@ -392,22 +475,15 @@ mod tests {
392475 // Create isolated database for test
393476 let database_name = format ! ( "test_db_{}" , uuid:: Uuid :: new_v4( ) ) ;
394477 client. create_database ( database_name. clone ( ) ) . await . unwrap ( ) ;
395- let databases = client. list_databases ( ) . await . unwrap ( ) ;
396- let database_id = databases
397- . iter ( )
398- . find ( |db| db. name == database_name)
399- . unwrap ( )
400- . id
401- . clone ( ) ;
402- client. set_default_database_id ( database_id. clone ( ) ) ;
478+ client. set_default_database_name ( database_name. clone ( ) ) ;
403479
404480 let result = std:: panic:: AssertUnwindSafe ( callback ( client. clone ( ) ) )
405481 . catch_unwind ( )
406482 . await ;
407483
408484 // Delete test database
409- if let Err ( err) = client. delete_database ( database_name) . await {
410- tracing:: error!( "Failed to delete test database {}: {}" , database_id , err) ;
485+ if let Err ( err) = client. delete_database ( database_name. clone ( ) ) . await {
486+ tracing:: error!( "Failed to delete test database {}: {}" , database_name , err) ;
411487 }
412488
413489 result. unwrap ( ) ;
@@ -551,7 +627,62 @@ mod tests {
551627 . unwrap ( ) ;
552628 assert ! ( collections. is_empty( ) ) ;
553629
554- // todo: create collection and assert it's returned, test limit/offset
630+ client
631+ . create_collection (
632+ CreateCollectionRequest :: builder ( )
633+ . name ( "first" . to_string ( ) )
634+ . build ( ) ,
635+ )
636+ . await
637+ . unwrap ( ) ;
638+
639+ client
640+ . create_collection (
641+ CreateCollectionRequest :: builder ( )
642+ . name ( "second" . to_string ( ) )
643+ . build ( ) ,
644+ )
645+ . await
646+ . unwrap ( ) ;
647+
648+ let collections = client
649+ . list_collections ( ListCollectionsRequest :: builder ( ) . build ( ) )
650+ . await
651+ . unwrap ( ) ;
652+ assert_eq ! ( collections. len( ) , 2 ) ;
653+
654+ let collections = client
655+ . list_collections ( ListCollectionsRequest :: builder ( ) . limit ( 1 ) . offset ( 1 ) . build ( ) )
656+ . await
657+ . unwrap ( ) ;
658+ assert_eq ! ( collections. len( ) , 1 ) ;
659+ assert_eq ! ( collections[ 0 ] . collection. name, "second" ) ;
660+ } )
661+ . await ;
662+ }
663+
664+ #[ tokio:: test]
665+ #[ test_log:: test]
666+ async fn test_live_cloud_create_collection ( ) {
667+ with_client ( |client| async move {
668+ let collection = client
669+ . create_collection (
670+ CreateCollectionRequest :: builder ( )
671+ . name ( "foo" . to_string ( ) )
672+ . build ( ) ,
673+ )
674+ . await
675+ . unwrap ( ) ;
676+ assert_eq ! ( collection. collection. name, "foo" ) ;
677+
678+ client
679+ . get_or_create_collection (
680+ CreateCollectionRequest :: builder ( )
681+ . name ( "foo" . to_string ( ) )
682+ . build ( ) ,
683+ )
684+ . await
685+ . unwrap ( ) ;
555686 } )
556687 . await ;
557688 }
0 commit comments