Skip to content

Commit

Permalink
benchmark fix
Browse files Browse the repository at this point in the history
  • Loading branch information
srozb committed Jan 22, 2025
1 parent 136810d commit 6e0548f
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def include(filename) -> str:

setup(
name="thsensai",
version="0.3.0",
version="0.3.1",
description="A library and CLI tool for AI-aided threat hunting and intelligence analysis.",
long_description=include("README.md"),
long_description_content_type="text/markdown",
Expand Down
4 changes: 1 addition & 3 deletions thsensai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,7 @@ def analyze(
rp(f"[bold red]Error acquiring intelligence: {e}[/bold red]")
raise typer.Exit(code=1)

intel_obj.chunk_size = chunk_size
intel_obj.chunk_overlap = chunk_overlap
intel_obj.split_content()
intel_obj.split_content(chunk_size, chunk_overlap)

if write_intel:
intel_obj.save_to_disk(output_dir)
Expand Down
53 changes: 53 additions & 0 deletions thsensai/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
This module will contain utility functions for processing and storing data in the vector store.
Not implemented yet.
"""

from typing import List
from langchain_ollama import OllamaEmbeddings
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.documents import Document


# Initialize embeddings and vector store
embeddings = OllamaEmbeddings(model="nomic-embed-text:latest")
vector_store = InMemoryVectorStore(embeddings)


def retrieve_context(query: str, top_k: int = 5) -> str:
"""
Retrieve the most relevant documents from the vector store based on a query.
The function performs a similarity search using the provided query and returns the
concatenated content of the top-k most relevant documents.
Args:
query (str): Query to search for.
top_k (int): Number of top results to retrieve (default is 5).
Returns:
str: Concatenated content of the top matching documents.
"""
docs = vector_store.similarity_search(query, k=top_k)
return "\n\n".join([doc.page_content for doc in docs])

def store_docs(docs: List[Document]) -> None:
"""
Add documents to the vector store.
Args:
docs (List[Document]): List of documents to store.
"""
vector_store.add_documents(documents=docs)

def store_data(data: str) -> None:
"""
Process and store data in the vector store.
This function splits the raw data into smaller chunks and adds them to the vector store.
Args:
data (str): Raw data to be split and stored.
"""
chunks = split_docs(data)
store_docs(chunks)
6 changes: 5 additions & 1 deletion thsensai/intel.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,17 @@ def from_source(cls, source: str, css_selector: Optional[str] = None) -> Intel:
intel_instance.acquire_intel()
return intel_instance

def split_content(self):
def split_content(self, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None):
"""
Split the content into smaller chunks for processing.
Returns:
List[Document]: List of split document chunks.
"""
if chunk_size:
self.chunk_size = chunk_size
if chunk_overlap:
self.chunk_overlap = chunk_overlap
if self.content:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
Expand Down
5 changes: 3 additions & 2 deletions thsensai/test/test_intel.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def process_model_benchmark(
"""
table = create_benchmark_table(model)
for chunk_size in chunk_sizes:
for chunk_overlap in chunk_overlaps:
for chunk_overlap in chunk_overlaps: # iterate over test_cases first
table = run_benchmarks_for_configuration(
table, model, chunk_size, chunk_overlap
)
Expand Down Expand Up @@ -158,6 +158,7 @@ def run_benchmarks_for_configuration(
keywords = set(test_case.get("keywords", []))

intel_obj = Intel.from_source(*target)
intel_obj.split_content(chunk_size, chunk_overlap)
scraped_size = calculate_scraped_size(intel_obj.content)

params = {
Expand Down Expand Up @@ -219,6 +220,6 @@ def run_extraction_with_timer(
TimeElapsedColumn(),
MofNCompleteColumn(), # pylint: disable=duplicate-code
) as progress:
iocs_obj = IOCs.from_documents(intel_obj, llm, progress)
iocs_obj = IOCs.from_intel(intel_obj, llm, progress)
total_inference_time = time.time() - start_time
return total_inference_time, iocs_obj

0 comments on commit 6e0548f

Please sign in to comment.