11from __future__ import annotations
22
3- import ast
4- import os
5- import re
6- from collections import defaultdict
7- from typing import TYPE_CHECKING
8-
9- import jedi
10- import tiktoken
113from jedi .api .classes import Name
12-
13- from codeflash .cli_cmds .console import logger
14- from codeflash .code_utils .code_extractor import get_code
154from codeflash .code_utils .code_utils import (
165 get_qualified_name ,
17- module_name_from_file_path ,
18- path_belongs_to_site_packages ,
19- )
20- from codeflash .discovery .functions_to_optimize import FunctionToOptimize
21- from codeflash .models .models import FunctionParent , FunctionSource
22-
23- if TYPE_CHECKING :
24- from pathlib import Path
256
7+ )
268
279def belongs_to_method (name : Name , class_name : str , method_name : str ) -> bool :
2810 """Check if the given name belongs to the specified method."""
@@ -58,242 +40,4 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b
5840 return get_qualified_name (name .module_name , name .full_name ) == qualified_function_name
5941 return False
6042 except ValueError :
61- return False
62-
63- #
64- # def get_type_annotation_context(
65- # function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path
66- # ) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
67- # function_name: str = function.function_name
68- # file_path: Path = function.file_path
69- # file_contents: str = file_path.read_text(encoding="utf8")
70- # try:
71- # module: ast.Module = ast.parse(file_contents)
72- # except SyntaxError as e:
73- # logger.exception(f"get_type_annotation_context - Syntax error in code: {e}")
74- # return [], set()
75- # sources: list[FunctionSource] = []
76- # ast_parents: list[FunctionParent] = []
77- # contextual_dunder_methods = set()
78- #
79- # def get_annotation_source(
80- # j_script: jedi.Script, name: str, node_parents: list[FunctionParent], line_no: int, col_no: str
81- # ) -> None:
82- # try:
83- # definition: list[Name] = j_script.goto(
84- # line=line_no, column=col_no, follow_imports=True, follow_builtin_imports=False
85- # )
86- # except Exception as ex:
87- # if hasattr(name, "full_name"):
88- # logger.exception(f"Error while getting definition for {name.full_name}: {ex}")
89- # else:
90- # logger.exception(f"Error while getting definition: {ex}")
91- # definition = []
92- # if definition: # TODO can be multiple definitions
93- # definition_path = definition[0].module_path
94- #
95- # # The definition is part of this project and not defined within the original function
96- # if (
97- # str(definition_path).startswith(str(project_root_path) + os.sep)
98- # and definition[0].full_name
99- # and not path_belongs_to_site_packages(definition_path)
100- # and not belongs_to_function(definition[0], function_name)
101- # ):
102- # source_code = get_code([FunctionToOptimize(definition[0].name, definition_path, node_parents[:-1])])
103- # if source_code[0]:
104- # sources.append(
105- # FunctionSource(
106- # fully_qualified_name=definition[0].full_name,
107- # jedi_definition=definition[0],
108- # source_code=source_code[0],
109- # file_path=definition_path,
110- # qualified_name=definition[0].full_name.removeprefix(definition[0].module_name + "."),
111- # only_function_name=definition[0].name,
112- # )
113- # )
114- # contextual_dunder_methods.update(source_code[1])
115- #
116- # def visit_children(
117- # node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, node_parents: list[FunctionParent]
118- # ) -> None:
119- # child: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module
120- # for child in ast.iter_child_nodes(node):
121- # visit(child, node_parents)
122- #
123- # def visit_all_annotation_children(
124- # node: ast.Subscript | ast.Name | ast.BinOp, node_parents: list[FunctionParent]
125- # ) -> None:
126- # if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
127- # visit_all_annotation_children(node.left, node_parents)
128- # visit_all_annotation_children(node.right, node_parents)
129- # if isinstance(node, ast.Name) and hasattr(node, "id"):
130- # name: str = node.id
131- # line_no: int = node.lineno
132- # col_no: int = node.col_offset
133- # get_annotation_source(jedi_script, name, node_parents, line_no, col_no)
134- # if isinstance(node, ast.Subscript):
135- # if hasattr(node, "slice"):
136- # if isinstance(node.slice, ast.Subscript):
137- # visit_all_annotation_children(node.slice, node_parents)
138- # elif isinstance(node.slice, ast.Tuple):
139- # for elt in node.slice.elts:
140- # if isinstance(elt, (ast.Name, ast.Subscript)):
141- # visit_all_annotation_children(elt, node_parents)
142- # elif isinstance(node.slice, ast.Name):
143- # visit_all_annotation_children(node.slice, node_parents)
144- # if hasattr(node, "value"):
145- # visit_all_annotation_children(node.value, node_parents)
146- #
147- # def visit(
148- # node: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module,
149- # node_parents: list[FunctionParent],
150- # ) -> None:
151- # if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
152- # if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
153- # if node.name == function_name and node_parents == function.parents:
154- # arg: ast.arg
155- # for arg in node.args.args:
156- # if arg.annotation:
157- # visit_all_annotation_children(arg.annotation, node_parents)
158- # if node.returns:
159- # visit_all_annotation_children(node.returns, node_parents)
160- #
161- # if not isinstance(node, ast.Module):
162- # node_parents.append(FunctionParent(node.name, type(node).__name__))
163- # visit_children(node, node_parents)
164- # if not isinstance(node, ast.Module):
165- # node_parents.pop()
166- #
167- # visit(module, ast_parents)
168- #
169- # return sources, contextual_dunder_methods
170-
171-
172- # def get_function_variables_definitions(
173- # function_to_optimize: FunctionToOptimize, project_root_path: Path
174- # ) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
175- # function_name = function_to_optimize.function_name
176- # file_path = function_to_optimize.file_path
177- # script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
178- # sources: list[FunctionSource] = []
179- # contextual_dunder_methods = set()
180- # # TODO: The function name condition can be stricter so that it does not clash with other class names etc.
181- # # TODO: The function could have been imported as some other name,
182- # # we should be checking for the translation as well. Also check for the original function name.
183- # names = []
184- # for ref in script.get_names(all_scopes=True, definitions=False, references=True):
185- # if ref.full_name:
186- # if function_to_optimize.parents:
187- # # Check if the reference belongs to the specified class when FunctionParent is provided
188- # if belongs_to_method(ref, function_to_optimize.parents[-1].name, function_name):
189- # names.append(ref)
190- # elif belongs_to_function(ref, function_name):
191- # names.append(ref)
192- #
193- # for name in names:
194- # try:
195- # definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
196- # except Exception as e:
197- # try:
198- # logger.exception(f"Error while getting definition for {name.full_name}: {e}")
199- # except Exception as e:
200- # # name.full_name can also throw exceptions sometimes
201- # logger.exception(f"Error while getting definition: {e}")
202- # definitions = []
203- # if definitions:
204- # # TODO: there can be multiple definitions, see how to handle such cases
205- # definition = definitions[0]
206- # definition_path = definition.module_path
207- #
208- # # The definition is part of this project and not defined within the original function
209- # if (
210- # str(definition_path).startswith(str(project_root_path) + os.sep)
211- # and not path_belongs_to_site_packages(definition_path)
212- # and definition.full_name
213- # and not belongs_to_function(definition, function_name)
214- # ):
215- # module_name = module_name_from_file_path(definition_path, project_root_path)
216- # m = re.match(rf"{module_name}\.(.*)\.{definitions[0].name}", definitions[0].full_name)
217- # parents = []
218- # if m:
219- # parents = [FunctionParent(m.group(1), "ClassDef")]
220- #
221- # source_code = get_code(
222- # [FunctionToOptimize(function_name=definitions[0].name, file_path=definition_path, parents=parents)]
223- # )
224- # if source_code[0]:
225- # sources.append(
226- # FunctionSource(
227- # fully_qualified_name=definition.full_name,
228- # jedi_definition=definition,
229- # source_code=source_code[0],
230- # file_path=definition_path,
231- # qualified_name=definition.full_name.removeprefix(definition.module_name + "."),
232- # only_function_name=definition.name,
233- # )
234- # )
235- # contextual_dunder_methods.update(source_code[1])
236- # annotation_sources, annotation_dunder_methods = get_type_annotation_context(
237- # function_to_optimize, script, project_root_path
238- # )
239- # sources[:0] = annotation_sources # prepend the annotation sources
240- # contextual_dunder_methods.update(annotation_dunder_methods)
241- # existing_fully_qualified_names = set()
242- # no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(lambda: defaultdict(set))
243- # parent_sources = set()
244- # for source in sources:
245- # if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names:
246- # if not source.qualified_name.count("."):
247- # no_parent_sources[source.file_path][source.qualified_name].add(source)
248- # else:
249- # parent_sources.add(source)
250- # existing_fully_qualified_names.add(fully_qualified_name)
251- # deduped_parent_sources = [
252- # source
253- # for source in parent_sources
254- # if source.file_path not in no_parent_sources
255- # or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path]
256- # ]
257- # deduped_no_parent_sources = [
258- # source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2]
259- # ]
260- # return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods
261- #
262- #
263- # MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k
264- #
265- #
266- # def get_constrained_function_context_and_helper_functions(
267- # function_to_optimize: FunctionToOptimize,
268- # project_root_path: Path,
269- # code_to_optimize: str,
270- # max_tokens: int = MAX_PROMPT_TOKENS,
271- # ) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]:
272- # helper_functions, dunder_methods = get_function_variables_definitions(function_to_optimize, project_root_path)
273- # tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
274- # code_to_optimize_tokens = tokenizer.encode(code_to_optimize)
275- #
276- # if not function_to_optimize.parents:
277- # helper_functions_sources = [function.source_code for function in helper_functions]
278- # else:
279- # helper_functions_sources = [
280- # function.source_code
281- # for function in helper_functions
282- # if not function.qualified_name.count(".")
283- # or function.qualified_name.split(".")[0] != function_to_optimize.parents[0].name
284- # ]
285- # helper_functions_tokens = [len(tokenizer.encode(function)) for function in helper_functions_sources]
286- #
287- # context_list = []
288- # context_len = len(code_to_optimize_tokens)
289- # logger.debug(f"ORIGINAL CODE TOKENS LENGTH: {context_len}")
290- # logger.debug(f"ALL DEPENDENCIES TOKENS LENGTH: {sum(helper_functions_tokens)}")
291- # for function_source, source_len in zip(helper_functions_sources, helper_functions_tokens):
292- # if context_len + source_len <= max_tokens:
293- # context_list.append(function_source)
294- # context_len += source_len
295- # else:
296- # break
297- # logger.debug(f"FINAL OPTIMIZATION CONTEXT TOKENS LENGTH: {context_len}")
298- # helper_code: str = "\n".join(context_list)
299- # return helper_code, helper_functions, dunder_methods
43+ return False
0 commit comments