Skip to content

Commit

Permalink
fix(code2prompt/core): handle circular include errors in template pro…
Browse files Browse the repository at this point in the history
…cessing
  • Loading branch information
raphaelmansuy authored and CTY-git committed Sep 23, 2024
1 parent dbdca94 commit 207e0cf
Showing 1 changed file with 30 additions and 14 deletions.
44 changes: 30 additions & 14 deletions code2prompt/core/template_processor.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,41 @@
from typing import OrderedDict
import os
from typing import OrderedDict, Tuple, List
import os
from jinja2 import Environment, FileSystemLoader
from code2prompt.utils.include_loader import CircularIncludeError, IncludeLoader
from prompt_toolkit import prompt
import re

def load_template(template_path):
""" Load a Jinja2 template from a file. """
"""
Load a Jinja2 template from a file.
"""
try:
with open(template_path, 'r', encoding='utf-8') as file:
return file.read()
except IOError as e:
raise IOError(f"Error loading template file: {e}") from e

def get_user_inputs(template_content):
""" Extract user-defined variables from the template and prompt for input. """
def get_user_inputs(template_content: str) -> Tuple[OrderedDict[str, str], List[Tuple[int, int, str]]]:
"""
Extract user-defined variables from the template and prompt for input.
Returns a tuple of user inputs and variable positions.
"""
# Use a regex pattern that excludes Jinja execute blocks and matches the new input syntax
pattern = r'{{\s*input:([^{}]+?)\s*}}'
user_vars = re.findall(pattern, template_content)
user_inputs = {}
matches = list(re.finditer(pattern, template_content))

for var in user_vars:
# Strip whitespace from the variable name
clean_var = var.strip()
user_inputs = OrderedDict()
positions = []

for match in matches:
var_name = match.group(1).strip()
positions.append((match.start(), match.end(), var_name))

# Only prompt for non-empty variable names that haven't been prompted before
if clean_var and clean_var not in user_inputs:
user_inputs[clean_var] = prompt(f"Enter value for {clean_var}: ")
if var_name and var_name not in user_inputs:
user_inputs[var_name] = prompt(f"Enter value for {var_name}: ")

return user_inputs

return user_inputs, positions

def process_template(template_content, files_data, user_inputs, template_path):
try:
Expand All @@ -38,6 +45,15 @@ def process_template(template_content, files_data, user_inputs, template_path):
autoescape=True,
keep_trailing_newline=True
)

# Get user inputs and variable positions
user_inputs, positions = get_user_inputs(template_content)

# Replace input placeholders with user-provided values
for start, end, var_name in reversed(positions):
replacement = user_inputs.get(var_name, '')
template_content = template_content[:start] + replacement + template_content[end:]

template = env.from_string(template_content)
return template.render(files=files_data, **user_inputs)
except CircularIncludeError as e:
Expand Down

0 comments on commit 207e0cf

Please sign in to comment.