Skip to content

Commit

Permalink
Merge pull request #12 from birbbit/hb/fixes
Browse files Browse the repository at this point in the history
Few small fixes
  • Loading branch information
KillianLucas authored Feb 11, 2024
2 parents 4918546 + 6c60251 commit 9cc0420
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions aifs/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
def format_function_details(func_def):
name = func_def.name
args = [(arg.arg, None if not arg.annotation else ast.unparse(arg.annotation)) for arg in func_def.args.args]
vararg = (func_def.args.vararg.arg, ast.unparse(func_def.args.vararg.annotation)) if func_def.args.vararg else None
vararg = (func_def.args.vararg.arg, ast.unparse(func_def.args.vararg.annotation)) if func_def.args.vararg and func_def.args.vararg.annotation else None
return_annotation = ast.unparse(func_def.returns) if func_def.returns else None
docstring = ast.get_docstring(func_def)

Expand Down Expand Up @@ -64,7 +64,7 @@ def chunk_file(path):
return [c.text for c in chunks]

def index_file(path, python_docstrings_only=False):
if python_docstrings_only and path.lower.endswith(".py"):
if python_docstrings_only and path.lower().endswith(".py"):
return minimally_index_python_file(path)

log(f"Indexing {path}...")
Expand Down Expand Up @@ -106,7 +106,7 @@ def minimally_index_python_file(path):
chunks.append(formatted_string)
representations.append(docstring if docstring else node.name)

embeddings = embed(representations)
embeddings = embed(representations) if representations else []
last_modified = os.path.getmtime(path)

return {
Expand Down Expand Up @@ -195,12 +195,13 @@ def search(query, path=None, max_results=5, verbose=False, python_docstrings_onl
for file_path, file_index in index.items():
ids = [str(id) for id in range(id_counter, id_counter + len(file_index["chunks"]))]
id_counter += len(file_index["chunks"])
collection.add(
ids=ids,
embeddings=file_index["embeddings"],
documents=file_index["chunks"],
metadatas=[{"source": file_path}] * len(file_index["chunks"]),
)
if ids:
collection.add(
ids=ids,
embeddings=file_index["embeddings"],
documents=file_index["chunks"],
metadatas=[{"source": file_path}] * len(file_index["chunks"]),
)

results = collection.query(
query_texts=[query],
Expand Down

0 comments on commit 9cc0420

Please sign in to comment.