From e928f18ba78a6a877397f67ffe8c49a7e2a741f4 Mon Sep 17 00:00:00 2001 From: ThomasFaria Date: Thu, 6 Feb 2025 16:37:50 +0000 Subject: [PATCH] [feat] start working on vdb --- build-vector-db.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 2 files changed, 54 insertions(+) create mode 100644 build-vector-db.py diff --git a/build-vector-db.py b/build-vector-db.py new file mode 100644 index 0000000..63a4a59 --- /dev/null +++ b/build-vector-db.py @@ -0,0 +1,53 @@ +import pandas as pd +from langchain_community.document_loaders import DataFrameLoader + +from src.constants.paths import ( + URL_EXPLANATORY_NOTES, + URL_MAPPING_TABLE, +) +from src.mappings.mappings import get_mapping +from src.utils.cache_models import get_file_system + +fs = get_file_system() + +with fs.open(URL_MAPPING_TABLE) as f: + table_corres = pd.read_excel(f, dtype=str) + +with fs.open(URL_EXPLANATORY_NOTES) as f: + notes_ex = pd.read_excel(f, dtype=str) + +mapping = get_mapping(notes_ex, table_corres) + +codes_naf2025 = [] +tmp = [] +for code08 in mapping: + for code25 in code08.naf2025: + if code25.code not in tmp: + tmp.append(code25.code) + codes_naf2025.append(code25) + +df_naf2025 = pd.DataFrame([c.__dict__ for c in codes_naf2025]) + + +df_naf2025.loc[:, "content"] = [ + "\n\n".join( + filter( + None, + [ + f"# {row.code} : {row.label}", + f"## Explications des activités incluses dans la sous-classe\n{row.notes}" + if row.notes + else None, + f"## Liste d'exemples d'activités incluses dans la sous-classe\n{row.include}" + if row.include + else None, + f"## Liste d'exemples d'activités non incluses dans la sous-classe\n{row.not_include}" + if row.not_include + else None, + ], + ) + ) + for row in df_naf2025.itertuples() +] + +document_list = DataFrameLoader(df_naf2025, page_content_column="content").load() diff --git a/requirements.txt b/requirements.txt index 0b2d79e..4dfd52f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ mlflow>=2.16.2 ipywidgets>=8.1.0 compressed_tensors>=0.7.0 optimum>=1.23.1 +langchain_community>=0.3.1