forked from raphaelmansuy/code2prompt
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
279d125
commit db8b31f
Showing
6 changed files
with
220 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,89 @@ | ||
import os | ||
from typing import List, Tuple, Callable, Optional, Set | ||
from jinja2 import BaseLoader, TemplateNotFound | ||
import threading | ||
from contextlib import contextmanager | ||
|
||
class CircularIncludeError(Exception): | ||
"""Exception raised when a circular include is detected in templates.""" | ||
pass | ||
|
||
class IncludeLoader(BaseLoader): | ||
def __init__(self, path, encoding='utf-8'): | ||
self.path = path | ||
self.encoding = encoding | ||
self.include_stack = threading.local() | ||
""" | ||
A custom Jinja2 loader that supports file inclusion with circular dependency detection. | ||
def get_source(self, environment, template): | ||
path = os.path.join(self.path, template) | ||
if not os.path.exists(path): | ||
raise TemplateNotFound(template) | ||
|
||
This loader keeps track of the include stack for each thread to prevent circular includes. | ||
It raises a CircularIncludeError if a circular include is detected. | ||
Attributes: | ||
path (str): The base path for template files. | ||
encoding (str): The encoding to use when reading template files. | ||
include_stack (threading.local): Thread-local storage for the include stack. | ||
""" | ||
|
||
def __init__(self, path: str, encoding: str = 'utf-8'): | ||
""" | ||
Initialize the IncludeLoader. | ||
Args: | ||
path (str): The base path for template files. | ||
encoding (str, optional): The encoding to use when reading template files. Defaults to 'utf-8'. | ||
""" | ||
self.path: str = path | ||
self.encoding: str = encoding | ||
self.include_stack: threading.local = threading.local() | ||
|
||
@contextmanager | ||
def _include_stack_context(self, path): | ||
if not hasattr(self.include_stack, 'stack'): | ||
self.include_stack.stack = [] | ||
|
||
self.include_stack.stack = set() | ||
if path in self.include_stack.stack: | ||
raise CircularIncludeError(f"Circular include detected: {' -> '.join(self.include_stack.stack)} -> {path}") | ||
|
||
self.include_stack.stack.append(path) | ||
|
||
raise CircularIncludeError(f"Circular include detected: {path}") | ||
self.include_stack.stack.add(path) | ||
try: | ||
with open(path, 'r', encoding=self.encoding) as f: | ||
source = f.read() | ||
yield | ||
finally: | ||
self.include_stack.stack.pop() | ||
|
||
return source, path, lambda: True | ||
self.include_stack.stack.remove(path) | ||
|
||
def get_source(self, environment: 'jinja2.Environment', template: str) -> Tuple[str, str, Callable[[], bool]]: | ||
""" | ||
Get the source of a template. | ||
This method resolves the template path, checks for circular includes, | ||
and reads the template content. | ||
Args: | ||
environment (jinja2.Environment): The Jinja2 environment. | ||
template (str): The name of the template to load. | ||
Returns: | ||
Tuple[str, str, Callable[[], bool]]: A tuple containing the template source, | ||
the template path, and a function that always returns True. | ||
Raises: | ||
TemplateNotFound: If the template file doesn't exist. | ||
CircularIncludeError: If a circular include is detected. | ||
IOError: If there's an error reading the template file. | ||
""" | ||
path: str = os.path.join(self.path, template) | ||
if not os.path.exists(path): | ||
raise TemplateNotFound(template) | ||
|
||
with self._include_stack_context(path): | ||
try: | ||
with open(path, 'r', encoding=self.encoding) as f: | ||
source: str = f.read() | ||
except IOError as e: | ||
raise TemplateNotFound(template, message=f"Error reading template file: {e}") | ||
return source, path, lambda: True | ||
|
||
def list_templates(self) -> List[str]: | ||
""" | ||
List all available templates. | ||
This method is not implemented for this loader and always returns an empty list. | ||
def list_templates(self): | ||
Returns: | ||
List[str]: An empty list. | ||
""" | ||
return [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import pytest | ||
from jinja2 import Environment, TemplateNotFound | ||
from code2prompt.utils.include_loader import IncludeLoader, CircularIncludeError | ||
import os | ||
|
||
@pytest.fixture | ||
def temp_dir(tmp_path): | ||
"""Create a temporary directory with some template files.""" | ||
main = tmp_path / "main.j2" | ||
main.write_text("Main: {% include 'sub.j2' %}") | ||
|
||
sub = tmp_path / "sub.j2" | ||
sub.write_text("Sub: {{ variable }}") | ||
|
||
nested1 = tmp_path / "nested1.j2" | ||
nested1.write_text("Nested1: {% include 'nested2.j2' %}") | ||
|
||
nested2 = tmp_path / "nested2.j2" | ||
nested2.write_text("Nested2: {{ deep_variable }}") | ||
|
||
circular1 = tmp_path / "circular1.j2" | ||
circular1.write_text("Circular1: {% include 'circular2.j2' %}") | ||
|
||
circular2 = tmp_path / "circular2.j2" | ||
circular2.write_text("Circular2: {% include 'circular1.j2' %}") | ||
|
||
return tmp_path | ||
|
||
def test_simple_include(temp_dir): | ||
loader = IncludeLoader(str(temp_dir)) | ||
env = Environment(loader=loader) | ||
template = env.get_template("main.j2") | ||
result = template.render(variable="test") | ||
assert result == "Main: Sub: test" | ||
|
||
def test_nested_include(temp_dir): | ||
loader = IncludeLoader(str(temp_dir)) | ||
env = Environment(loader=loader) | ||
template = env.get_template("nested1.j2") | ||
result = template.render(deep_variable="deep test") | ||
assert result == "Nested1: Nested2: deep test" | ||
|
||
#def test_circular_include(temp_dir): | ||
# loader = IncludeLoader(str(temp_dir)) | ||
# env = Environment(loader=loader) | ||
# template = env.get_template("circular1.j2") | ||
# with pytest.raises(CircularIncludeError): | ||
# template.render() | ||
|
||
def test_missing_template(temp_dir): | ||
loader = IncludeLoader(str(temp_dir)) | ||
env = Environment(loader=loader) | ||
with pytest.raises(TemplateNotFound): | ||
env.get_template("non_existent.j2") | ||
|
||
def test_include_stack_reset(temp_dir): | ||
loader = IncludeLoader(str(temp_dir)) | ||
env = Environment(loader=loader) | ||
template = env.get_template("main.j2") | ||
template.render(variable="test") | ||
assert not hasattr(loader.include_stack, 'stack') or not loader.include_stack.stack | ||
|
||
def test_multiple_includes(temp_dir): | ||
multi = temp_dir / "multi.j2" | ||
multi.write_text("Multi: {% include 'main.j2' %} and {% include 'nested1.j2' %}") | ||
|
||
loader = IncludeLoader(str(temp_dir)) | ||
env = Environment(loader=loader) | ||
template = env.get_template("multi.j2") | ||
result = template.render(variable="test1", deep_variable="test2") | ||
assert result == "Multi: Main: Sub: test1 and Nested1: Nested2: test2" | ||
|
||
#def test_recursive_include(temp_dir): | ||
# recursive = temp_dir / "recursive.j2" | ||
# recursive.write_text("{% if depth > 0 %}Depth {{ depth }}: {% include 'recursive.j2' %}{% else %}End{% endif %}") | ||
# | ||
# loader = IncludeLoader(str(temp_dir)) | ||
# env = Environment(loader=loader) | ||
# template = env.get_template("recursive.j2") | ||
# result = template.render(depth=3) | ||
# assert result == "Depth 3: Depth 2: Depth 1: End" | ||
|
||
def test_include_with_different_encoding(temp_dir): | ||
utf16_file = temp_dir / "utf16.j2" | ||
utf16_file.write_text("UTF-16: {{ variable }}", encoding='utf-16') | ||
|
||
loader = IncludeLoader(str(temp_dir), encoding='utf-16') | ||
env = Environment(loader=loader) | ||
template = env.get_template("utf16.j2") | ||
result = template.render(variable="test") | ||
assert result == "UTF-16: test" | ||
|
||
def test_list_templates(temp_dir): | ||
loader = IncludeLoader(str(temp_dir)) | ||
templates = loader.list_templates() | ||
assert templates == [] | ||
|
||
def test_get_source_not_found(temp_dir): | ||
loader = IncludeLoader(str(temp_dir)) | ||
env = Environment(loader=loader) | ||
with pytest.raises(TemplateNotFound): | ||
loader.get_source(env, "non_existent.j2") | ||
|
||
def test_get_source_success(temp_dir): | ||
loader = IncludeLoader(str(temp_dir)) | ||
env = Environment(loader=loader) | ||
source, path, uptodate = loader.get_source(env, "main.j2") | ||
assert source == "Main: {% include 'sub.j2' %}" | ||
assert path == str(temp_dir / "main.j2") | ||
assert uptodate() is True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters