Skip to content

Commit

Permalink
Add mock support for knowledge graph index (run-llama#600)
Browse files Browse the repository at this point in the history
Co-authored-by: Logan Markewich <Logan.Markewich@yardi.com>
  • Loading branch information
logan-markewich and Logan Markewich authored Mar 4, 2023
1 parent 8a7a6b4 commit 1afe7b1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
16 changes: 15 additions & 1 deletion gpt_index/token_counter/mock_chain_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.prompts.base import Prompt
from gpt_index.prompts.prompt_type import PromptType
from gpt_index.token_counter.utils import mock_extract_keywords_response
from gpt_index.token_counter.utils import (
mock_extract_keywords_response,
mock_extract_kg_triplets_response,
)
from gpt_index.utils import globals_helper

# TODO: consolidate with unit tests in tests/mock_utils/mock_predict.py
Expand Down Expand Up @@ -69,6 +72,13 @@ def _mock_query_keyword_extract(prompt_args: Dict) -> str:
return mock_extract_keywords_response(prompt_args["question"])


def _mock_knowledge_graph_triplet_extract(prompt_args: Dict, max_triplets: int) -> str:
"""Mock knowledge graph triplet extract."""
return mock_extract_kg_triplets_response(
prompt_args["text"], max_triplets=max_triplets
)


class MockLLMPredictor(LLMPredictor):
"""Mock LLM Predictor."""

Expand Down Expand Up @@ -99,5 +109,9 @@ def _predict(self, prompt: Prompt, **prompt_args: Any) -> str:
return _mock_keyword_extract(prompt_args)
elif prompt_str == PromptType.QUERY_KEYWORD_EXTRACT:
return _mock_query_keyword_extract(prompt_args)
elif prompt_str == PromptType.KNOWLEDGE_TRIPLET_EXTRACT:
return _mock_knowledge_graph_triplet_extract(
prompt_args, prompt.partial_dict.get("max_knowledge_triplets", 2)
)
else:
raise ValueError("Invalid prompt type.")
14 changes: 14 additions & 0 deletions gpt_index/token_counter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,17 @@ def mock_extract_keywords_response(
text_chunk, max_keywords=max_keywords, filter_stopwords=False
)
)


def mock_extract_kg_triplets_response(
text_chunk: str, max_triplets: Optional[int] = None
) -> str:
"""Generate 1 or more fake triplets."""
response = ""
if max_triplets is not None:
for i in range(max_triplets):
response += "(This is, a mock, triplet)\n"
else:
response += "(This is, a mock, triplet)\n"

return response

0 comments on commit 1afe7b1

Please sign in to comment.