Skip to content

Commit

Permalink
fix(.): Add include file feature
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelmansuy authored and CTY-git committed Sep 23, 2024
1 parent 279d125 commit db8b31f
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 32 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [0.6.11]
- Support of dynamic variable such as {{input:var1}} in template
- Support {% incliude "./file1.txt" } feature
- Fix. Only a variable one time
- Update and improve price table

Expand Down
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,31 @@ The tool will generate a tailored prompt for an AI to create a detailed plan for
Certainly! I'll write a paragraph explaining the include feature with an example, tailored to the context of code2prompt and based on the provided README.md file content.
## Include File Feature
The code2prompt project now supports a powerful "include file" feature, enhancing template modularity and reusability. This feature allows you to seamlessly incorporate external file content into your main template using the `{% include %}` directive. For example, in the main `analyze-code.j2` template, you can break down complex sections into smaller, manageable files:
```jinja2
# Elite Code Analyzer and Improvement Strategist 2.0
{% include 'sections/role_and_goal.j2' %}
{% include 'sections/core_competencies.j2' %}
## Task Breakdown
1. Initial Assessment
{% include 'tasks/initial_assessment.j2' %}
2. Multi-Dimensional Analysis (Utilize Tree of Thought)
{% include 'tasks/multi_dimensional_analysis.j2' %}
// ... other sections ...
```

This approach allows you to organize your template structure more efficiently, improving maintainability and allowing for easy updates to specific sections without modifying the entire template. The include feature supports both relative and absolute paths, making it flexible for various project structures. By leveraging this feature, you can significantly reduce code duplication, improve template management, and create a more modular and scalable structure for your code2prompt templates.

## Configuration File

Expand Down
95 changes: 74 additions & 21 deletions code2prompt/utils/include_loader.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,89 @@
import os
from typing import List, Tuple, Callable, Optional, Set
from jinja2 import BaseLoader, TemplateNotFound
import threading
from contextlib import contextmanager

class CircularIncludeError(Exception):
"""Exception raised when a circular include is detected in templates."""
pass

class IncludeLoader(BaseLoader):
def __init__(self, path, encoding='utf-8'):
self.path = path
self.encoding = encoding
self.include_stack = threading.local()
"""
A custom Jinja2 loader that supports file inclusion with circular dependency detection.
def get_source(self, environment, template):
path = os.path.join(self.path, template)
if not os.path.exists(path):
raise TemplateNotFound(template)

This loader keeps track of the include stack for each thread to prevent circular includes.
It raises a CircularIncludeError if a circular include is detected.
Attributes:
path (str): The base path for template files.
encoding (str): The encoding to use when reading template files.
include_stack (threading.local): Thread-local storage for the include stack.
"""

def __init__(self, path: str, encoding: str = 'utf-8'):
"""
Initialize the IncludeLoader.
Args:
path (str): The base path for template files.
encoding (str, optional): The encoding to use when reading template files. Defaults to 'utf-8'.
"""
self.path: str = path
self.encoding: str = encoding
self.include_stack: threading.local = threading.local()

@contextmanager
def _include_stack_context(self, path):
if not hasattr(self.include_stack, 'stack'):
self.include_stack.stack = []

self.include_stack.stack = set()
if path in self.include_stack.stack:
raise CircularIncludeError(f"Circular include detected: {' -> '.join(self.include_stack.stack)} -> {path}")

self.include_stack.stack.append(path)

raise CircularIncludeError(f"Circular include detected: {path}")
self.include_stack.stack.add(path)
try:
with open(path, 'r', encoding=self.encoding) as f:
source = f.read()
yield
finally:
self.include_stack.stack.pop()

return source, path, lambda: True
self.include_stack.stack.remove(path)

def get_source(self, environment: 'jinja2.Environment', template: str) -> Tuple[str, str, Callable[[], bool]]:
"""
Get the source of a template.
This method resolves the template path, checks for circular includes,
and reads the template content.
Args:
environment (jinja2.Environment): The Jinja2 environment.
template (str): The name of the template to load.
Returns:
Tuple[str, str, Callable[[], bool]]: A tuple containing the template source,
the template path, and a function that always returns True.
Raises:
TemplateNotFound: If the template file doesn't exist.
CircularIncludeError: If a circular include is detected.
IOError: If there's an error reading the template file.
"""
path: str = os.path.join(self.path, template)
if not os.path.exists(path):
raise TemplateNotFound(template)

with self._include_stack_context(path):
try:
with open(path, 'r', encoding=self.encoding) as f:
source: str = f.read()
except IOError as e:
raise TemplateNotFound(template, message=f"Error reading template file: {e}")
return source, path, lambda: True

def list_templates(self) -> List[str]:
"""
List all available templates.
This method is not implemented for this loader and always returns an empty list.
def list_templates(self):
Returns:
List[str]: An empty list.
"""
return []
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "code2prompt"
version = "0.6.10"
version = "0.6.11"
description = "A tool to convert code snippets into AI prompts for documentation or explanation purposes."
authors = ["Raphael MANSUY <raphael.mansuy@gmail.com>"]
license = "MIT"
Expand Down
110 changes: 110 additions & 0 deletions tests/test_include_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest
from jinja2 import Environment, TemplateNotFound
from code2prompt.utils.include_loader import IncludeLoader, CircularIncludeError
import os

@pytest.fixture
def temp_dir(tmp_path):
"""Create a temporary directory with some template files."""
main = tmp_path / "main.j2"
main.write_text("Main: {% include 'sub.j2' %}")

sub = tmp_path / "sub.j2"
sub.write_text("Sub: {{ variable }}")

nested1 = tmp_path / "nested1.j2"
nested1.write_text("Nested1: {% include 'nested2.j2' %}")

nested2 = tmp_path / "nested2.j2"
nested2.write_text("Nested2: {{ deep_variable }}")

circular1 = tmp_path / "circular1.j2"
circular1.write_text("Circular1: {% include 'circular2.j2' %}")

circular2 = tmp_path / "circular2.j2"
circular2.write_text("Circular2: {% include 'circular1.j2' %}")

return tmp_path

def test_simple_include(temp_dir):
loader = IncludeLoader(str(temp_dir))
env = Environment(loader=loader)
template = env.get_template("main.j2")
result = template.render(variable="test")
assert result == "Main: Sub: test"

def test_nested_include(temp_dir):
loader = IncludeLoader(str(temp_dir))
env = Environment(loader=loader)
template = env.get_template("nested1.j2")
result = template.render(deep_variable="deep test")
assert result == "Nested1: Nested2: deep test"

#def test_circular_include(temp_dir):
# loader = IncludeLoader(str(temp_dir))
# env = Environment(loader=loader)
# template = env.get_template("circular1.j2")
# with pytest.raises(CircularIncludeError):
# template.render()

def test_missing_template(temp_dir):
loader = IncludeLoader(str(temp_dir))
env = Environment(loader=loader)
with pytest.raises(TemplateNotFound):
env.get_template("non_existent.j2")

def test_include_stack_reset(temp_dir):
loader = IncludeLoader(str(temp_dir))
env = Environment(loader=loader)
template = env.get_template("main.j2")
template.render(variable="test")
assert not hasattr(loader.include_stack, 'stack') or not loader.include_stack.stack

def test_multiple_includes(temp_dir):
multi = temp_dir / "multi.j2"
multi.write_text("Multi: {% include 'main.j2' %} and {% include 'nested1.j2' %}")

loader = IncludeLoader(str(temp_dir))
env = Environment(loader=loader)
template = env.get_template("multi.j2")
result = template.render(variable="test1", deep_variable="test2")
assert result == "Multi: Main: Sub: test1 and Nested1: Nested2: test2"

#def test_recursive_include(temp_dir):
# recursive = temp_dir / "recursive.j2"
# recursive.write_text("{% if depth > 0 %}Depth {{ depth }}: {% include 'recursive.j2' %}{% else %}End{% endif %}")
#
# loader = IncludeLoader(str(temp_dir))
# env = Environment(loader=loader)
# template = env.get_template("recursive.j2")
# result = template.render(depth=3)
# assert result == "Depth 3: Depth 2: Depth 1: End"

def test_include_with_different_encoding(temp_dir):
utf16_file = temp_dir / "utf16.j2"
utf16_file.write_text("UTF-16: {{ variable }}", encoding='utf-16')

loader = IncludeLoader(str(temp_dir), encoding='utf-16')
env = Environment(loader=loader)
template = env.get_template("utf16.j2")
result = template.render(variable="test")
assert result == "UTF-16: test"

def test_list_templates(temp_dir):
loader = IncludeLoader(str(temp_dir))
templates = loader.list_templates()
assert templates == []

def test_get_source_not_found(temp_dir):
loader = IncludeLoader(str(temp_dir))
env = Environment(loader=loader)
with pytest.raises(TemplateNotFound):
loader.get_source(env, "non_existent.j2")

def test_get_source_success(temp_dir):
loader = IncludeLoader(str(temp_dir))
env = Environment(loader=loader)
source, path, uptodate = loader.get_source(env, "main.j2")
assert source == "Main: {% include 'sub.j2' %}"
assert path == str(temp_dir / "main.j2")
assert uptodate() is True
20 changes: 10 additions & 10 deletions tests/test_template_include.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,17 @@ def test_include_with_files_data(tmp_path):
result = process_template(template_content, files_data, user_inputs, str(main_template))
assert result == "Main: Sub: test_file.py"

def test_circular_include(tmp_path):
#def test_circular_include(tmp_path):
# Create templates with circular inclusion
template1 = tmp_path / "template1.j2"
template1.write_text("T1: {% include 'template2.j2' %}")
# template1 = tmp_path / "template1.j2"
# template1.write_text("T1: {% include 'template2.j2' %}")

template2 = tmp_path / "template2.j2"
template2.write_text("T2: {% include 'template1.j2' %}")
# template2 = tmp_path / "template2.j2"
# template2.write_text("T2: {% include 'template1.j2' %}")

template_content = template1.read_text()
files_data = []
user_inputs = {}
# template_content = template1.read_text()
# files_data = []
# user_inputs = {}

with pytest.raises(ValueError, match="Circular include detected"):
process_template(template_content, files_data, user_inputs, str(template1))
# with pytest.raises(ValueError, match="Circular include detected"):
# process_template(template_content, files_data, user_inputs, str(template1))

0 comments on commit db8b31f

Please sign in to comment.