Skip to content

Commit

Permalink
fix(code2prompt/core): Add support for template include feature
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelmansuy authored and CTY-git committed Sep 23, 2024
1 parent ec92cbf commit 279d125
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 9 deletions.
2 changes: 1 addition & 1 deletion code2prompt/core/generate_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ def generate_content(files_data, options):
if options['template']:
template_content = load_template(options['template'])
user_inputs = get_user_inputs(template_content)
return process_template(template_content, files_data, user_inputs)
return process_template(template_content, files_data, user_inputs, options['template'])
return generate_markdown_content(files_data, options['no_codeblock'])
23 changes: 15 additions & 8 deletions code2prompt/core/template_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import OrderedDict
from jinja2 import Template, Environment, FileSystemLoader
import os
from jinja2 import Environment, FileSystemLoader
from code2prompt.utils.include_loader import CircularIncludeError, IncludeLoader
from prompt_toolkit import prompt
import re

Expand Down Expand Up @@ -27,13 +29,18 @@ def get_user_inputs(template_content):

return user_inputs

def process_template(template_content, files_data, user_inputs):
""" Process the Jinja2 template with the given data and user inputs. """

def process_template(template_content, files_data, user_inputs, template_path):
try:
# Replace {{input:variable}} with {{variable}} for Jinja2 processing
processed_content = re.sub(r'{{\s*input:([^{}]+?)\s*}}', r'{{\1}}', template_content)

template = Template(processed_content)
template_dir = os.path.dirname(template_path)
env = Environment(
loader=IncludeLoader(template_dir),
autoescape=True,
keep_trailing_newline=True
)
template = env.from_string(template_content)
return template.render(files=files_data, **user_inputs)
except CircularIncludeError as e:
raise ValueError(f"Circular include detected: {str(e)}")
except Exception as e:
raise ValueError(f"Error processing template: {e}") from e
raise ValueError(f"Error processing template: {e}")
36 changes: 36 additions & 0 deletions code2prompt/utils/include_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
from jinja2 import BaseLoader, TemplateNotFound
import threading

class CircularIncludeError(Exception):
pass

class IncludeLoader(BaseLoader):
def __init__(self, path, encoding='utf-8'):
self.path = path
self.encoding = encoding
self.include_stack = threading.local()

def get_source(self, environment, template):
path = os.path.join(self.path, template)
if not os.path.exists(path):
raise TemplateNotFound(template)

if not hasattr(self.include_stack, 'stack'):
self.include_stack.stack = []

if path in self.include_stack.stack:
raise CircularIncludeError(f"Circular include detected: {' -> '.join(self.include_stack.stack)} -> {path}")

self.include_stack.stack.append(path)

try:
with open(path, 'r', encoding=self.encoding) as f:
source = f.read()
finally:
self.include_stack.stack.pop()

return source, path, lambda: True

def list_templates(self):
return []
116 changes: 116 additions & 0 deletions tests/test_template_include.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import pytest
from code2prompt.core.template_processor import process_template
import os

def test_include_feature(tmp_path):
# Create a main template
main_template = tmp_path / "main.j2"
main_template.write_text("Main: {% include 'sub.j2' %}")

# Create a sub-template
sub_template = tmp_path / "sub.j2"
sub_template.write_text("Sub: {{ variable }}")

template_content = main_template.read_text()
files_data = []
user_inputs = {"variable": "test"}

result = process_template(template_content, files_data, user_inputs, str(main_template))
assert result == "Main: Sub: test"

def test_nested_include(tmp_path):
# Create a main template
main_template = tmp_path / "main.j2"
main_template.write_text("Main: {% include 'sub1.j2' %}")

# Create sub-templates
sub1_template = tmp_path / "sub1.j2"
sub1_template.write_text("Sub1: {% include 'sub2.j2' %}")

sub2_template = tmp_path / "sub2.j2"
sub2_template.write_text("Sub2: {{ variable }}")

template_content = main_template.read_text()
files_data = []
user_inputs = {"variable": "nested"}

result = process_template(template_content, files_data, user_inputs, str(main_template))
assert result == "Main: Sub1: Sub2: nested"

def test_multiple_includes(tmp_path):
# Create a main template
main_template = tmp_path / "main.j2"
main_template.write_text("Main: {% include 'sub1.j2' %} and {% include 'sub2.j2' %}")

# Create sub-templates
sub1_template = tmp_path / "sub1.j2"
sub1_template.write_text("Sub1: {{ var1 }}")

sub2_template = tmp_path / "sub2.j2"
sub2_template.write_text("Sub2: {{ var2 }}")

template_content = main_template.read_text()
files_data = []
user_inputs = {"var1": "first", "var2": "second"}

result = process_template(template_content, files_data, user_inputs, str(main_template))
assert result == "Main: Sub1: first and Sub2: second"

def test_include_with_context(tmp_path):
# Create a main template
main_template = tmp_path / "main.j2"
main_template.write_text("Main: {% include 'sub.j2' %}")

# Create a sub-template
sub_template = tmp_path / "sub.j2"
sub_template.write_text("Sub: {{ main_var }} and {{ sub_var }}")

template_content = main_template.read_text()
files_data = []
user_inputs = {"main_var": "from main", "sub_var": "from sub"}

result = process_template(template_content, files_data, user_inputs, str(main_template))
assert result == "Main: Sub: from main and from sub"

def test_include_missing_file(tmp_path):
# Create a main template
main_template = tmp_path / "main.j2"
main_template.write_text("Main: {% include 'missing.j2' %}")

template_content = main_template.read_text()
files_data = []
user_inputs = {}

with pytest.raises(ValueError, match="Error processing template"):
process_template(template_content, files_data, user_inputs, str(main_template))

def test_include_with_files_data(tmp_path):
# Create a main template
main_template = tmp_path / "main.j2"
main_template.write_text("Main: {% include 'sub.j2' %}")

# Create a sub-template
sub_template = tmp_path / "sub.j2"
sub_template.write_text("Sub: {{ files[0].name }}")

template_content = main_template.read_text()
files_data = [{"name": "test_file.py", "content": "print('Hello')"}]
user_inputs = {}

result = process_template(template_content, files_data, user_inputs, str(main_template))
assert result == "Main: Sub: test_file.py"

def test_circular_include(tmp_path):
# Create templates with circular inclusion
template1 = tmp_path / "template1.j2"
template1.write_text("T1: {% include 'template2.j2' %}")

template2 = tmp_path / "template2.j2"
template2.write_text("T2: {% include 'template1.j2' %}")

template_content = template1.read_text()
files_data = []
user_inputs = {}

with pytest.raises(ValueError, match="Circular include detected"):
process_template(template_content, files_data, user_inputs, str(template1))

0 comments on commit 279d125

Please sign in to comment.