Skip to content

Commit 0090105

Browse files
CrazyDubyaCopilot
andauthored
Add set comprehension support (#3)
* Add SetComp support and tests * Update src/converter/code_generator_fixed.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent d41a8ca commit 0090105

File tree

4 files changed

+92
-1
lines changed

4 files changed

+92
-1
lines changed

src/analyzer/code_analyzer_fixed.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,10 @@ def _infer_expression_type(self, node: ast.AST) -> str:
485485
elt_type = self._infer_expression_type(node.elts[0])
486486
return f'std::set<{elt_type}>'
487487
return 'std::set<int>'
488+
elif isinstance(node, ast.SetComp):
489+
# Infer type from the element expression of the comprehension
490+
elt_type = self._infer_expression_type(node.elt)
491+
return f'std::set<{elt_type}>'
488492
elif isinstance(node, ast.Tuple):
489493
if node.elts:
490494
elt_types = []

src/converter/code_generator_fixed.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,28 @@ def _translate_expression(self, node: ast.AST, local_vars: Dict[str, str]) -> st
11041104
value_type = self._infer_cpp_type(node.values[0], local_vars)
11051105

11061106
return f"std::map<{key_type}, {value_type}>{{{', '.join(pairs)}}}"
1107+
elif isinstance(node, ast.SetComp):
1108+
# Translate set comprehension using a lambda that fills a std::set
1109+
comp = node.generators[0]
1110+
iter_expr = self._translate_expression(comp.iter, local_vars)
1111+
target = self._translate_expression(comp.target, local_vars)
1112+
element_expr = self._translate_expression(node.elt, local_vars)
1113+
elem_type = self._infer_cpp_type(node.elt, local_vars)
1114+
conditions = ''
1115+
if comp.ifs:
1116+
conds = ' && '.join(f"({self._translate_expression(c, local_vars)})" for c in comp.ifs)
1117+
conditions = f"if ({conds}) "
1118+
1119+
lines = [
1120+
"[&]() {",
1121+
f" std::set<{elem_type}> _set;",
1122+
f" for (auto {target} : {iter_expr}) {{",
1123+
f" {conditions} _set.insert({element_expr});",
1124+
" }",
1125+
" return _set;",
1126+
"}()",
1127+
]
1128+
return "\n".join(lines)
11071129
elif isinstance(node, ast.Tuple):
11081130
# Handle tuple literals
11091131
elements = [self._translate_expression(elt, local_vars) for elt in node.elts]
@@ -1257,6 +1279,15 @@ def _infer_cpp_type(self, node: ast.AST, local_vars: Dict[str, str]) -> str:
12571279
return f"std::map<{key_type}, {value_type}>"
12581280
else:
12591281
return "std::map<std::string, int>"
1282+
elif isinstance(node, ast.Set):
1283+
if node.elts:
1284+
element_type = self._infer_cpp_type(node.elts[0], local_vars)
1285+
return f"std::set<{element_type}>"
1286+
else:
1287+
return "std::set<int>"
1288+
elif isinstance(node, ast.SetComp):
1289+
element_type = self._infer_cpp_type(node.elt, local_vars)
1290+
return f"std::set<{element_type}>"
12601291
elif isinstance(node, ast.Tuple):
12611292
if node.elts:
12621293
element_types = [self._infer_cpp_type(elt, local_vars) for elt in node.elts]

tests/test_code_analyzer_fixed.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,24 @@ def test_inference_expressions(self):
310310
values=[ast.Constant(value=True), ast.Constant(value=False)]
311311
)
312312
assert analyzer._infer_expression_type(bool_op) == 'bool'
313+
314+
def test_set_comprehension_inference(self):
315+
"""Ensure set comprehensions are inferred as std::set."""
316+
analyzer = CodeAnalyzer()
317+
318+
comp = ast.SetComp(
319+
elt=ast.Name(id='x', ctx=ast.Load()),
320+
generators=[
321+
ast.comprehension(
322+
target=ast.Name(id='x', ctx=ast.Store()),
323+
iter=ast.Call(func=ast.Name(id='range', ctx=ast.Load()), args=[ast.Constant(value=5)], keywords=[]),
324+
ifs=[],
325+
is_async=0
326+
)
327+
]
328+
)
329+
330+
assert analyzer._infer_expression_type(comp) == 'std::set<int>'
313331

314332
def test_type_annotation_handling(self):
315333
"""Test handling of Python type annotations."""

tests/test_conversion_fixed.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22
from pathlib import Path
3+
import tempfile
4+
import os
35
from src.analyzer.code_analyzer_fixed import CodeAnalyzer
46
from src.rules.rule_manager import RuleManager
57
from src.rules.basic_rules import (
@@ -56,4 +58,40 @@ def test_fibonacci_conversion(tmp_path):
5658
# Verify CMake content
5759
cmake_content = (output_dir / "CMakeLists.txt").read_text()
5860
assert "cmake_minimum_required" in cmake_content
59-
assert "project(pytocpp_generated)" in cmake_content
61+
assert "project(pytocpp_generated)" in cmake_content
62+
63+
64+
def test_set_comprehension_translation(tmp_path):
65+
analyzer = CodeAnalyzer()
66+
rule_manager = RuleManager()
67+
68+
rule_manager.register_rule(VariableDeclarationRule())
69+
rule_manager.register_rule(FunctionDefinitionRule())
70+
rule_manager.register_rule(ClassDefinitionRule())
71+
72+
generator = CodeGenerator(rule_manager)
73+
74+
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as temp:
75+
temp.write(
76+
"def make_set(n):\n"
77+
" return {i * 2 for i in range(n)}\n"
78+
)
79+
temp_path = Path(temp.name)
80+
81+
try:
82+
analysis_result = analyzer.analyze_file(temp_path)
83+
84+
rule_manager.set_context({
85+
'type_info': analysis_result.type_info,
86+
'performance_bottlenecks': analysis_result.performance_bottlenecks,
87+
'memory_usage': analysis_result.memory_usage,
88+
'hot_paths': analysis_result.hot_paths
89+
})
90+
91+
output_dir = tmp_path / "generated_set"
92+
generator.generate_code(analysis_result, output_dir)
93+
94+
impl_content = (output_dir / "generated.cpp").read_text()
95+
assert "std::set<int> _set" in impl_content
96+
finally:
97+
os.unlink(temp_path)

0 commit comments

Comments
 (0)