Skip to content

Feature/61 질문 재정의 노드 강화 #107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ test_lhm/
.cursorignore
.vscode
table_info_db
ko_reranker_local
ko_reranker_local
28 changes: 23 additions & 5 deletions interface/lang2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from llm_utils.connect_db import ConnectDB
from llm_utils.graph import builder
from llm_utils.enriched_graph import builder as enriched_builder
from llm_utils.display_chart import DisplayChart
from llm_utils.llm_response_parser import LLMResponseParser


DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
SIDEBAR_OPTIONS = {
"show_total_token_usage": "Show Total Token Usage",
Expand Down Expand Up @@ -77,7 +77,10 @@ def execute_query(

graph = st.session_state.get("graph")
if graph is None:
graph = builder.compile()
graph_builder = (
enriched_builder if st.session_state.get("use_enriched") else builder
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💊 이 부분에 대해서는 @seyoung4503 님이 보여주신 enriched_builder 와 @nonegom 님의 state 분리 등을 잘 고려해서 만들면 더 확장성 있는 형태를 만들 수 있을것 같은데 지금상태로도 너무 좋습니다 :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

의견 감사합니다!
제안주신대로 추후에 확장하면 더욱 좋을것 같아요 🤩

)
graph = graph_builder.compile()
st.session_state["graph"] = graph

res = graph.invoke(
Expand Down Expand Up @@ -198,14 +201,29 @@ def should_show(_key: str) -> bool:

st.title("Lang2SQL")

# 워크플로우 선택(UI)
use_enriched = st.sidebar.checkbox(
"프로파일 추출 & 컨텍스트 보강 워크플로우 사용", value=False
)

# 세션 상태 초기화
if "graph" not in st.session_state:
st.session_state["graph"] = builder.compile()
if (
"graph" not in st.session_state
or st.session_state.get("use_enriched") != use_enriched
):
graph_builder = enriched_builder if use_enriched else builder
st.session_state["graph"] = graph_builder.compile()

# 프로파일 추출 & 컨텍스트 보강 그래프
st.session_state["use_enriched"] = use_enriched
st.info("Lang2SQL이 성공적으로 시작되었습니다.")

# 새로고침 버튼 추가
if st.sidebar.button("Lang2SQL 새로고침"):
st.session_state["graph"] = builder.compile()
graph_builder = (
enriched_builder if st.session_state.get("use_enriched") else builder
)
st.session_state["graph"] = graph_builder.compile()
st.sidebar.success("Lang2SQL이 성공적으로 새로고침되었습니다.")

user_query = st.text_area(
Expand Down
70 changes: 70 additions & 0 deletions llm_utils/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from pydantic import BaseModel, Field

from .llm_factory import get_llm

from dotenv import load_dotenv
from prompt.template_loader import get_prompt_template


env_path = os.path.join(os.getcwd(), ".env")

if os.path.exists(env_path):
Expand All @@ -20,6 +22,16 @@
llm = get_llm()


class QuestionProfile(BaseModel):
is_timeseries: bool = Field(description="시계열 분석 필요 여부")
is_aggregation: bool = Field(description="집계 함수 필요 여부")
has_filter: bool = Field(description="조건 필터 필요 여부")
is_grouped: bool = Field(description="그룹화 필요 여부")
has_ranking: bool = Field(description="정렬/순위 필요 여부")
has_temporal_comparison: bool = Field(description="기간 비교 포함 여부")
intent_type: str = Field(description="질문의 주요 의도 유형")


def create_query_refiner_chain(llm):
prompt = get_prompt_template("query_refiner_prompt")
tool_choice_prompt = ChatPromptTemplate.from_messages(
Expand Down Expand Up @@ -72,8 +84,66 @@ def create_query_maker_chain(llm):
return query_maker_prompt | llm


def create_query_refiner_with_profile_chain(llm):
prompt = get_prompt_template("query_refiner_prompt")

tool_choice_prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(prompt),
MessagesPlaceholder(variable_name="user_input"),
SystemMessagePromptTemplate.from_template(
"다음은 사용자의 실제 사용 가능한 테이블 및 컬럼 정보입니다:"
),
MessagesPlaceholder(variable_name="searched_tables"),
# 프로파일 정보 입력
SystemMessagePromptTemplate.from_template(
"다음은 사용자의 질문을 분석한 프로파일 정보입니다."
),
MessagesPlaceholder("profile_prompt"),
SystemMessagePromptTemplate.from_template(
"""
위 사용자의 입력과 위 조건을 바탕으로
분석 관점에서 **충분히 답변 가능한 형태**로
"구체화된 질문"을 작성하세요.
""",
),
]
)

return tool_choice_prompt | llm


def create_query_enrichment_chain(llm):
prompt = get_prompt_template("query_enrichment_prompt")

enrichment_prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(prompt),
]
)

chain = enrichment_prompt | llm
return chain


def create_profile_extraction_chain(llm):
prompt = get_prompt_template("profile_extraction_prompt")

profile_prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(prompt),
]
)

chain = profile_prompt | llm.with_structured_output(QuestionProfile)
return chain


query_refiner_chain = create_query_refiner_chain(llm)
query_maker_chain = create_query_maker_chain(llm)
profile_extraction_chain = create_profile_extraction_chain(llm)
query_refiner_with_profile_chain = create_query_refiner_with_profile_chain(llm)
query_enrichment_chain = create_query_enrichment_chain(llm)

if __name__ == "__main__":
query_refiner_chain.invoke()
41 changes: 41 additions & 0 deletions llm_utils/enriched_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import json

from langgraph.graph import StateGraph, END
from llm_utils.graph import (
QueryMakerState,
GET_TABLE_INFO,
PROFILE_EXTRACTION,
QUERY_REFINER,
CONTEXT_ENRICHMENT,
QUERY_MAKER,
get_table_info_node,
profile_extraction_node,
query_refiner_with_profile_node,
context_enrichment_node,
query_maker_node,
)

"""
기본 워크플로우에 '프로파일 추출(PROFILE_EXTRACTION)'과 '컨텍스트 보강(CONTEXT_ENRICHMENT)'를
추가한 확장된 그래프입니다.
"""

# StateGraph 생성 및 구성
builder = StateGraph(QueryMakerState)
builder.set_entry_point(GET_TABLE_INFO)

# 노드 추가
builder.add_node(GET_TABLE_INFO, get_table_info_node)
builder.add_node(QUERY_REFINER, query_refiner_with_profile_node)
builder.add_node(PROFILE_EXTRACTION, profile_extraction_node)
builder.add_node(CONTEXT_ENRICHMENT, context_enrichment_node)
builder.add_node(QUERY_MAKER, query_maker_node)

# 기본 엣지 설정
builder.add_edge(GET_TABLE_INFO, PROFILE_EXTRACTION)
builder.add_edge(PROFILE_EXTRACTION, QUERY_REFINER)
builder.add_edge(QUERY_REFINER, CONTEXT_ENRICHMENT)
builder.add_edge(CONTEXT_ENRICHMENT, QUERY_MAKER)

# QUERY_MAKER 노드 후 종료
builder.add_edge(QUERY_MAKER, END)
106 changes: 106 additions & 0 deletions llm_utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,23 @@
from llm_utils.chains import (
query_refiner_chain,
query_maker_chain,
query_refiner_with_profile_chain,
profile_extraction_chain,
query_enrichment_chain,
)

from llm_utils.tools import get_info_from_db
from llm_utils.retrieval import search_tables
from llm_utils.utils import profile_to_text

# 노드 식별자 정의
QUERY_REFINER = "query_refiner"
GET_TABLE_INFO = "get_table_info"
TOOL = "tool"
TABLE_FILTER = "table_filter"
QUERY_MAKER = "query_maker"
PROFILE_EXTRACTION = "profile_extraction"
CONTEXT_ENRICHMENT = "context_enrichment"


# 상태 타입 정의 (추가 상태 정보와 메시지들을 포함)
Expand All @@ -31,12 +37,38 @@ class QueryMakerState(TypedDict):
searched_tables: dict[str, dict[str, str]]
best_practice_query: str
refined_input: str
question_profile: dict
generated_query: str
retriever_name: str
top_n: int
device: str


# 노드 함수: PROFILE_EXTRACTION 노드
def profile_extraction_node(state: QueryMakerState):
"""
자연어 쿼리로부터 질문 유형(PROFILE)을 추출하는 노드입니다.

이 노드는 주어진 자연어 쿼리에서 질문의 특성을 분석하여, 해당 질문이 시계열 분석, 집계 함수 사용, 조건 필터 필요 여부,
그룹화, 정렬/순위, 기간 비교 등 다양한 특성을 갖는지 여부를 추출합니다.

추출된 정보는 `QuestionProfile` 모델에 맞춰 저장됩니다. `QuestionProfile` 모델의 필드는 다음과 같습니다:
- `is_timeseries`: 시계열 분석 필요 여부
- `is_aggregation`: 집계 함수 필요 여부
- `has_filter`: 조건 필터 필요 여부
- `is_grouped`: 그룹화 필요 여부
- `has_ranking`: 정렬/순위 필요 여부
- `has_temporal_comparison`: 기간 비교 포함 여부
- `intent_type`: 질문의 주요 의도 유형

"""
result = profile_extraction_chain.invoke({"question": state["messages"][0].content})

state["question_profile"] = result
print("profile_extraction_node : ", result)
return state


# 노드 함수: QUERY_REFINER 노드
def query_refiner_node(state: QueryMakerState):
res = query_refiner_chain.invoke(
Expand All @@ -52,6 +84,80 @@ def query_refiner_node(state: QueryMakerState):
return state


# 노드 함수: QUERY_REFINER 노드
def query_refiner_with_profile_node(state: QueryMakerState):
"""
자연어 쿼리로부터 질문 유형(PROFILE)을 사용해 자연어 질의를 확장하는 노드입니다.

"""

profile_bullets = profile_to_text(state["question_profile"])
res = query_refiner_with_profile_chain.invoke(
input={
"user_input": [state["messages"][0].content],
"user_database_env": [state["user_database_env"]],
"best_practice_query": [state["best_practice_query"]],
"searched_tables": [json.dumps(state["searched_tables"])],
"profile_prompt": [profile_bullets],
}
)
state["messages"].append(res)
state["refined_input"] = res

print("refined_input before context enrichment : ", res.content)
return state


# 노드 함수: CONTEXT_ENRICHMENT 노드
def context_enrichment_node(state: QueryMakerState):
"""
주어진 질문과 관련된 메타데이터를 기반으로 질문을 풍부하게 만드는 노드입니다.

이 함수는 `refined_question`, `profiles`, `related_tables` 정보를 이용하여 자연어 질문을 보강합니다.
보강 과정에서는 질문의 의도를 유지하면서, 추가적인 세부 정보를 제공하거나 잘못된 용어를 수정합니다.

주요 작업:
- 주어진 질문의 메타데이터 (`question_profile` 및 `searched_tables`)를 활용하여, 질문을 수정하거나 추가 정보를 삽입합니다.
- 질문이 시계열 분석 또는 집계 함수 관련인 경우, 이를 명시적으로 강조합니다 (예: "지난 30일 동안").
- 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: ‘미국’ → ‘USA’).
- 보강된 질문을 출력합니다.

Args:
state (QueryMakerState): 쿼리와 관련된 상태 정보를 담고 있는 객체.
상태 객체는 `refined_input`, `question_profile`, `searched_tables` 등의 정보를 포함합니다.

Returns:
QueryMakerState: 보강된 질문이 포함된 상태 객체.

Example:
Given the refined question "What are the total sales in the last month?",
the function would enrich it with additional information such as:
- Ensuring the time period is specified correctly.
- Correcting any column names if necessary.
- Returning the enriched version of the question.
"""

searched_tables = state["searched_tables"]
searched_tables_json = json.dumps(searched_tables, ensure_ascii=False, indent=2)

question_profile = state["question_profile"].model_dump()
question_profile_json = json.dumps(question_profile, ensure_ascii=False, indent=2)

enriched_text = query_enrichment_chain.invoke(
input={
"refined_question": state["refined_input"],
"profiles": question_profile_json,
"related_tables": searched_tables_json,
}
)

state["refined_input"] = enriched_text
state["messages"].append(enriched_text)
print("After context enrichment : ", enriched_text.content)

return state


def get_table_info_node(state: QueryMakerState):
# retriever_name과 top_n을 이용하여 검색 수행
documents_dict = search_tables(
Expand Down
17 changes: 17 additions & 0 deletions llm_utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
def profile_to_text(profile_obj) -> str:
mapping = {
"is_timeseries": "• 시계열 분석 필요",
"is_aggregation": "• 집계 함수 필요",
"has_filter": "• WHERE 조건 필요",
"is_grouped": "• GROUP BY 필요",
"has_ranking": "• 정렬/순위 필요",
"has_temporal_comparison": "• 기간 비교 필요",
}
bullets = [
text for field, text in mapping.items() if getattr(profile_obj, field, False)
]
intent = getattr(profile_obj, "intent_type", None)
if intent:
bullets.append(f"• 의도 유형 → {intent}")

return "\n".join(bullets)
19 changes: 19 additions & 0 deletions prompt/profile_extraction_prompt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Role

You are an assistant that analyzes a user question and extracts the following profiles as JSON:
- is_timeseries (boolean)
- is_aggregation (boolean)
- has_filter (boolean)
- is_grouped (boolean)
- has_ranking (boolean)
- has_temporal_comparison (boolean)
- intent_type (one of: trend, lookup, comparison, distribution)

# Input

Question:
{question}

# Output Example

The output must be a valid JSON matching the QuestionProfile schema.
Loading