Skip to content

[ENH]: Get/Delete collection name validation + Get collection by ID #4381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions clients/js/packages/chromadb-core/src/generated/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1075,14 +1075,14 @@ export const ApiApiFetchParamCreator = function (
* @summary Retrieves a collection by ID or name.
* @param {string} tenant <p>Tenant ID</p>
* @param {string} database <p>Database name</p>
* @param {string} collectionId <p>UUID of the collection</p>
* @param {string} collectionIdOrName <p>ID or Name of the collection</p>
* @param {RequestInit} [options] Override http request option.
* @throws {RequiredError}
*/
getCollection(
tenant: string,
database: string,
collectionId: string,
collectionIdOrName: string,
options: RequestInit = {},
): FetchArgs {
// verify required parameter 'tenant' is not null or undefined
Expand All @@ -1099,18 +1099,21 @@ export const ApiApiFetchParamCreator = function (
"Required parameter database was null or undefined when calling getCollection.",
);
}
// verify required parameter 'collectionId' is not null or undefined
if (collectionId === null || collectionId === undefined) {
// verify required parameter 'collectionIdOrName' is not null or undefined
if (collectionIdOrName === null || collectionIdOrName === undefined) {
throw new RequiredError(
"collectionId",
"Required parameter collectionId was null or undefined when calling getCollection.",
"collectionIdOrName",
"Required parameter collectionIdOrName was null or undefined when calling getCollection.",
);
}
let localVarPath =
`/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}`
`/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id_or_name}`
.replace("{tenant}", encodeURIComponent(String(tenant)))
.replace("{database}", encodeURIComponent(String(database)))
.replace("{collection_id}", encodeURIComponent(String(collectionId)));
.replace(
"{collection_id_or_name}",
encodeURIComponent(String(collectionIdOrName)),
);
const localVarPathQueryStart = localVarPath.indexOf("?");
const localVarRequestOptions: RequestInit = Object.assign(
{ method: "GET" },
Expand Down Expand Up @@ -2435,19 +2438,19 @@ export const ApiApiFp = function (configuration?: Configuration) {
* @summary Retrieves a collection by ID or name.
* @param {string} tenant <p>Tenant ID</p>
* @param {string} database <p>Database name</p>
* @param {string} collectionId <p>UUID of the collection</p>
* @param {string} collectionIdOrName <p>ID or Name of the collection</p>
* @param {RequestInit} [options] Override http request option.
* @throws {RequiredError}
*/
getCollection(
tenant: string,
database: string,
collectionId: string,
collectionIdOrName: string,
options?: RequestInit,
): (fetch?: FetchAPI, basePath?: string) => Promise<Api.Collection> {
const localVarFetchArgs = ApiApiFetchParamCreator(
configuration,
).getCollection(tenant, database, collectionId, options);
).getCollection(tenant, database, collectionIdOrName, options);
return (fetch: FetchAPI = defaultFetch, basePath: string = BASE_PATH) => {
return fetch(
basePath + localVarFetchArgs.url,
Expand Down Expand Up @@ -3312,20 +3315,20 @@ export class ApiApi extends BaseAPI {
* @summary Retrieves a collection by ID or name.
* @param {string} tenant <p>Tenant ID</p>
* @param {string} database <p>Database name</p>
* @param {string} collectionId <p>UUID of the collection</p>
* @param {string} collectionIdOrName <p>ID or Name of the collection</p>
* @param {RequestInit} [options] Override http request option.
* @throws {RequiredError}
*/
public getCollection(
tenant: string,
database: string,
collectionId: string,
collectionIdOrName: string,
options?: RequestInit,
) {
return ApiApiFp(this.configuration).getCollection(
tenant,
database,
collectionId,
collectionIdOrName,
options,
)(this.fetch, this.basePath);
}
Expand Down
6 changes: 4 additions & 2 deletions rust/frontend/src/impls/service_based_frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use chroma_types::{
use opentelemetry::global;
use opentelemetry::metrics::Counter;
use std::collections::HashSet;
use std::str::FromStr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
Expand Down Expand Up @@ -323,11 +324,12 @@ impl ServiceBasedFrontend {
..
}: GetCollectionRequest,
) -> Result<GetCollectionResponse, GetCollectionError> {
let cid: Option<CollectionUuid> = CollectionUuid::from_str(collection_name.as_str()).ok();
let mut collections = self
.sysdb_client
.get_collections(
None,
Some(collection_name.clone()),
cid,
cid.is_none().then(|| collection_name.clone()),
Some(tenant_id),
Some(database_name),
None,
Expand Down
136 changes: 134 additions & 2 deletions rust/frontend/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ async fn create_collection(
/// Retrieves a collection by ID or name.
#[utoipa::path(
get,
path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id}",
path = "/api/v2/tenants/{tenant}/databases/{database}/collections/{collection_id_or_name}",
responses(
(status = 200, description = "Collection found", body = Collection),
(status = 401, description = "Unauthorized", body = ErrorResponse),
Expand All @@ -952,7 +952,7 @@ async fn create_collection(
params(
("tenant" = String, Path, description = "Tenant ID"),
("database" = String, Path, description = "Database name"),
("collection_id" = String, Path, description = "UUID of the collection")
("collection_id_or_name" = String, Path, description = "ID or Name of the collection")
)
)]
async fn get_collection(
Expand Down Expand Up @@ -1961,4 +1961,136 @@ mod tests {
serde_json::Value::String("InvalidArgumentError".to_string())
);
}

#[tokio::test]
async fn test_get_collection_by_id() {
let port = test_server().await;
let client = reqwest::Client::new();
let res = client
.post(format!("http://localhost:{}/api/v2/tenants/default_tenant/databases/default_database/collections", port))
.header("content-type", "application/json")
.body(serde_json::to_string(&serde_json::json!({ "name": "test" })).unwrap())
.send()
.await
.unwrap();

assert_eq!(res.status(), 200);

// Should have returned JSON
assert_eq!(
res.headers().get("content-type").unwrap(),
"application/json"
);
let response_json = res.json::<serde_json::Value>().await.unwrap();
let collection_id = response_json["id"].as_str().unwrap();
assert_eq!(response_json["name"].as_str().unwrap(), "test");
let res = client
.get(format!(
"http://localhost:{}/api/v2/tenants/default_tenant/databases/default_database/collections/{}",
port, collection_id
))
.header("content-type", "application/json")
.send()
.await
.unwrap();
assert_eq!(res.status().clone(), 200);
// Should have returned JSON
assert_eq!(
res.headers().get("content-type").unwrap(),
"application/json"
);
let response_json = res.json::<serde_json::Value>().await.unwrap();
assert_eq!(response_json["name"].as_str().unwrap(), "test");
assert_eq!(response_json["id"].as_str().unwrap(), collection_id);
}

#[tokio::test]
async fn test_get_collection_with_name_validation() {
let port = test_server().await;
let client = reqwest::Client::new();
let res = client
.get(format!(
"http://localhost:{}/api/v2/tenants/default_tenant/databases/default_database/collections/{}",
port, "i"
))
.header("content-type", "application/json")
.send()
.await
.unwrap();
assert_eq!(res.status().clone(), 400);
let response_json = res.json::<serde_json::Value>().await.unwrap();
assert_eq!(
response_json["error"],
serde_json::Value::String("InvalidArgumentError".to_string())
);
assert!(response_json["message"]
.as_str()
.expect("error message to be present")
.contains("Expected a name containing 3-512 characters"));

let res = client
.get(format!(
"http://localhost:{}/api/v2/tenants/default_tenant/databases/default_database/collections/{}",
port, "_invalid_name"
))
.header("content-type", "application/json")
.send()
.await
.unwrap();
assert_eq!(res.status().clone(), 400);
let response_json = res.json::<serde_json::Value>().await.unwrap();
assert_eq!(
response_json["error"],
serde_json::Value::String("InvalidArgumentError".to_string())
);
assert!(response_json["message"]
.as_str()
.expect("error message to be present")
.contains("starting and ending with an alphanumeric character in"));
}

#[tokio::test]
async fn test_delete_collection_with_name_validation() {
let port = test_server().await;
let client = reqwest::Client::new();
let res = client
.delete(format!(
"http://localhost:{}/api/v2/tenants/default_tenant/databases/default_database/collections/{}",
port, "i"
))
.header("content-type", "application/json")
.send()
.await
.unwrap();
assert_eq!(res.status().clone(), 400);
let response_json = res.json::<serde_json::Value>().await.unwrap();
assert_eq!(
response_json["error"],
serde_json::Value::String("InvalidArgumentError".to_string())
);
assert!(response_json["message"]
.as_str()
.expect("error message to be present")
.contains("Expected a name containing 3-512 characters"));

let res = client
.delete(format!(
"http://localhost:{}/api/v2/tenants/default_tenant/databases/default_database/collections/{}",
port, "_invalid_name"
))
.header("content-type", "application/json")
.send()
.await
.unwrap();
assert_eq!(res.status().clone(), 400);
let response_json = res.json::<serde_json::Value>().await.unwrap();
assert_eq!(
response_json["error"],
serde_json::Value::String("InvalidArgumentError".to_string())
);
assert!(response_json["message"]
.as_str()
.expect("error message to be present")
.contains("starting and ending with an alphanumeric character in"));
}
}
2 changes: 2 additions & 0 deletions rust/types/src/api_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ pub type CountCollectionsResponse = u32;
pub struct GetCollectionRequest {
pub tenant_id: String,
pub database_name: String,
#[validate(custom(function = "validate_name"))]
pub collection_name: String,
}

Expand Down Expand Up @@ -722,6 +723,7 @@ impl ChromaError for UpdateCollectionError {
pub struct DeleteCollectionRequest {
pub tenant_id: String,
pub database_name: String,
#[validate(custom(function = "validate_name"))]
pub collection_name: String,
}

Expand Down
2 changes: 1 addition & 1 deletion rust/types/src/validators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub(crate) fn validate_non_empty_metadata<V>(
pub(crate) fn validate_name(name: impl AsRef<str>) -> Result<(), ValidationError> {
let name_str = name.as_ref();
if !ALNUM_RE.is_match(name_str) {
return Err(ValidationError::new("name").with_message(format!("Expected a name containing 3-512 characters from [a-zA-Z0-9._-], starting and ending with a character in [a-zA-Z0-9]. Got: {name_str}").into()));
return Err(ValidationError::new("name").with_message(format!("Expected a name containing 3-512 characters from [a-zA-Z0-9._-], starting and ending with an alphanumeric character in [a-zA-Z0-9]. Got: {name_str}").into()));
}

if DP_RE.is_match(name_str) {
Expand Down
Loading