Skip to content

Commit bf8e462

Browse files
authored
[ENH]: add create_collection() to Rust client (#5655)
## Description of changes Adds `create_collection()` and `get_or_create_collection()` to the Rust client, also wires up `list_collections()`. ## Test plan _How are these changes tested?_ - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent c176352 commit bf8e462

File tree

5 files changed

+200
-59
lines changed

5 files changed

+200
-59
lines changed

rust/chroma/src/client/chroma_client.rs

Lines changed: 185 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use backon::ExponentialBuilder;
22
use backon::Retryable;
33
use chroma_error::ChromaValidationError;
4+
use chroma_types::Collection;
45
use parking_lot::Mutex;
56
use reqwest::Method;
67
use reqwest::StatusCode;
@@ -9,7 +10,10 @@ use std::sync::Arc;
910
use thiserror::Error;
1011

1112
use crate::client::ChromaClientOptions;
12-
use crate::types::{GetUserIdentityResponse, HeartbeatResponse, ListCollectionsRequest};
13+
use crate::collection::ChromaCollection;
14+
use crate::types::{
15+
CreateCollectionRequest, GetUserIdentityResponse, HeartbeatResponse, ListCollectionsRequest,
16+
};
1317

1418
const USER_AGENT: &str = concat!(
1519
"Chroma Rust Client v",
@@ -35,7 +39,7 @@ pub struct ChromaClient {
3539
client: reqwest::Client,
3640
retry_policy: ExponentialBuilder,
3741
tenant_id: Arc<Mutex<Option<String>>>,
38-
default_database_id: Arc<Mutex<Option<String>>>,
42+
default_database_name: Arc<Mutex<Option<String>>>,
3943
resolve_tenant_or_database_lock: Arc<tokio::sync::Mutex<()>>,
4044
#[cfg(feature = "opentelemetry")]
4145
metrics: crate::client::metrics::Metrics,
@@ -48,7 +52,7 @@ impl Clone for ChromaClient {
4852
client: self.client.clone(),
4953
retry_policy: self.retry_policy,
5054
tenant_id: Arc::new(Mutex::new(self.tenant_id.lock().clone())),
51-
default_database_id: Arc::new(Mutex::new(self.default_database_id.lock().clone())),
55+
default_database_name: Arc::new(Mutex::new(self.default_database_name.lock().clone())),
5256
resolve_tenant_or_database_lock: Arc::new(tokio::sync::Mutex::new(())),
5357
#[cfg(feature = "opentelemetry")]
5458
metrics: self.metrics.clone(),
@@ -79,16 +83,16 @@ impl ChromaClient {
7983
client,
8084
retry_policy: options.retry_options.into(),
8185
tenant_id: Arc::new(Mutex::new(options.tenant_id)),
82-
default_database_id: Arc::new(Mutex::new(options.default_database_id)),
86+
default_database_name: Arc::new(Mutex::new(options.default_database_name)),
8387
resolve_tenant_or_database_lock: Arc::new(tokio::sync::Mutex::new(())),
8488
#[cfg(feature = "opentelemetry")]
8589
metrics: crate::client::metrics::Metrics::new(),
8690
}
8791
}
8892

89-
pub fn set_default_database_id(&self, database_id: String) {
90-
let mut lock = self.default_database_id.lock();
91-
*lock = Some(database_id);
93+
pub fn set_default_database_name(&self, database_name: String) {
94+
let mut lock = self.default_database_name.lock();
95+
*lock = Some(database_name);
9296
}
9397

9498
pub async fn create_database(&self, name: String) -> Result<(), ChromaClientError> {
@@ -158,79 +162,135 @@ impl ChromaClient {
158162
.await
159163
}
160164

165+
pub async fn get_or_create_collection(
166+
&self,
167+
params: CreateCollectionRequest,
168+
) -> Result<ChromaCollection, ChromaClientError> {
169+
self.common_create_collection(params, true).await
170+
}
171+
172+
pub async fn create_collection(
173+
&self,
174+
params: CreateCollectionRequest,
175+
) -> Result<ChromaCollection, ChromaClientError> {
176+
self.common_create_collection(params, false).await
177+
}
178+
161179
pub async fn list_collections(
162180
&self,
163181
params: ListCollectionsRequest,
164-
) -> Result<Vec<String>, ChromaClientError> {
182+
) -> Result<Vec<ChromaCollection>, ChromaClientError> {
165183
let tenant_id = self.get_tenant_id().await?;
166-
let database_id = self.get_database_id(params.database_id).await?;
184+
let database_name = self.get_database_name(params.database_name).await?;
167185

168186
#[derive(Serialize)]
169187
struct QueryParams {
170188
limit: usize,
171189
offset: Option<usize>,
172190
}
173191

174-
self.send::<(), _, _>(
175-
"list_collections",
176-
Method::GET,
177-
format!(
178-
"/api/v2/tenants/{}/databases/{}/collections",
179-
tenant_id, database_id
180-
),
181-
None,
182-
Some(QueryParams {
183-
limit: params.limit,
184-
offset: params.offset,
185-
}),
186-
)
187-
.await
192+
let collections = self
193+
.send::<(), _, Vec<Collection>>(
194+
"list_collections",
195+
Method::GET,
196+
format!(
197+
"/api/v2/tenants/{}/databases/{}/collections",
198+
tenant_id, database_name
199+
),
200+
None,
201+
Some(QueryParams {
202+
limit: params.limit,
203+
offset: params.offset,
204+
}),
205+
)
206+
.await?;
207+
208+
Ok(collections
209+
.into_iter()
210+
.map(|collection| ChromaCollection {
211+
client: self.clone(),
212+
collection: Arc::new(collection),
213+
})
214+
.collect())
215+
}
216+
217+
async fn common_create_collection(
218+
&self,
219+
params: CreateCollectionRequest,
220+
get_or_create: bool,
221+
) -> Result<ChromaCollection, ChromaClientError> {
222+
let tenant_id = self.get_tenant_id().await?;
223+
let database_name = self.get_database_name(params.database_name).await?;
224+
225+
let collection: chroma_types::Collection = self
226+
.send(
227+
"create_collection",
228+
Method::POST,
229+
format!(
230+
"/api/v2/tenants/{}/databases/{}/collections",
231+
tenant_id, database_name
232+
),
233+
Some(serde_json::json!({
234+
"name": params.name,
235+
"configuration": params.configuration,
236+
"metadata": params.metadata,
237+
"get_or_create": get_or_create,
238+
})),
239+
None::<()>,
240+
)
241+
.await?;
242+
243+
Ok(ChromaCollection {
244+
client: self.clone(),
245+
collection: Arc::new(collection),
246+
})
188247
}
189248

190-
async fn get_database_id(
249+
async fn get_database_name(
191250
&self,
192-
id_override: Option<String>,
251+
name_override: Option<String>,
193252
) -> Result<String, ChromaClientError> {
194-
if let Some(id) = id_override {
253+
if let Some(id) = name_override {
195254
return Ok(id);
196255
}
197256

198257
{
199-
let database_id_lock = self.default_database_id.lock();
200-
if let Some(database_id) = &*database_id_lock {
201-
return Ok(database_id.clone());
258+
let database_name_lock = self.default_database_name.lock();
259+
if let Some(database_name) = &*database_name_lock {
260+
return Ok(database_name.clone());
202261
}
203262
}
204263

205264
let _guard = self.resolve_tenant_or_database_lock.lock().await;
206265

207266
{
208-
let database_id_lock = self.default_database_id.lock();
209-
if let Some(database_id) = &*database_id_lock {
210-
return Ok(database_id.clone());
267+
let database_name_lock = self.default_database_name.lock();
268+
if let Some(database_name) = &*database_name_lock {
269+
return Ok(database_name.clone());
211270
}
212271
}
213272

214273
let identity = self.get_auth_identity().await?;
215274

216275
if identity.databases.len() > 1 {
217276
return Err(ChromaClientError::CouldNotResolveDatabaseId(
218-
"Client has access to multiple databases; please provide a database_id".to_string(),
277+
"Client has access to multiple databases; please provide a database_name"
278+
.to_string(),
219279
));
220280
}
221281

222-
let database_id = identity.databases.first().ok_or_else(|| {
282+
let database_name = identity.databases.first().ok_or_else(|| {
223283
ChromaClientError::CouldNotResolveDatabaseId(
224284
"Client has access to no databases".to_string(),
225285
)
226286
})?;
227287

228288
{
229-
let mut database_id_lock = self.default_database_id.lock();
230-
*database_id_lock = Some(database_id.clone());
289+
let mut database_name_lock = self.default_database_name.lock();
290+
*database_name_lock = Some(database_name.clone());
231291
}
232292

233-
Ok(database_id.clone())
293+
Ok(database_name.clone())
234294
}
235295

236296
async fn get_tenant_id(&self) -> Result<String, ChromaClientError> {
@@ -290,7 +350,7 @@ impl ChromaClient {
290350
#[cfg(feature = "opentelemetry")]
291351
let started_at = std::time::Instant::now();
292352

293-
let response = request.send().await?;
353+
let response = request.send().await.map_err(|err| (err, None))?;
294354

295355
#[cfg(feature = "opentelemetry")]
296356
{
@@ -305,13 +365,16 @@ impl ChromaClient {
305365
let _ = operation_name;
306366
}
307367

308-
response.error_for_status_ref()?;
309-
Ok::<_, reqwest::Error>(response)
368+
if let Err(err) = response.error_for_status_ref() {
369+
return Err((err, Some(response)));
370+
}
371+
372+
Ok::<reqwest::Response, (reqwest::Error, Option<reqwest::Response>)>(response)
310373
};
311374

312375
let response = attempt
313376
.retry(&self.retry_policy)
314-
.notify(|err, _| {
377+
.notify(|(err, _), _| {
315378
tracing::warn!(
316379
url = %url,
317380
method =? method,
@@ -322,13 +385,33 @@ impl ChromaClient {
322385
#[cfg(feature = "opentelemetry")]
323386
self.metrics.increment_retry(operation_name);
324387
})
325-
.when(|err| {
388+
.when(|(err, _)| {
326389
err.status()
327390
.map(|status| status == StatusCode::TOO_MANY_REQUESTS)
328391
.unwrap_or_default()
329392
|| method == Method::GET
330393
})
331-
.await?;
394+
.await;
395+
396+
let response = match response {
397+
Ok(response) => response,
398+
Err((err, maybe_response)) => {
399+
if let Some(response) = maybe_response {
400+
let json = response.json::<serde_json::Value>().await?;
401+
402+
if tracing::enabled!(tracing::Level::TRACE) {
403+
tracing::trace!(
404+
url = %url,
405+
method =? method,
406+
"Received response: {}",
407+
serde_json::to_string_pretty(&json).unwrap_or_else(|_| "<failed to serialize>".to_string())
408+
);
409+
}
410+
}
411+
412+
return Err(ChromaClientError::RequestError(err));
413+
}
414+
};
332415

333416
let json = response.json::<serde_json::Value>().await?;
334417

@@ -392,22 +475,15 @@ mod tests {
392475
// Create isolated database for test
393476
let database_name = format!("test_db_{}", uuid::Uuid::new_v4());
394477
client.create_database(database_name.clone()).await.unwrap();
395-
let databases = client.list_databases().await.unwrap();
396-
let database_id = databases
397-
.iter()
398-
.find(|db| db.name == database_name)
399-
.unwrap()
400-
.id
401-
.clone();
402-
client.set_default_database_id(database_id.clone());
478+
client.set_default_database_name(database_name.clone());
403479

404480
let result = std::panic::AssertUnwindSafe(callback(client.clone()))
405481
.catch_unwind()
406482
.await;
407483

408484
// Delete test database
409-
if let Err(err) = client.delete_database(database_name).await {
410-
tracing::error!("Failed to delete test database {}: {}", database_id, err);
485+
if let Err(err) = client.delete_database(database_name.clone()).await {
486+
tracing::error!("Failed to delete test database {}: {}", database_name, err);
411487
}
412488

413489
result.unwrap();
@@ -551,7 +627,62 @@ mod tests {
551627
.unwrap();
552628
assert!(collections.is_empty());
553629

554-
// todo: create collection and assert it's returned, test limit/offset
630+
client
631+
.create_collection(
632+
CreateCollectionRequest::builder()
633+
.name("first".to_string())
634+
.build(),
635+
)
636+
.await
637+
.unwrap();
638+
639+
client
640+
.create_collection(
641+
CreateCollectionRequest::builder()
642+
.name("second".to_string())
643+
.build(),
644+
)
645+
.await
646+
.unwrap();
647+
648+
let collections = client
649+
.list_collections(ListCollectionsRequest::builder().build())
650+
.await
651+
.unwrap();
652+
assert_eq!(collections.len(), 2);
653+
654+
let collections = client
655+
.list_collections(ListCollectionsRequest::builder().limit(1).offset(1).build())
656+
.await
657+
.unwrap();
658+
assert_eq!(collections.len(), 1);
659+
assert_eq!(collections[0].collection.name, "second");
660+
})
661+
.await;
662+
}
663+
664+
#[tokio::test]
665+
#[test_log::test]
666+
async fn test_live_cloud_create_collection() {
667+
with_client(|client| async move {
668+
let collection = client
669+
.create_collection(
670+
CreateCollectionRequest::builder()
671+
.name("foo".to_string())
672+
.build(),
673+
)
674+
.await
675+
.unwrap();
676+
assert_eq!(collection.collection.name, "foo");
677+
678+
client
679+
.get_or_create_collection(
680+
CreateCollectionRequest::builder()
681+
.name("foo".to_string())
682+
.build(),
683+
)
684+
.await
685+
.unwrap();
555686
})
556687
.await;
557688
}

rust/chroma/src/client/options.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ pub struct ChromaClientOptions {
6464
/// Will be automatically resolved at request time if not provided
6565
pub tenant_id: Option<String>,
6666
/// Will be automatically resolved at request time if not provided. It can only be resolved automatically if this client has access to exactly one database.
67-
pub default_database_id: Option<String>,
67+
pub default_database_name: Option<String>,
6868
}
6969

7070
impl Default for ChromaClientOptions {
@@ -74,7 +74,7 @@ impl Default for ChromaClientOptions {
7474
auth_method: ChromaAuthMethod::None,
7575
retry_options: ChromaRetryOptions::default(),
7676
tenant_id: None,
77-
default_database_id: None,
77+
default_database_name: None,
7878
}
7979
}
8080
}

0 commit comments

Comments
 (0)