Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- This file should undo anything in `up.sql`
DROP TABLE IF EXISTS chunk_boosts;
15 changes: 15 additions & 0 deletions server/migrations/2024-11-15-185112_store_chunk_boosts/up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
-- Your SQL goes here
CREATE TABLE chunk_boosts (
chunk_id UUID NOT NULL,
fulltext_boost_phrase TEXT,
fulltext_boost_factor FLOAT,
semantic_boost_phrase TEXT,
semantic_boost_factor FLOAT,

PRIMARY KEY (chunk_id),
CONSTRAINT chunk_boosts_chunk_id_fkey FOREIGN KEY (chunk_id) REFERENCES chunk_metadata (id) ON DELETE CASCADE ON UPDATE CASCADE,
CONSTRAINT fulltext_pairs CHECK ((fulltext_boost_phrase IS NULL AND fulltext_boost_factor IS NULL) OR
(fulltext_boost_phrase IS NOT NULL AND fulltext_boost_factor IS NOT NULL)),
CONSTRAINT semantic_pairs CHECK ((semantic_boost_phrase IS NULL AND semantic_boost_factor IS NULL) OR
(semantic_boost_phrase IS NOT NULL AND semantic_boost_factor IS NOT NULL))
)
42 changes: 37 additions & 5 deletions server/src/bin/ingestion-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::collections::HashMap;
use std::sync::{atomic::AtomicBool, atomic::Ordering, Arc};
use tracing_subscriber::{prelude::*, EnvFilter, Layer};
use trieve_server::data::models::{
self, ChunkMetadata, DatasetConfiguration, QdrantPayload, UnifiedId, WorkerEvent,
self, ChunkBoost, ChunkMetadata, DatasetConfiguration, QdrantPayload, UnifiedId, WorkerEvent,
};
use trieve_server::errors::ServiceError;
use trieve_server::handlers::chunk_handler::{
Expand All @@ -20,8 +20,8 @@ use trieve_server::handlers::chunk_handler::{
use trieve_server::handlers::group_handler::dataset_owns_group;
use trieve_server::operators::chunk_operator::{
bulk_insert_chunk_metadata_query, bulk_revert_insert_chunk_metadata_query,
get_row_count_for_organization_id_query, insert_chunk_metadata_query,
update_chunk_metadata_query,
get_row_count_for_organization_id_query, insert_chunk_boost, insert_chunk_metadata_query,
update_chunk_boost_query, update_chunk_metadata_query,
};
use trieve_server::operators::clickhouse_operator::{ClickHouseEvent, EventQueue};
use trieve_server::operators::dataset_operator::{
Expand Down Expand Up @@ -1020,6 +1020,23 @@ async fn upload_chunk(
)
.await?;

if payload.chunk.fulltext_boost.is_some() || payload.chunk.semantic_boost.is_some() {
insert_chunk_boost(
ChunkBoost {
chunk_id: inserted_chunk.id,
fulltext_boost_phrase: payload.chunk.fulltext_boost.clone().map(|x| x.phrase),
fulltext_boost_factor: payload.chunk.fulltext_boost.map(|x| x.boost_factor),
semantic_boost_phrase: payload.chunk.semantic_boost.clone().map(|x| x.phrase),
semantic_boost_factor: payload
.chunk
.semantic_boost
.map(|x| x.distance_factor as f64),
},
web_pool.clone(),
)
.await?;
}

insert_tx.finish();

qdrant_point_id = inserted_chunk.qdrant_point_id;
Expand Down Expand Up @@ -1139,7 +1156,7 @@ async fn update_chunk(
true => {
let embedding = get_dense_vector(
content.to_string(),
payload.semantic_boost,
payload.semantic_boost.clone(),
"doc",
dataset_config.clone(),
)
Expand Down Expand Up @@ -1171,7 +1188,7 @@ async fn update_chunk(
&& std::env::var("BM25_ACTIVE").unwrap_or("false".to_string()) == "true"
{
let vecs = get_bm25_embeddings(
vec![(content, payload.fulltext_boost)],
vec![(content, payload.fulltext_boost.clone())],
dataset_config.BM25_AVG_LEN,
dataset_config.BM25_B,
dataset_config.BM25_K,
Expand Down Expand Up @@ -1236,6 +1253,21 @@ async fn update_chunk(
.map_err(|err| ServiceError::BadRequest(err.to_string()))?;
}

// If boosts are changed, reflect changes to chunk_boosts table
if payload.fulltext_boost.is_some() || payload.semantic_boost.is_some() {
update_chunk_boost_query(
ChunkBoost {
chunk_id: payload.chunk_metadata.id,
fulltext_boost_phrase: payload.fulltext_boost.clone().map(|x| x.phrase),
fulltext_boost_factor: payload.fulltext_boost.map(|x| x.boost_factor),
semantic_boost_phrase: payload.semantic_boost.clone().map(|x| x.phrase),
semantic_boost_factor: payload.semantic_boost.map(|x| x.distance_factor as f64),
},
web_pool,
)
.await?;
}

Ok(())
}

Expand Down
30 changes: 30 additions & 0 deletions server/src/data/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5035,6 +5035,36 @@ pub struct ChunkData {
pub semantic_boost: Option<SemanticBoost>,
}

#[derive(Debug, Serialize, Deserialize, Selectable, Queryable, Insertable, Clone)]
#[diesel(table_name = chunk_boosts)]
pub struct ChunkBoost {
pub chunk_id: uuid::Uuid,
pub fulltext_boost_phrase: Option<String>,
pub fulltext_boost_factor: Option<f64>,
pub semantic_boost_phrase: Option<String>,
pub semantic_boost_factor: Option<f64>,
}

#[derive(AsChangeset)]
#[diesel(table_name = chunk_boosts)]
pub struct ChunkBoostChangeset {
fulltext_boost_phrase: Option<String>,
fulltext_boost_factor: Option<f64>,
semantic_boost_phrase: Option<String>,
semantic_boost_factor: Option<f64>,
}

impl From<ChunkBoost> for ChunkBoostChangeset {
fn from(chunk_boost: ChunkBoost) -> Self {
ChunkBoostChangeset {
fulltext_boost_phrase: chunk_boost.fulltext_boost_phrase,
fulltext_boost_factor: chunk_boost.fulltext_boost_factor,
semantic_boost_phrase: chunk_boost.semantic_boost_phrase,
semantic_boost_factor: chunk_boost.semantic_boost_factor,
}
}
}

#[derive(Debug, ToSchema, Serialize, Deserialize, Row)]
#[schema(example = json!({
"search_type": "search",
Expand Down
12 changes: 12 additions & 0 deletions server/src/data/schema.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
// @generated automatically by Diesel CLI.

diesel::table! {
chunk_boosts (chunk_id) {
chunk_id -> Uuid,
fulltext_boost_phrase -> Nullable<Text>,
fulltext_boost_factor -> Nullable<Float8>,
semantic_boost_phrase -> Nullable<Text>,
semantic_boost_factor -> Nullable<Float8>,
}
}

diesel::table! {
chunk_group (id) {
id -> Uuid,
Expand Down Expand Up @@ -281,6 +291,7 @@ diesel::table! {
}
}

diesel::joinable!(chunk_boosts -> chunk_metadata (chunk_id));
diesel::joinable!(chunk_group -> datasets (dataset_id));
diesel::joinable!(chunk_group_bookmarks -> chunk_group (group_id));
diesel::joinable!(chunk_group_bookmarks -> chunk_metadata (chunk_metadata_id));
Expand All @@ -307,6 +318,7 @@ diesel::joinable!(user_organizations -> organizations (organization_id));
diesel::joinable!(user_organizations -> users (user_id));

diesel::allow_tables_to_appear_in_same_query!(
chunk_boosts,
chunk_group,
chunk_group_bookmarks,
chunk_metadata,
Expand Down
108 changes: 105 additions & 3 deletions server/src/operators/chunk_operator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::data::models::{
uuid_between, ChunkData, ChunkGroup, ChunkGroupBookmark, ChunkMetadataTable, ChunkMetadataTags,
ChunkMetadataTypes, ContentChunkMetadata, Dataset, DatasetConfiguration, DatasetTags,
IngestSpecificChunkMetadata, SlimChunkMetadata, SlimChunkMetadataTable, UnifiedId,
uuid_between, ChunkBoost, ChunkBoostChangeset, ChunkData, ChunkGroup, ChunkGroupBookmark,
ChunkMetadataTable, ChunkMetadataTags, ChunkMetadataTypes, ContentChunkMetadata, Dataset,
DatasetConfiguration, DatasetTags, IngestSpecificChunkMetadata, SlimChunkMetadata,
SlimChunkMetadataTable, UnifiedId,
};
use crate::handlers::chunk_handler::{BulkUploadIngestionMessage, ChunkReqPayload};
use crate::handlers::chunk_handler::{ChunkFilter, UploadIngestionMessage};
Expand Down Expand Up @@ -816,6 +817,58 @@ pub async fn bulk_insert_chunk_metadata_query(
})
.collect::<Vec<ChunkData>>();

use crate::data::schema::chunk_boosts::dsl as chunk_boosts_columns;

// Insert the fulltext and semantic boosts
let boosts_to_insert = insertion_data
.iter()
.filter_map(|chunk_data| {
if chunk_data.fulltext_boost.is_none() && chunk_data.semantic_boost.is_none() {
return None;
}
return Some(ChunkBoost {
chunk_id: chunk_data.chunk_metadata.id,
fulltext_boost_phrase: chunk_data
.fulltext_boost
.as_ref()
.map(|boost| boost.phrase.clone()),
fulltext_boost_factor: chunk_data
.fulltext_boost
.as_ref()
.map(|boost| boost.boost_factor),
semantic_boost_phrase: chunk_data
.semantic_boost
.as_ref()
.map(|boost| boost.phrase.clone()),
semantic_boost_factor: chunk_data
.semantic_boost
.as_ref()
.map(|boost| boost.distance_factor as f64),
});
})
.collect::<Vec<ChunkBoost>>();

diesel::insert_into(chunk_boosts_columns::chunk_boosts)
.values(boosts_to_insert)
.on_conflict((chunk_boosts_columns::chunk_id,))
.do_update()
.set((
chunk_boosts_columns::fulltext_boost_phrase
.eq(excluded(chunk_boosts_columns::fulltext_boost_phrase)),
chunk_boosts_columns::fulltext_boost_factor
.eq(excluded(chunk_boosts_columns::fulltext_boost_factor)),
chunk_boosts_columns::semantic_boost_phrase
.eq(excluded(chunk_boosts_columns::semantic_boost_phrase)),
chunk_boosts_columns::semantic_boost_factor
.eq(excluded(chunk_boosts_columns::semantic_boost_factor)),
))
.execute(&mut conn)
.await
.map_err(|e| {
log::error!("Failed to create chunk boosts {:}", e);
ServiceError::InternalServerError("Failed to create chunk boosts".to_string())
})?;

let chunk_group_bookmarks_to_insert: Vec<ChunkGroupBookmark> = insertion_data
.clone()
.iter()
Expand Down Expand Up @@ -1135,6 +1188,27 @@ pub async fn insert_chunk_metadata_query(
Ok(chunk_data)
}

#[tracing::instrument(skip(pool))]
pub async fn insert_chunk_boost(
chunk_boost: ChunkBoost,
pool: web::Data<Pool>,
) -> Result<ChunkBoost, ServiceError> {
use crate::data::schema::chunk_boosts::dsl as chunk_boosts_columns;
let mut conn = pool.get().await.map_err(|_e| {
ServiceError::InternalServerError("Failed to get postgres connection".to_string())
})?;
diesel::insert_into(chunk_boosts_columns::chunk_boosts)
.values(&chunk_boost)
.on_conflict_do_nothing()
.execute(&mut conn)
.await
.map_err(|e| {
log::error!("Failed to insert chunk boost {:}", e);
ServiceError::BadRequest("Failed to insert chunk boost".to_string())
})?;
Ok(chunk_boost)
}

#[tracing::instrument(skip(pool))]
pub async fn get_dataset_tags_id_from_names(
pool: web::Data<Pool>,
Expand Down Expand Up @@ -1363,6 +1437,34 @@ pub async fn update_chunk_metadata_query(
}
}

#[tracing::instrument(skip(pool))]
pub async fn update_chunk_boost_query(
chunk_boost: ChunkBoost,
pool: web::Data<Pool>,
) -> Result<(), ServiceError> {
use crate::data::schema::chunk_boosts::dsl as chunk_boosts_columns;
let mut conn = pool.get().await.map_err(|_e| {
ServiceError::InternalServerError("Failed to get postgres connection".to_string())
})?;

// Create a changeset based on which fields are present
let changes: ChunkBoostChangeset = chunk_boost.clone().into();

diesel::update(
chunk_boosts_columns::chunk_boosts
.filter(chunk_boosts_columns::chunk_id.eq(chunk_boost.chunk_id)),
)
.set(&changes)
.execute(&mut conn)
.await
.map_err(|e| {
log::error!("Failed to update chunk boost {:}", e);
ServiceError::BadRequest("Failed to update chunk boost".to_string())
})?;

Ok(())
}

#[tracing::instrument(skip(pool))]
pub async fn delete_chunk_metadata_query(
chunk_uuid: Vec<uuid::Uuid>,
Expand Down