From bd609ce6d3a54ccce9dcb0c845dfee8ba18af4e1 Mon Sep 17 00:00:00 2001 From: PCSwingle Date: Fri, 12 Apr 2024 13:42:29 -0500 Subject: [PATCH] Update spice remove cost tracker (#563) --- benchmarks/arg_parser.py | 2 +- benchmarks/benchmark_runner.py | 8 +- benchmarks/exercism_practice.py | 5 +- benchmarks/run_sample.py | 5 +- dev-requirements.txt | 2 +- docs/source/developer/mentat.rst | 8 - mentat/agent_handler.py | 8 +- mentat/auto_completer.py | 17 +- mentat/code_context.py | 20 +- mentat/code_feature.py | 7 +- mentat/code_file_manager.py | 2 +- mentat/command/command.py | 2 +- mentat/command/commands/agent.py | 2 +- mentat/command/commands/config.py | 2 +- mentat/command/commands/screenshot.py | 6 +- mentat/command/commands/talk.py | 14 +- mentat/config.py | 13 +- mentat/conversation.py | 29 +-- mentat/cost_tracker.py | 71 ------- mentat/diff_context.py | 4 +- mentat/llm_api_handler.py | 259 ++++---------------------- mentat/python_client/client.py | 4 - mentat/revisor/revisor.py | 7 +- mentat/sampler/sampler.py | 2 +- mentat/session.py | 6 +- mentat/session_context.py | 2 - mentat/splash_messages.py | 2 +- mentat/terminal/client.py | 2 +- mentat/vision/vision_manager.py | 2 +- requirements.txt | 4 +- scripts/git_log_to_transcripts.py | 4 +- scripts/sampler/__main__.py | 9 +- scripts/select_git_transcripts.py | 4 +- tests/benchmark_test.py | 2 +- tests/code_context_test.py | 3 +- tests/commands_test.py | 4 +- tests/conftest.py | 9 +- tests/llm_api_handler_test.py | 33 ---- tests/sampler_test.py | 2 +- 39 files changed, 143 insertions(+), 444 deletions(-) delete mode 100644 mentat/cost_tracker.py delete mode 100644 tests/llm_api_handler_test.py diff --git a/benchmarks/arg_parser.py b/benchmarks/arg_parser.py index 93155a064..a57dea444 100644 --- a/benchmarks/arg_parser.py +++ b/benchmarks/arg_parser.py @@ -19,7 +19,7 @@ def common_benchmark_parser(): "--benchmarks", nargs="*", default=[], - help=("Which benchmarks to run. max_benchmarks ignored when set. Exact meaning" " depends on benchmark."), + help=("Which benchmarks to run. max_benchmarks ignored when set. Exact meaning depends on benchmark."), ) parser.add_argument( "--directory", diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index c2647c242..335436e2c 100755 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -13,6 +13,7 @@ from openai.types.chat.completion_create_params import ResponseFormat from spice import SpiceMessage +from spice.spice import get_model_from_name from benchmarks.arg_parser import common_benchmark_parser from benchmarks.benchmark_result import BenchmarkResult @@ -22,7 +23,6 @@ from benchmarks.swe_bench_runner import SWE_BENCH_SAMPLES_DIR, get_swe_samples from mentat.config import Config from mentat.git_handler import get_git_diff, get_mentat_branch, get_mentat_hexsha -from mentat.llm_api_handler import model_context_size, prompt_tokens from mentat.sampler.sample import Sample from mentat.sampler.utils import setup_repo from mentat.session_context import SESSION_CONTEXT @@ -45,12 +45,13 @@ def git_diff_from_comparison_commit(sample: Sample, comparison_commit: str) -> s async def grade(to_grade, prompt, model="gpt-4-1106-preview"): try: + llm_api_handler = SESSION_CONTEXT.get().llm_api_handler messages: List[SpiceMessage] = [ {"role": "system", "content": prompt}, {"role": "user", "content": to_grade}, ] - tokens = prompt_tokens(messages, model) - max_tokens = model_context_size(model) - 1000 # Response buffer + tokens = llm_api_handler.spice.count_prompt_tokens(messages, model) + max_tokens = get_model_from_name(model).context_length - 1000 # Response buffer if tokens > max_tokens: print("Prompt too long! Truncating... (this may affect results)") tokens_to_remove = tokens - max_tokens @@ -58,7 +59,6 @@ async def grade(to_grade, prompt, model="gpt-4-1106-preview"): chars_to_remove = int(chars_per_token * tokens_to_remove) messages[1]["content"] = messages[1]["content"][:-chars_to_remove] - llm_api_handler = SESSION_CONTEXT.get().llm_api_handler llm_grade = await llm_api_handler.call_llm_api(messages, model, None, False, ResponseFormat(type="json_object")) content = llm_grade.text return json.loads(content) diff --git a/benchmarks/exercism_practice.py b/benchmarks/exercism_practice.py index 00474ce9b..f112facdf 100755 --- a/benchmarks/exercism_practice.py +++ b/benchmarks/exercism_practice.py @@ -114,13 +114,12 @@ async def run_exercise(problem_dir, language="python", max_iterations=2): messages = client.get_conversation().literal_messages await client.shutdown() passed = exercise_runner.passed() - cost_tracker = SESSION_CONTEXT.get().cost_tracker result = BenchmarkResult( iterations=iterations, passed=passed, name=exercise_runner.name, - tokens=cost_tracker.total_tokens, - cost=cost_tracker.total_cost, + tokens=None, + cost=SESSION_CONTEXT.get().llm_api_handler.spice.total_cost / 100, transcript={"id": problem_dir, "messages": messages}, ) if had_error: diff --git a/benchmarks/run_sample.py b/benchmarks/run_sample.py index 738d94c9e..da89f4305 100644 --- a/benchmarks/run_sample.py +++ b/benchmarks/run_sample.py @@ -78,7 +78,6 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None, config: Conf await mentat.startup() session_context = SESSION_CONTEXT.get() conversation = session_context.conversation - cost_tracker = session_context.cost_tracker for msg in sample.message_history: if msg["role"] == "user": conversation.add_user_message(msg["content"]) @@ -127,8 +126,8 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None, config: Conf "id": sample.id, "message_eval": message_eval, "diff_eval": diff_eval, - "cost": cost_tracker.total_cost, - "tokens": cost_tracker.total_tokens, + "cost": session_context.llm_api_handler.spice.total_cost / 100, + "tokens": None, "transcript": { "id": sample.id, "messages": transcript_messages, diff --git a/dev-requirements.txt b/dev-requirements.txt index 401bf3c8e..3e8321097 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -5,6 +5,6 @@ gitpython==3.1.41 isort==5.12.0 pip-licenses==4.3.3 plotly==5.18.0 -pyright==1.1.339 +pyright==1.1.358 pytest-xdist==3.3.1 ruff==0.0.292 diff --git a/docs/source/developer/mentat.rst b/docs/source/developer/mentat.rst index bed64e413..f9c6d1f1b 100644 --- a/docs/source/developer/mentat.rst +++ b/docs/source/developer/mentat.rst @@ -97,14 +97,6 @@ mentat.conversation module :undoc-members: :show-inheritance: -mentat.cost\_tracker module ---------------------------- - -.. automodule:: mentat.cost_tracker - :members: - :undoc-members: - :show-inheritance: - mentat.diff\_context module --------------------------- diff --git a/mentat/agent_handler.py b/mentat/agent_handler.py index a21bc9c91..203725cb3 100644 --- a/mentat/agent_handler.py +++ b/mentat/agent_handler.py @@ -57,11 +57,11 @@ async def enable_agent_mode(self): self.agent_file_message += f"{path}\n\n{file_contents}" ctx.stream.send( - "The model has chosen these files to help it determine how to test its" " changes:", + "The model has chosen these files to help it determine how to test its changes:", style="info", ) ctx.stream.send("\n".join(str(path) for path in paths)) - ctx.cost_tracker.display_last_api_call() + ctx.llm_api_handler.display_cost_stats(response) messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content)) ctx.conversation.add_transcript_message( @@ -82,7 +82,7 @@ async def _determine_commands(self) -> List[str]: try: # TODO: Should this even be a separate call or should we collect commands in the edit call? response = await ctx.llm_api_handler.call_llm_api(messages, model, ctx.config.provider, False) - ctx.cost_tracker.display_last_api_call() + ctx.llm_api_handler.display_cost_stats(response) except BadRequestError as e: ctx.stream.send(f"Error accessing OpenAI API: {e.message}", style="error") return [] @@ -113,7 +113,7 @@ async def add_agent_context(self) -> bool: run_commands = await ask_yes_no(default_yes=True) if not run_commands: ctx.stream.send( - "Enter a new-line separated list of commands to run, or nothing to" " return control to the user:", + "Enter a new-line separated list of commands to run, or nothing to return control to the user:", style="info", ) commands: list[str] = (await collect_user_input()).data.strip().splitlines() diff --git a/mentat/auto_completer.py b/mentat/auto_completer.py index 581e9c2a2..5a6a60e4b 100644 --- a/mentat/auto_completer.py +++ b/mentat/auto_completer.py @@ -129,8 +129,8 @@ def _find_shlex_last_word_position(self, argument_buffer: str, num_words: int) - lex.whitespace_split = True for _ in range(num_words - 1): lex.get_token() - remaining = list(lex.instream) - return 0 if not remaining else -len(remaining[0]) + remaining = list(lex.instream) # pyright: ignore + return 0 if not remaining else -len(remaining[0]) # pyright: ignore def _command_argument_completion(self, buffer: str) -> List[Completion]: if any(buffer.startswith(space) for space in whitespace): @@ -177,7 +177,7 @@ def _refresh_file_completion(self, file_path: Path): try: lexer = cast(Lexer, guess_lexer_for_filename(file_path, file_content)) except ClassNotFound: - self._file_completions[file_path] = FileCompletion(datetime.utcnow(), set()) + self._file_completions[file_path] = FileCompletion(datetime.now(), set()) return tokens = list(lexer.get_tokens(file_content)) @@ -188,7 +188,7 @@ def _refresh_file_completion(self, file_path: Path): if len(token_value) <= 1: continue filtered_tokens.add(token_value) - self._file_completions[file_path] = FileCompletion(datetime.utcnow(), filtered_tokens) + self._file_completions[file_path] = FileCompletion(datetime.now(), filtered_tokens) def _refresh_all_file_completions(self): ctx = SESSION_CONTEXT.get() @@ -205,7 +205,7 @@ def _refresh_all_file_completions(self): if file_path not in self._file_completions: self._refresh_file_completion(file_path) else: - modified_at = datetime.utcfromtimestamp(os.path.getmtime(file_path)) + modified_at = datetime.fromtimestamp(os.path.getmtime(file_path)) if self._file_completions[file_path].last_updated < modified_at: self._refresh_file_completion(file_path) @@ -216,13 +216,10 @@ def _refresh_all_file_completions(self): self._all_file_completions.add(str(rel_path)) self._all_file_completions.update(file_completion.syntax_fragments) - self._last_refresh_at = datetime.utcnow() + self._last_refresh_at = datetime.now() def get_file_completions(self, buffer: str) -> List[Completion]: - if ( - self._last_refresh_at is None - or (datetime.utcnow() - self._last_refresh_at).seconds > SECONDS_BETWEEN_REFRESH - ): + if self._last_refresh_at is None or (datetime.now() - self._last_refresh_at).seconds > SECONDS_BETWEEN_REFRESH: self._refresh_all_file_completions() if not buffer or buffer[-1] == " ": diff --git a/mentat/code_context.py b/mentat/code_context.py index d675123fa..3990e89d6 100644 --- a/mentat/code_context.py +++ b/mentat/code_context.py @@ -16,7 +16,7 @@ validate_and_format_path, ) from mentat.interval import parse_intervals, split_intervals_from_path -from mentat.llm_api_handler import count_tokens, get_max_tokens +from mentat.llm_api_handler import get_max_tokens from mentat.session_context import SESSION_CONTEXT from mentat.session_stream import SessionStream from mentat.utils import get_relative_path, mentat_dir_path @@ -60,6 +60,7 @@ def __init__( async def refresh_daemon(self): """Call before interacting with context to ensure daemon is up to date.""" + if not hasattr(self, "daemon"): # Daemon is initialized after setup because it needs the embedding_provider. ctx = SESSION_CONTEXT.get() @@ -76,7 +77,9 @@ async def refresh_daemon(self): annotators=annotators, verbose=False, graph_path=graphs_dir / f"ragdaemon-{cwd.name}.json", - spice_client=getattr(llm_api_handler, "spice_client", None), + spice_client=llm_api_handler.spice, + model=ctx.config.embedding_model, + provider=ctx.config.embedding_provider, ) await self.daemon.update() @@ -96,7 +99,7 @@ async def refresh_context_display(self): total_tokens = await ctx.conversation.count_tokens(include_code_message=True) - total_cost = ctx.cost_tracker.total_cost + total_cost = ctx.llm_api_handler.spice.total_cost data = ContextStreamMessage( cwd=str(ctx.cwd), @@ -126,6 +129,7 @@ async def get_code_message( """ session_context = SESSION_CONTEXT.get() config = session_context.config + llm_api_handler = session_context.llm_api_handler model = config.model cwd = session_context.cwd code_file_manager = session_context.code_file_manager @@ -164,10 +168,10 @@ async def get_code_message( # If auto-context, replace the context_builder with a new one if config.auto_context_tokens > 0 and prompt: - meta_tokens = count_tokens("\n".join(header_lines), model, full_message=True) + meta_tokens = llm_api_handler.spice.count_tokens("\n".join(header_lines), model, is_message=True) include_files_message = context_builder.render() - include_files_tokens = count_tokens(include_files_message, model, full_message=False) + include_files_tokens = llm_api_handler.spice.count_tokens(include_files_message, model, is_message=False) tokens_used = prompt_tokens + meta_tokens + include_files_tokens auto_tokens = min( @@ -209,10 +213,10 @@ def get_all_features( cwd = session_context.cwd all_features = list[CodeFeature]() - for _, data in self.daemon.graph.nodes(data=True): # pyright: ignore - if data is None or "type" not in data or "ref" not in data or data["type"] not in {"file", "chunk"}: + for _, data in self.daemon.graph.nodes(data=True): # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + if data is None or "type" not in data or "ref" not in data or data["type"] not in {"file", "chunk"}: # pyright: ignore[reportUnnecessaryComparison] continue - path, interval = split_intervals_from_path(data["ref"]) # pyright: ignore + path, interval = split_intervals_from_path(data["ref"]) intervals = parse_intervals(interval) if not intervals: all_features.append(CodeFeature(cwd / path)) diff --git a/mentat/code_feature.py b/mentat/code_feature.py index ef760f981..1782d8bd5 100644 --- a/mentat/code_feature.py +++ b/mentat/code_feature.py @@ -9,7 +9,6 @@ from mentat.errors import MentatError from mentat.interval import INTERVAL_FILE_END, Interval -from mentat.llm_api_handler import count_tokens from mentat.session_context import SESSION_CONTEXT from mentat.utils import get_relative_path @@ -61,10 +60,12 @@ def __str__(self, cwd: Optional[Path] = None) -> str: def count_feature_tokens(feature: CodeFeature, model: str) -> int: - cwd = SESSION_CONTEXT.get().cwd + ctx = SESSION_CONTEXT.get() + + cwd = ctx.cwd ref = feature.__str__(cwd) document = get_document(ref, cwd) - return count_tokens(document, model, full_message=False) + return ctx.llm_api_handler.spice.count_tokens(document, model, is_message=False) def get_consolidated_feature_refs(features: list[CodeFeature]) -> list[str]: diff --git a/mentat/code_file_manager.py b/mentat/code_file_manager.py index eeb446447..734b2a4b6 100644 --- a/mentat/code_file_manager.py +++ b/mentat/code_file_manager.py @@ -90,7 +90,7 @@ async def write_changes_to_files( if file_edit.is_creation: if file_edit.file_path.exists(): - raise MentatError(f"Model attempted to create file {file_edit.file_path} which" " already exists") + raise MentatError(f"Model attempted to create file {file_edit.file_path} which already exists") self.create_file(file_edit.file_path) elif not file_edit.file_path.exists(): raise MentatError(f"Attempted to edit non-existent file {file_edit.file_path}") diff --git a/mentat/command/command.py b/mentat/command/command.py index 5cc46dbe5..99bb44277 100644 --- a/mentat/command/command.py +++ b/mentat/command/command.py @@ -78,7 +78,7 @@ async def apply(self, *args: str) -> None: stream = session_context.stream stream.send( - f"{self.invalid_name} is not a valid command. Use /help to see a list of" " all valid commands", + f"{self.invalid_name} is not a valid command. Use /help to see a list of all valid commands", style="warning", ) diff --git a/mentat/command/commands/agent.py b/mentat/command/commands/agent.py index e84fc56f2..0843c315b 100644 --- a/mentat/command/commands/agent.py +++ b/mentat/command/commands/agent.py @@ -33,4 +33,4 @@ def argument_autocompletions(cls, arguments: list[str], argument_position: int) @override @classmethod def help_message(cls) -> str: - return "Toggle agent mode. In agent mode Mentat will automatically make changes" " and run commands." + return "Toggle agent mode. In agent mode Mentat will automatically make changes and run commands." diff --git a/mentat/command/commands/config.py b/mentat/command/commands/config.py index a787d9e4b..129eabf85 100644 --- a/mentat/command/commands/config.py +++ b/mentat/command/commands/config.py @@ -28,7 +28,7 @@ async def apply(self, *args: str) -> None: value = args[1] if attr.fields_dict(type(config))[setting].metadata.get("no_midsession_change"): stream.send( - f"Cannot change {setting} mid-session. Please restart" " Mentat to change this setting.", + f"Cannot change {setting} mid-session. Please restart Mentat to change this setting.", style="warning", ) return diff --git a/mentat/command/commands/screenshot.py b/mentat/command/commands/screenshot.py index 7ab84a148..03fa648e5 100644 --- a/mentat/command/commands/screenshot.py +++ b/mentat/command/commands/screenshot.py @@ -19,12 +19,12 @@ async def apply(self, *args: str) -> None: model = config.model if "gpt" in model: - if "vision" not in model: + if not ("vision" in model or ("gpt-4-turbo" in model and "preview" not in model)): stream.send( - "Using a version of gpt that doesn't support images. Changing to" " gpt-4-vision-preview", + "Using a version of gpt that doesn't support images. Changing to gpt-4-turbo", style="warning", ) - config.model = "gpt-4-vision-preview" + config.model = "gpt-4-turbo" else: stream.send( "Can't determine if this model supports vision. Attempting anyway.", diff --git a/mentat/command/commands/talk.py b/mentat/command/commands/talk.py index d54c60573..02c45d78a 100644 --- a/mentat/command/commands/talk.py +++ b/mentat/command/commands/talk.py @@ -43,10 +43,10 @@ async def record(self): self.start_time = default_timer() self.q: queue.Queue[np.ndarray[Any, Any]] = queue.Queue() - with sf.SoundFile( # pyright: ignore[reportUnboundVariable] + with sf.SoundFile( # pyright: ignore[reportPossiblyUnboundVariable] self.file, mode="w", samplerate=RATE, channels=1 ) as file: - with sd.InputStream( # pyright: ignore[reportUnboundVariable] + with sd.InputStream( # pyright: ignore[reportPossiblyUnboundVariable] samplerate=RATE, channels=1, callback=self.callback ): while not self.shutdown.is_set(): @@ -75,9 +75,13 @@ async def apply(self, *args: str) -> None: await recorder.record() ctx.stream.send("Processing audio with whisper...") await asyncio.sleep(0.01) - transcript = await ctx.llm_api_handler.call_whisper_api(recorder.file) - ctx.stream.send(transcript, channel="default_prompt") - ctx.cost_tracker.log_whisper_call_stats(recorder.recording_time) + response = await ctx.llm_api_handler.call_whisper_api(recorder.file) + ctx.stream.send(response.text, channel="default_prompt") + if response.cost: + ctx.stream.send( + f"Whisper audio length and cost: {response.input_length:.2f}s, ${response.cost / 100:.2f}", + style="info", + ) @override @classmethod diff --git a/mentat/config.py b/mentat/config.py index 20d969ec6..97093eb7f 100644 --- a/mentat/config.py +++ b/mentat/config.py @@ -8,9 +8,10 @@ import attr from attr import converters, validators +from spice.models import TextModel, models +from spice.spice import EmbeddingModel from mentat.git_handler import get_git_root_for_path -from mentat.llm_api_handler import known_models from mentat.parsers.parser import Parser from mentat.parsers.parser_map import parser_map from mentat.session_context import SESSION_CONTEXT @@ -37,12 +38,12 @@ class Config: # Model specific settings model: str = attr.field( default="gpt-4-0125-preview", - metadata={"auto_completions": list(known_models.keys())}, + metadata={"auto_completions": [model.name for model in models if isinstance(model, TextModel)]}, ) provider: Optional[str] = attr.field(default=None, metadata={"auto_completions": ["openai", "anthropic", "azure"]}) embedding_model: str = attr.field( default="text-embedding-ada-002", - metadata={"auto_completions": [model.name for model in known_models.values() if model.embedding_model]}, + metadata={"auto_completions": [model.name for model in models if isinstance(model, EmbeddingModel)]}, ) embedding_provider: Optional[str] = attr.field( default=None, metadata={"auto_completions": ["openai", "anthropic", "azure"]} @@ -64,14 +65,14 @@ class Config: token_buffer: int = attr.field( default=1000, metadata={ - "description": ("The amount of tokens to always be reserved as a buffer for user and" " model messages."), + "description": ("The amount of tokens to always be reserved as a buffer for user and model messages."), }, ) parser: Parser = attr.field( # pyright: ignore default="block", metadata={ "description": ( - "The format for the LLM to write code in. You probably don't want to" " mess with this setting." + "The format for the LLM to write code in. You probably don't want to mess with this setting." ), "auto_completions": list(parser_map.keys()), }, @@ -213,7 +214,7 @@ def load_file(self, path: Path) -> None: try: config = json.load(config_file) except JSONDecodeError: - self.error(f"Warning: Config {path} contains invalid json; ignoring user" " configuration file") + self.error(f"Warning: Config {path} contains invalid json; ignoring user configuration file") return for field in config: if hasattr(self, field): diff --git a/mentat/conversation.py b/mentat/conversation.py index 4d17ab041..11bcc8733 100644 --- a/mentat/conversation.py +++ b/mentat/conversation.py @@ -17,9 +17,7 @@ from mentat.llm_api_handler import ( TOKEN_COUNT_WARNING, - count_tokens, get_max_tokens, - prompt_tokens, raise_if_context_exceeds_max, ) from mentat.parsers.file_edit import FileEdit @@ -91,9 +89,11 @@ async def count_tokens( system_prompt: Optional[list[ChatCompletionMessageParam]] = None, include_code_message: bool = False, ) -> int: + ctx = SESSION_CONTEXT.get() + _messages = await self.get_messages(system_prompt=system_prompt, include_code_message=include_code_message) - model = SESSION_CONTEXT.get().config.model - return prompt_tokens(_messages, model) + model = ctx.config.model + return ctx.llm_api_handler.spice.count_prompt_tokens(_messages, model) async def get_messages( self, @@ -126,7 +126,7 @@ async def get_messages( if include_code_message: code_message = await ctx.code_context.get_code_message( - prompt_tokens(_messages, ctx.config.model), + ctx.llm_api_handler.spice.count_prompt_tokens(_messages, ctx.config.model), prompt=( prompt # Prompt can be image as well as text if isinstance(prompt, str) @@ -168,7 +168,6 @@ async def _stream_model_response( config = session_context.config parser = config.parser llm_api_handler = session_context.llm_api_handler - cost_tracker = session_context.cost_tracker stream.send( None, @@ -187,7 +186,7 @@ async def _stream_model_response( terminate=True, ) - num_prompt_tokens = prompt_tokens(messages, config.model) + num_prompt_tokens = llm_api_handler.spice.count_prompt_tokens(messages, config.model) stream.send(f"Total token count: {num_prompt_tokens}", style="info") if num_prompt_tokens > TOKEN_COUNT_WARNING: stream.send( @@ -205,8 +204,7 @@ async def _stream_model_response( for file_edit in parsed_llm_response.file_edits: file_edit.previous_file_lines = code_file_manager.file_lines.get(file_edit.file_path, []).copy() - cost_tracker.log_api_call_stats(response.current_response()) - cost_tracker.display_last_api_call() + llm_api_handler.display_cost_stats(response.current_response()) messages.append( ChatCompletionAssistantMessageParam(role="assistant", content=parsed_llm_response.full_response) @@ -219,9 +217,10 @@ async def get_model_response(self) -> ParsedLLMResponse: session_context = SESSION_CONTEXT.get() stream = session_context.stream config = session_context.config + llm_api_handler = session_context.llm_api_handler messages_snapshot = await self.get_messages(include_code_message=True) - tokens_used = prompt_tokens(messages_snapshot, config.model) + tokens_used = llm_api_handler.spice.count_prompt_tokens(messages_snapshot, config.model) raise_if_context_exceeds_max(tokens_used) try: @@ -238,7 +237,9 @@ async def get_model_response(self) -> ParsedLLMResponse: async def remaining_context(self) -> int | None: ctx = SESSION_CONTEXT.get() - return get_max_tokens() - prompt_tokens(await self.get_messages(), ctx.config.model) + return get_max_tokens() - ctx.llm_api_handler.spice.count_prompt_tokens( + await self.get_messages(), ctx.config.model + ) async def can_add_to_context(self, message: str) -> bool: """ @@ -250,7 +251,9 @@ async def can_add_to_context(self, message: str) -> bool: remaining_context = await self.remaining_context() return ( remaining_context is not None - and remaining_context - count_tokens(message, ctx.config.model, full_message=True) - ctx.config.token_buffer + and remaining_context + - ctx.llm_api_handler.spice.count_tokens(message, ctx.config.model, is_message=True) + - ctx.config.token_buffer > 0 ) @@ -297,7 +300,7 @@ async def run_command(self, command: list[str]) -> bool: return True else: ctx.stream.send( - "Not enough tokens remaining in model's context to add command output" " to model context.", + "Not enough tokens remaining in model's context to add command output to model context.", style="error", ) return False diff --git a/mentat/cost_tracker.py b/mentat/cost_tracker.py deleted file mode 100644 index 8b886201e..000000000 --- a/mentat/cost_tracker.py +++ /dev/null @@ -1,71 +0,0 @@ -import logging -from dataclasses import dataclass - -from spice import SpiceResponse - -from mentat.llm_api_handler import model_price_per_1000_tokens -from mentat.session_context import SESSION_CONTEXT - - -@dataclass -class CostTracker: - total_tokens: int = 0 - total_cost: float = 0 - - last_api_call: str = "" - - def log_api_call_stats( - self, - response: SpiceResponse, - ) -> None: - decimal_places = 2 - - model = response.call_args.model - input_tokens = response.input_tokens - output_tokens = response.output_tokens - total_time = response.total_time - - speed_and_cost_string = "" - self.total_tokens += response.total_tokens - if output_tokens > 0: - tokens_per_second = output_tokens / total_time - speed_and_cost_string += f"Speed: {tokens_per_second:.{decimal_places}f} tkns/s" - cost = model_price_per_1000_tokens(model) - if cost: - prompt_cost = (input_tokens / 1000) * cost[0] - sampled_cost = (output_tokens / 1000) * cost[1] - call_cost = prompt_cost + sampled_cost - self.total_cost += call_cost - if speed_and_cost_string: - speed_and_cost_string += " | " - speed_and_cost_string += f"Cost: ${call_cost:.{decimal_places}f}" - - costs_logger = logging.getLogger("costs") - costs_logger.info(speed_and_cost_string) - self.last_api_call = speed_and_cost_string - - # TODO Grant: Functionality moved to ragdaemon. Move all cost tracking to Spice? - def log_embedding_call_stats(self, tokens: int, model: str, total_time: float): - cost = model_price_per_1000_tokens(model) - # TODO Scott: handle unknown models better / port to spice - if cost is None: - return - - cost = cost[0] - call_cost = (tokens / 1000) * cost - self.total_cost += call_cost - costs_logger = logging.getLogger("costs") - costs_logger.info(f"Cost: ${call_cost:.2f}") - self.last_api_call = f"Embedding call time and cost: {total_time:.2f}s, ${call_cost:.2f}" - - def display_last_api_call(self): - """ - Used so that places that call the llm can print the api call stats after they finish printing everything else. - The api call will not be logged if it gets interrupted! - """ - ctx = SESSION_CONTEXT.get() - if self.last_api_call: - ctx.stream.send(self.last_api_call, style="info") - - def log_whisper_call_stats(self, seconds: float): - self.total_cost += seconds * 0.0001 diff --git a/mentat/diff_context.py b/mentat/diff_context.py index 96af9f609..f33c46dd2 100644 --- a/mentat/diff_context.py +++ b/mentat/diff_context.py @@ -33,7 +33,7 @@ def __init__( # TODO: Once broadcast queue's unread messages and/or config is moved to client, # determine if this should quit or not stream.send( - "Cannot specify more than one type of diff. Disabling diff and" " pr-diff.", + "Cannot specify more than one type of diff. Disabling diff and pr-diff.", style="warning", ) diff = None @@ -61,7 +61,7 @@ def __init__( if not target: # TODO: Same as above todo stream.send( - f"Cannot identify merge base between HEAD and {pr_diff}. Disabling" " pr-diff.", + f"Cannot identify merge base between HEAD and {pr_diff}. Disabling pr-diff.", style="warning", ) return diff --git a/mentat/llm_api_handler.py b/mentat/llm_api_handler.py index 28d17d2a3..eb309e4b4 100644 --- a/mentat/llm_api_handler.py +++ b/mentat/llm_api_handler.py @@ -1,41 +1,28 @@ from __future__ import annotations -import base64 -import io +import logging import os import sys from inspect import iscoroutinefunction from pathlib import Path from typing import ( - TYPE_CHECKING, Any, Callable, - Dict, List, Literal, Optional, - cast, + TypeVar, overload, ) -import attr import sentry_sdk -import tiktoken from dotenv import load_dotenv -from openai.types.chat import ( - ChatCompletionAssistantMessageParam, - ChatCompletionContentPartParam, - ChatCompletionMessageParam, - ChatCompletionSystemMessageParam, - ChatCompletionUserMessageParam, -) from openai.types.chat.completion_create_params import ResponseFormat -from PIL import Image -from spice import APIConnectionError, Spice, SpiceMessage, SpiceResponse, StreamingSpiceResponse -from spice.errors import NoAPIKeyError +from spice import EmbeddingResponse, Spice, SpiceMessage, SpiceResponse, StreamingSpiceResponse, TranscriptionResponse +from spice.errors import APIConnectionError, NoAPIKeyError from spice.models import WHISPER_1 from spice.providers import OPEN_AI -from spice.spice import InvalidModelError +from spice.spice import UnknownModelError, get_model_from_name from mentat.errors import MentatError, ReturnToUser from mentat.session_context import SESSION_CONTEXT @@ -43,10 +30,6 @@ TOKEN_COUNT_WARNING = 32000 -if TYPE_CHECKING: - # This import is slow - from chromadb.api.types import Embeddings - def is_test_environment(): """Returns True if in pytest and not benchmarks""" @@ -58,7 +41,10 @@ def is_test_environment(): ) -def api_guard(func: Callable[..., Any]) -> Callable[..., Any]: +RetType = TypeVar("RetType") + + +def api_guard(func: Callable[..., RetType]) -> Callable[..., RetType]: """Decorator that should be used on any function that calls the OpenAI API It does two things: @@ -68,28 +54,28 @@ def api_guard(func: Callable[..., Any]) -> Callable[..., Any]: if iscoroutinefunction(func): - async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + async def async_wrapper(*args: Any, **kwargs: Any) -> RetType: assert not is_test_environment(), "OpenAI call attempted in non-benchmark test environment!" try: return await func(*args, **kwargs) except APIConnectionError: raise MentatError("API connection error: please check your internet connection and try again.") - except InvalidModelError: + except UnknownModelError: SESSION_CONTEXT.get().stream.send( "Unknown model. Use /config provider and try again.", style="error" ) raise ReturnToUser() - return async_wrapper + return async_wrapper # pyright: ignore[reportReturnType] else: - def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + def sync_wrapper(*args: Any, **kwargs: Any) -> RetType: assert not is_test_environment(), "OpenAI call attempted in non-benchmark test environment!" try: return func(*args, **kwargs) except APIConnectionError: raise MentatError("API connection error: please check your internet connection and try again.") - except InvalidModelError: + except UnknownModelError: SESSION_CONTEXT.get().stream.send( "Unknown model. Use /config provider and try again.", style="error" ) @@ -103,194 +89,12 @@ def chunk_to_lines(content: str) -> list[str]: return content.splitlines(keepends=True) -def get_encoding_for_model(model: str) -> tiktoken.Encoding: - try: - # OpenAI fine-tuned models are named `ft:::`. If tiktoken - # can't match the full string, it tries to match on startswith, e.g. 'gpt-4' - _model = model.split(":")[1] if model.startswith("ft:") else model - return tiktoken.encoding_for_model(_model) - except KeyError: - return tiktoken.get_encoding("cl100k_base") - - -def count_tokens(message: str, model: str, full_message: bool) -> int: - """ - Calculates the tokens in this message. Will NOT be accurate for a full prompt! - Use prompt_tokens to get the exact amount of tokens for a prompt. - If full_message is true, will include the extra 4 tokens used in a chat completion by this message - if this message is part of a prompt. You do NOT want full_message to be true for a response. - """ - encoding = get_encoding_for_model(model) - return len(encoding.encode(message, disallowed_special=())) + (4 if full_message else 0) - - -def normalize_messages_for_anthropic( - messages: list[ChatCompletionMessageParam], -) -> list[ChatCompletionMessageParam]: - """Claude expects the chat to start with at most one system message and afterwards user and system messages to - alternate. This method consolidates all the system messages at the beginning of the conversation into one system - message delimited by "\n"+"-"*80+"\n and turns future system messages into user messages annotated with "System:" - and combines adjacent assistant or user messages into one assistant or user message. - """ - replace_non_leading_systems = list[ChatCompletionMessageParam]() - for i, message in enumerate(messages): - if message["role"] == "system": - if i == 0 or messages[i - 1]["role"] == "system": - replace_non_leading_systems.append(message) - else: - content = "SYSTEM: " + (message["content"] or "") - replace_non_leading_systems.append(ChatCompletionUserMessageParam(role="user", content=content)) - else: - replace_non_leading_systems.append(message) - - concatenate_adjacent = list[ChatCompletionMessageParam]() - current_role: str = "" - current_content: str = "" - delimiter = "\n" + "-" * 80 + "\n" - for message in replace_non_leading_systems: - if message["role"] == current_role: - current_content += delimiter + str(message["content"]) # type: ignore - else: - if current_role == "user": - concatenate_adjacent.append(ChatCompletionUserMessageParam(role=current_role, content=current_content)) - elif current_role == "system": - concatenate_adjacent.append( - ChatCompletionSystemMessageParam(role=current_role, content=current_content) - ) - elif current_role == "assistant": - concatenate_adjacent.append( - ChatCompletionAssistantMessageParam(role=current_role, content=current_content) - ) - current_role = message["role"] - current_content = str(message["content"]) # type: ignore - - if current_role == "user": - concatenate_adjacent.append(ChatCompletionUserMessageParam(role=current_role, content=current_content)) - elif current_role == "system": - concatenate_adjacent.append(ChatCompletionSystemMessageParam(role=current_role, content=current_content)) - elif current_role == "assistant": - concatenate_adjacent.append(ChatCompletionAssistantMessageParam(role=current_role, content=current_content)) - - return concatenate_adjacent - - -def prompt_tokens(messages: list[ChatCompletionMessageParam], model: str): - """ - Returns the number of tokens used by a prompt if it was sent to OpenAI for a chat completion. - Adapted from https://platform.openai.com/docs/guides/text-generation/managing-tokens - """ - encoding = get_encoding_for_model(model) - - num_tokens = 0 - for message in messages: - # every message follows <|start|>{role/name}\n{content}<|end|>\n - # this has 5 tokens (start token, role, \n, end token, \n), but we count the role token later - num_tokens += 4 - for key, value in message.items(): - if isinstance(value, list) and key == "content": - value = cast(List[ChatCompletionContentPartParam], value) - for entry in value: - if entry["type"] == "text": - num_tokens += len(encoding.encode(entry["text"])) - if entry["type"] == "image_url": - image_base64: str = entry["image_url"]["url"].split(",")[1] - image_bytes: bytes = base64.b64decode(image_base64) - image = Image.open(io.BytesIO(image_bytes)) - size = image.size - # As described here: https://platform.openai.com/docs/guides/vision/calculating-costs - scale = min(1, 2048 / max(size)) - size = (int(size[0] * scale), int(size[1] * scale)) - scale = min(1, 768 / min(size)) - size = (int(size[0] * scale), int(size[1] * scale)) - num_tokens += 85 + 170 * ((size[0] + 511) // 512) * ((size[1] + 511) // 512) - elif isinstance(value, str): - num_tokens += len(encoding.encode(value)) - if key == "name": # if there's a name, the role is omitted - num_tokens -= 1 # role is always required and always 1 token - num_tokens += 2 # every reply is primed with <|start|>assistant - return num_tokens - - -@attr.define -class Model: - name: str = attr.field() - context_size: int = attr.field() - input_cost: float = attr.field() - output_cost: float = attr.field() - embedding_model: bool = attr.field(default=False) - - -class ModelsIndex(Dict[str, Model]): - def __init__(self, models: Dict[str, Model]): - super().update(models) - - def _validate_key(self, key: str) -> str: - """Try to match fine-tuned models to their base models.""" - if not super().__contains__(key) and key.startswith("ft:"): - base_model = key.split(":")[1] # e.g. "ft:gpt-3.5-turbo-1106:abante::8dsQMc4F" - if super().__contains__(base_model): - ctx = SESSION_CONTEXT.get() - ctx.stream.send( - f"Using base model {base_model} for size and cost estimates.", - style="info", - ) - super().__setitem__(key, attr.evolve(super().__getitem__(base_model), name=key)) - return key - return key - - def __getitem__(self, key: str) -> Model: - return super().__getitem__(self._validate_key(key)) - - def __contains__(self, key: object) -> bool: - return super().__contains__(self._validate_key(str(key))) - - -known_models = ModelsIndex( - { - "gpt-4-0125-preview": Model("gpt-4-0125-preview", 128000, 0.01, 0.03), - "gpt-4-1106-preview": Model("gpt-4-1106-preview", 128000, 0.01, 0.03), - "gpt-4-vision-preview": Model("gpt-4-vision-preview", 128000, 0.01, 0.03), - "gpt-4": Model("gpt-4", 8192, 0.03, 0.06), - "gpt-4-32k": Model("gpt-4-32k", 32768, 0.06, 0.12), - "gpt-4-0613": Model("gpt-4-0613", 8192, 0.03, 0.06), - "gpt-4-32k-0613": Model("gpt-4-32k-0613", 32768, 0.06, 0.12), - "gpt-4-0314": Model("gpt-4-0314", 8192, 0.03, 0.06), - "gpt-4-32k-0314": Model("gpt-4-32k-0314", 32768, 0.06, 0.12), - "gpt-3.5-turbo-0125": Model("gpt-3.5-turbo-0125", 16385, 0.0005, 0.0015), - "gpt-3.5-turbo-1106": Model("gpt-3.5-turbo-1106", 16385, 0.001, 0.002), - "gpt-3.5-turbo": Model("gpt-3.5-turbo", 16385, 0.001, 0.002), - "gpt-3.5-turbo-0613": Model("gpt-3.5-turbo-0613", 4096, 0.0015, 0.002), - "gpt-3.5-turbo-16k-0613": Model("gpt-3.5-turbo-16k-0613", 16385, 0.003, 0.004), - "gpt-3.5-turbo-0301": Model("gpt-3.5-turbo-0301", 4096, 0.0015, 0.002), - "text-embedding-ada-002": Model("text-embedding-ada-002", 8191, 0.0001, 0, embedding_model=True), - "claude-3-opus-20240229": Model("claude-3-opus-20240229", 200000, 0.015, 0.075), - "claude-3-sonnet-20240229": Model("claude-3-sonnet-20240229", 200000, 0.003, 0.015), - "claude-3-haiku-20240307": Model("claude-3-haiku-20240307", 200000, 0.00025, 0.00125), - } -) - - -def model_context_size(model: str) -> Optional[int]: - if model not in known_models: - return None - else: - return known_models[model].context_size - - -def model_price_per_1000_tokens(model: str) -> Optional[tuple[float, float]]: - """Returns (input, output) cost per 1000 tokens in USD""" - if model not in known_models: - return None - else: - return known_models[model].input_cost, known_models[model].output_cost - - def get_max_tokens() -> int: session_context = SESSION_CONTEXT.get() stream = session_context.stream config = session_context.config - context_size = model_context_size(config.model) + context_size = get_model_from_name(config.model).context_length maximum_context = config.maximum_context if context_size is not None and maximum_context is not None: @@ -329,14 +133,15 @@ def raise_if_context_exceeds_max(tokens: int): class LlmApiHandler: """Used for any functions that require calling the external LLM API""" + def __init__(self): + self.spice = Spice() + async def initialize_client(self): ctx = SESSION_CONTEXT.get() if not load_dotenv(mentat_dir_path / ".env") and not load_dotenv(ctx.cwd / ".env"): load_dotenv() - self.spice = Spice() - try: self.spice.load_provider(OPEN_AI) except NoAPIKeyError: @@ -384,10 +189,9 @@ async def call_llm_api( ) -> SpiceResponse | StreamingSpiceResponse: session_context = SESSION_CONTEXT.get() config = session_context.config - cost_tracker = session_context.cost_tracker - # Confirm that model has enough tokens remaining. - tokens = prompt_tokens(messages, model) + # Confirm that model has enough tokens remaining + tokens = self.spice.count_prompt_tokens(messages, model) raise_if_context_exceeds_max(tokens) with sentry_sdk.start_span(description="LLM Call") as span: @@ -399,25 +203,36 @@ async def call_llm_api( provider=provider, messages=messages, temperature=config.temperature, - response_format=response_format, # pyright: ignore + response_format=response_format, ) - cost_tracker.log_api_call_stats(response) else: response = await self.spice.stream_response( model=model, provider=provider, messages=messages, temperature=config.temperature, - response_format=response_format, # pyright: ignore + response_format=response_format, ) return response @api_guard - def call_embedding_api(self, input_texts: list[str], model: str = "text-embedding-ada-002") -> Embeddings: + def call_embedding_api(self, input_texts: list[str], model: str = "text-embedding-ada-002") -> EmbeddingResponse: ctx = SESSION_CONTEXT.get() - return self.spice.get_embeddings_sync(input_texts, model, provider=ctx.config.embedding_provider) # pyright: ignore + return self.spice.get_embeddings_sync(input_texts, model, provider=ctx.config.embedding_provider) @api_guard - async def call_whisper_api(self, audio_path: Path) -> str: + async def call_whisper_api(self, audio_path: Path) -> TranscriptionResponse: return await self.spice.get_transcription(audio_path, model=WHISPER_1) + + def display_cost_stats(self, response: SpiceResponse): + ctx = SESSION_CONTEXT.get() + + display = f"Speed: {response.characters_per_second:.2f} char/s" + if response.cost is not None: + display += f" | Cost: ${response.cost / 100:.2f}" + + costs_logger = logging.getLogger("costs") + costs_logger.info(display) + + ctx.stream.send(display, style="info") diff --git a/mentat/python_client/client.py b/mentat/python_client/client.py index 2618ffb54..4ca48ceaa 100644 --- a/mentat/python_client/client.py +++ b/mentat/python_client/client.py @@ -133,7 +133,3 @@ async def _stop(self): def get_conversation(self): """Returns the current conversation context from the session.""" return self.session.ctx.conversation - - def get_cost_tracker(self): - """Returns the cost tracker from the session context.""" - return self.session.ctx.cost_tracker diff --git a/mentat/revisor/revisor.py b/mentat/revisor/revisor.py index 2a8f713b5..fbe7e5296 100644 --- a/mentat/revisor/revisor.py +++ b/mentat/revisor/revisor.py @@ -9,7 +9,6 @@ ) from mentat.errors import MentatError -from mentat.llm_api_handler import prompt_tokens from mentat.parsers.change_display_helper import get_lexer, highlight_text from mentat.parsers.file_edit import FileEdit from mentat.parsers.git_parser import GitParser @@ -59,7 +58,9 @@ async def revise_edit(file_edit: FileEdit): user_message, ChatCompletionSystemMessageParam(content=f"Diff:\n{diff}", role="system"), ] - code_message = await ctx.code_context.get_code_message(prompt_tokens(messages, ctx.config.model)) + code_message = await ctx.code_context.get_code_message( + ctx.llm_api_handler.spice.count_prompt_tokens(messages, ctx.config.model) + ) messages.insert(1, ChatCompletionSystemMessageParam(content=code_message, role="system")) ctx.stream.send( @@ -121,7 +122,7 @@ async def revise_edit(file_edit: FileEdit): for line in diff_diff: send_formatted_string(line) ctx.stream.send("", delimiter=True) - ctx.cost_tracker.display_last_api_call() + ctx.llm_api_handler.display_cost_stats(response) async def revise_edits(file_edits: List[FileEdit]): diff --git a/mentat/sampler/sampler.py b/mentat/sampler/sampler.py index acf381ef6..87092da5a 100644 --- a/mentat/sampler/sampler.py +++ b/mentat/sampler/sampler.py @@ -96,7 +96,7 @@ async def create_sample(self) -> Sample: ).strip() except subprocess.CalledProcessError: pass - stream.send(f"Found repo URL: {remote_url}. Press 'ENTER' to accept, or enter a new" " URL.") + stream.send(f"Found repo URL: {remote_url}. Press 'ENTER' to accept, or enter a new URL.") response = (await collect_user_input()).data.strip() if response == "y": repo = remote_url diff --git a/mentat/session.py b/mentat/session.py index 06c6775d2..0ebb05078 100644 --- a/mentat/session.py +++ b/mentat/session.py @@ -23,7 +23,6 @@ from mentat.code_file_manager import CodeFileManager from mentat.config import Config from mentat.conversation import Conversation -from mentat.cost_tracker import CostTracker from mentat.errors import MentatError, ReturnToUser, SessionExit, UserError from mentat.llm_api_handler import LlmApiHandler, is_test_environment from mentat.logging_config import setup_logging @@ -77,8 +76,6 @@ def __init__( self.stream = stream self.stream.start() - cost_tracker = CostTracker() - code_context = CodeContext(stream, cwd, diff, pr_diff, ignore_paths) code_file_manager = CodeFileManager() @@ -97,7 +94,6 @@ def __init__( cwd, stream, llm_api_handler, - cost_tracker, config, code_context, code_file_manager, @@ -183,7 +179,7 @@ async def _main(self): if agent_handler.agent_enabled: code_file_manager.history.push_edits() stream.send( - "Use /undo to undo all changes from agent mode since last" " input.", + "Use /undo to undo all changes from agent mode since last input.", style="success", ) message = await collect_input_with_commands() diff --git a/mentat/session_context.py b/mentat/session_context.py index 6c88969cb..4b0db7e67 100644 --- a/mentat/session_context.py +++ b/mentat/session_context.py @@ -13,7 +13,6 @@ from mentat.code_file_manager import CodeFileManager from mentat.config import Config from mentat.conversation import Conversation - from mentat.cost_tracker import CostTracker from mentat.llm_api_handler import LlmApiHandler from mentat.sampler.sampler import Sampler from mentat.session_stream import SessionStream @@ -27,7 +26,6 @@ class SessionContext: cwd: Path = attr.field() stream: SessionStream = attr.field() llm_api_handler: LlmApiHandler = attr.field() - cost_tracker: CostTracker = attr.field() config: Config = attr.field() code_context: CodeContext = attr.field() code_file_manager: CodeFileManager = attr.field() diff --git a/mentat/splash_messages.py b/mentat/splash_messages.py index dbfec6d39..2a8c9a4dd 100644 --- a/mentat/splash_messages.py +++ b/mentat/splash_messages.py @@ -81,6 +81,6 @@ def check_model(): ) if "gpt-3.5" not in model and "claude-3" not in model: ctx.stream.send( - "Warning: Mentat does not know how to calculate costs or context" " size for this model.", + "Warning: Mentat does not know how to calculate costs or context size for this model.", style="warning", ) diff --git a/mentat/terminal/client.py b/mentat/terminal/client.py index cd999ac5a..6757b0bf4 100644 --- a/mentat/terminal/client.py +++ b/mentat/terminal/client.py @@ -227,7 +227,7 @@ def get_parser(): "-g", nargs="*", default=[], - help=("List of file paths, directory paths, or glob patterns to ignore in" " auto-context"), + help=("List of file paths, directory paths, or glob patterns to ignore in auto-context"), ) parser.add_argument( "--diff", diff --git a/mentat/vision/vision_manager.py b/mentat/vision/vision_manager.py index cbff240ee..8e1ff2f2c 100644 --- a/mentat/vision/vision_manager.py +++ b/mentat/vision/vision_manager.py @@ -49,7 +49,7 @@ def _open_browser(self) -> None: except Exception: if safari_installed: ctx.stream.send( - "No suitable browser found. To use Safari, enable" " remote automation.", + "No suitable browser found. To use Safari, enable remote automation.", style="error", ) else: diff --git a/requirements.txt b/requirements.txt index 1427f3f0e..d4d9c2544 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,12 +17,12 @@ pytest-mock==3.11.1 pytest-reportlog==0.4.0 pytest-timeout==2.2.0 python-dotenv==1.0.0 -ragdaemon==0.1.3 +ragdaemon==0.1.5 selenium==4.15.2 sentry-sdk==1.34.0 sounddevice==0.4.6 soundfile==0.12.1 -spiceai==0.1.9 +spiceai==0.1.11 termcolor==2.3.0 textual==0.47.1 textual-autocomplete==2.1.0b0 diff --git a/scripts/git_log_to_transcripts.py b/scripts/git_log_to_transcripts.py index b29944c9a..a6602826a 100755 --- a/scripts/git_log_to_transcripts.py +++ b/scripts/git_log_to_transcripts.py @@ -15,7 +15,6 @@ from mentat.code_context import CodeContext from mentat.code_file_manager import CodeFileManager from mentat.config import Config -from mentat.llm_api import CostTracker, count_tokens from mentat.parsers.git_parser import GitParser from mentat.sampler.utils import clone_repo from mentat.session_context import SESSION_CONTEXT, SessionContext @@ -126,7 +125,7 @@ async def translate_commits_to_transcripts(repo, count=10): # Necessary for CodeContext to work repo.git.checkout(commit.parents[0].hexsha) shown = subprocess.check_output(["git", "show", sha, "-m", "--first-parent"]).decode("utf-8") - if count_tokens(shown, "gpt-4") > 6000: + if session_context.llm_api_handler.spice.count_tokens(shown, "gpt-4") > 6000: print("Skipping because too long") continue @@ -219,7 +218,6 @@ async def translate_commits_to_transcripts(repo, count=10): code_context = CodeContext(stream, os.getcwd()) session_context = SessionContext( stream, - CostTracker(), Path.cwd(), config, code_context, diff --git a/scripts/sampler/__main__.py b/scripts/sampler/__main__.py index 39b22dca8..18534971b 100644 --- a/scripts/sampler/__main__.py +++ b/scripts/sampler/__main__.py @@ -13,7 +13,6 @@ from finetune import generate_finetune from validate import validate_sample -from mentat.llm_api_handler import count_tokens, prompt_tokens from mentat.sampler.sample import Sample from mentat.utils import mentat_dir_path @@ -80,11 +79,11 @@ async def main(): elif args.finetune: try: example = await generate_finetune(sample) - # Toktoken only includes encoding for openAI models, so this isn't always correct + spice = Spice() if "messages" in example: - tokens = prompt_tokens(example["messages"], "gpt-4") + tokens = spice.count_prompt_tokens(example["messages"], "gpt-4") elif "text" in example: - tokens = count_tokens(example["text"], "gpt-4", full_message=False) + tokens = spice.count_tokens(example["text"], "gpt-4", is_message=False) example["tokens"] = tokens print("Generated finetune example" f" {sample.id[:8]} ({example['tokens']} tokens)") logs.append(example) @@ -115,7 +114,7 @@ async def main(): ) if args.validate: - print(f"{sum([log['is_valid'] for log in logs])}/{len(logs)} samples passed" " validation.") + print(f"{sum([log['is_valid'] for log in logs])}/{len(logs)} samples passed validation.") elif args.finetune: # Dump all logs into a .jsonl file timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/scripts/select_git_transcripts.py b/scripts/select_git_transcripts.py index 36382cefd..ca768ff5b 100755 --- a/scripts/select_git_transcripts.py +++ b/scripts/select_git_transcripts.py @@ -4,7 +4,7 @@ import os from pathlib import Path -from mentat.llm_api import count_tokens, model_context_size +from spice.spice import get_model_from_name def select_transcripts( @@ -33,7 +33,7 @@ def select_transcripts( continue if skip_config and info["configuration"]: continue - if count_tokens(json.dumps(info["mocked_conversation"]), model) > model_context_size(model): + if count_tokens(json.dumps(info["mocked_conversation"]), model) > get_model_from_name(model).context_length: continue transcripts.append(info["mocked_conversation"]) diff --git a/tests/benchmark_test.py b/tests/benchmark_test.py index 81de3586e..a037c0a38 100644 --- a/tests/benchmark_test.py +++ b/tests/benchmark_test.py @@ -43,7 +43,7 @@ async def test_calculator_add_power(mock_collect_user_input): calculator_path = "scripts/calculator.py" results = await edit_file_and_run( mock_collect_user_input, - prompts=["Add power as a possible operation, raising the first arg to the power of" " the second"], + prompts=["Add power as a possible operation, raising the first arg to the power of the second"], context_file_paths=[calculator_path], main_file_path=calculator_path, argument_lists=[["power", "15", "3"]], diff --git a/tests/code_context_test.py b/tests/code_context_test.py index 75b401d3e..5c046f5ca 100644 --- a/tests/code_context_test.py +++ b/tests/code_context_test.py @@ -10,7 +10,6 @@ from mentat.git_handler import get_non_gitignored_files from mentat.include_files import is_file_text_encoded from mentat.interval import Interval -from mentat.llm_api_handler import count_tokens from tests.conftest import run_git_command @@ -214,7 +213,7 @@ def func_4(string): mock_session_context.config.auto_context_tokens = 8000 code_message = await code_context.get_code_message(0, prompt="prompt") - assert count_tokens(code_message, "gpt-4", full_message=True) == 95 # Code + assert mock_session_context.llm_api_handler.spice.count_tokens(code_message, "gpt-4", is_message=True) == 95 # Code assert ( code_message == """\ diff --git a/tests/commands_test.py b/tests/commands_test.py index f897fdf59..c4585a9a3 100644 --- a/tests/commands_test.py +++ b/tests/commands_test.py @@ -437,7 +437,7 @@ async def test_screenshot_command(mocker): stream = session_context.stream conversation = session_context.conversation - assert config.model != "gpt-4-vision-preview" + assert config.model != "gpt-4-turbo" mock_vision_manager.screenshot.return_value = "fake_image_data" @@ -445,7 +445,7 @@ async def test_screenshot_command(mocker): await screenshot_command.apply("fake_path") mock_vision_manager.screenshot.assert_called_once_with("fake_path") - assert config.model == "gpt-4-vision-preview" + assert config.model == "gpt-4-turbo" assert stream.messages[-1].data == "Screenshot taken for: fake_path." assert conversation._messages[-1] == { "role": "user", diff --git a/tests/conftest.py b/tests/conftest.py index dcfce3d06..234e661f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,6 @@ from mentat.code_file_manager import CodeFileManager from mentat.config import Config, config_file_name from mentat.conversation import Conversation -from mentat.cost_tracker import CostTracker from mentat.llm_api_handler import LlmApiHandler from mentat.parsers.streaming_printer import StreamingPrinter from mentat.sampler.sampler import Sampler @@ -171,8 +170,6 @@ def mock_session_context(temp_testbed): """ stream = SessionStream() - cost_tracker = CostTracker() - config = Config() llm_api_handler = LlmApiHandler() @@ -194,7 +191,6 @@ def mock_session_context(temp_testbed): Path.cwd(), stream, llm_api_handler, - cost_tracker, config, code_context, code_file_manager, @@ -300,3 +296,8 @@ def mock_user_config(mocker): @pytest.fixture(autouse=True) def mock_sleep_time(mocker): mocker.patch.object(StreamingPrinter, "sleep_time", new=lambda self: 0) + + +@pytest.fixture(autouse=True) +def mock_api_key(): + os.environ["OPENAI_API_KEY"] = "fake_testing_key" diff --git a/tests/llm_api_handler_test.py b/tests/llm_api_handler_test.py deleted file mode 100644 index 31fd1d691..000000000 --- a/tests/llm_api_handler_test.py +++ /dev/null @@ -1,33 +0,0 @@ -import base64 -from io import BytesIO - -from PIL import Image - -from mentat.llm_api_handler import prompt_tokens - - -def test_prompt_tokens(): - messages = [ - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Hi there! How can I help you today?"}, - ] - model = "gpt-4-vision-preview" - - assert prompt_tokens(messages, model) == 24 - - # An image that must be scaled twice and then fits in 6 512x512 panels - img = Image.new("RGB", (768 * 4, 1050 * 4), color="red") - buffer = BytesIO() - img.save(buffer, format="PNG") - buffer.seek(0) - img_base64 = base64.b64encode(buffer.getvalue()).decode() - image_url = f"data:image/png;base64,{img_base64}" - - messages.append( - { - "role": "user", - "content": [{"type": "image_url", "image_url": {"url": image_url}}], - } - ) - - assert prompt_tokens(messages, model) == 24 + 6 * 170 + 85 + 5 diff --git a/tests/sampler_test.py b/tests/sampler_test.py index d268042ed..4d30d614b 100644 --- a/tests/sampler_test.py +++ b/tests/sampler_test.py @@ -187,7 +187,7 @@ async def test_sample_command(temp_testbed, mock_collect_user_input, mock_call_l "message_history": [], "message_prompt": "Add a sha1 function to utils.py", "message_edit": ( - "I will add a new sha1 function to the `utils.py` file.\n\nSteps:\n1." " Add the sha1 function to `utils.py`." + "I will add a new sha1 function to the `utils.py` file.\n\nSteps:\n1. Add the sha1 function to `utils.py`." ), "context": ["mentat/utils.py"], "diff_edit": (