Skip to content

Commit

Permalink
Update spice remove cost tracker (#563)
Browse files Browse the repository at this point in the history
  • Loading branch information
PCSwingle authored Apr 12, 2024
1 parent fd17d74 commit bd609ce
Show file tree
Hide file tree
Showing 39 changed files with 143 additions and 444 deletions.
2 changes: 1 addition & 1 deletion benchmarks/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def common_benchmark_parser():
"--benchmarks",
nargs="*",
default=[],
help=("Which benchmarks to run. max_benchmarks ignored when set. Exact meaning" " depends on benchmark."),
help=("Which benchmarks to run. max_benchmarks ignored when set. Exact meaning depends on benchmark."),
)
parser.add_argument(
"--directory",
Expand Down
8 changes: 4 additions & 4 deletions benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from openai.types.chat.completion_create_params import ResponseFormat
from spice import SpiceMessage
from spice.spice import get_model_from_name

from benchmarks.arg_parser import common_benchmark_parser
from benchmarks.benchmark_result import BenchmarkResult
Expand All @@ -22,7 +23,6 @@
from benchmarks.swe_bench_runner import SWE_BENCH_SAMPLES_DIR, get_swe_samples
from mentat.config import Config
from mentat.git_handler import get_git_diff, get_mentat_branch, get_mentat_hexsha
from mentat.llm_api_handler import model_context_size, prompt_tokens
from mentat.sampler.sample import Sample
from mentat.sampler.utils import setup_repo
from mentat.session_context import SESSION_CONTEXT
Expand All @@ -45,20 +45,20 @@ def git_diff_from_comparison_commit(sample: Sample, comparison_commit: str) -> s

async def grade(to_grade, prompt, model="gpt-4-1106-preview"):
try:
llm_api_handler = SESSION_CONTEXT.get().llm_api_handler
messages: List[SpiceMessage] = [
{"role": "system", "content": prompt},
{"role": "user", "content": to_grade},
]
tokens = prompt_tokens(messages, model)
max_tokens = model_context_size(model) - 1000 # Response buffer
tokens = llm_api_handler.spice.count_prompt_tokens(messages, model)
max_tokens = get_model_from_name(model).context_length - 1000 # Response buffer
if tokens > max_tokens:
print("Prompt too long! Truncating... (this may affect results)")
tokens_to_remove = tokens - max_tokens
chars_per_token = len(str(messages)) / tokens
chars_to_remove = int(chars_per_token * tokens_to_remove)
messages[1]["content"] = messages[1]["content"][:-chars_to_remove]

llm_api_handler = SESSION_CONTEXT.get().llm_api_handler
llm_grade = await llm_api_handler.call_llm_api(messages, model, None, False, ResponseFormat(type="json_object"))
content = llm_grade.text
return json.loads(content)
Expand Down
5 changes: 2 additions & 3 deletions benchmarks/exercism_practice.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,12 @@ async def run_exercise(problem_dir, language="python", max_iterations=2):
messages = client.get_conversation().literal_messages
await client.shutdown()
passed = exercise_runner.passed()
cost_tracker = SESSION_CONTEXT.get().cost_tracker
result = BenchmarkResult(
iterations=iterations,
passed=passed,
name=exercise_runner.name,
tokens=cost_tracker.total_tokens,
cost=cost_tracker.total_cost,
tokens=None,
cost=SESSION_CONTEXT.get().llm_api_handler.spice.total_cost / 100,
transcript={"id": problem_dir, "messages": messages},
)
if had_error:
Expand Down
5 changes: 2 additions & 3 deletions benchmarks/run_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None, config: Conf
await mentat.startup()
session_context = SESSION_CONTEXT.get()
conversation = session_context.conversation
cost_tracker = session_context.cost_tracker
for msg in sample.message_history:
if msg["role"] == "user":
conversation.add_user_message(msg["content"])
Expand Down Expand Up @@ -127,8 +126,8 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None, config: Conf
"id": sample.id,
"message_eval": message_eval,
"diff_eval": diff_eval,
"cost": cost_tracker.total_cost,
"tokens": cost_tracker.total_tokens,
"cost": session_context.llm_api_handler.spice.total_cost / 100,
"tokens": None,
"transcript": {
"id": sample.id,
"messages": transcript_messages,
Expand Down
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ gitpython==3.1.41
isort==5.12.0
pip-licenses==4.3.3
plotly==5.18.0
pyright==1.1.339
pyright==1.1.358
pytest-xdist==3.3.1
ruff==0.0.292
8 changes: 0 additions & 8 deletions docs/source/developer/mentat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,6 @@ mentat.conversation module
:undoc-members:
:show-inheritance:

mentat.cost\_tracker module
---------------------------

.. automodule:: mentat.cost_tracker
:members:
:undoc-members:
:show-inheritance:

mentat.diff\_context module
---------------------------

Expand Down
8 changes: 4 additions & 4 deletions mentat/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ async def enable_agent_mode(self):
self.agent_file_message += f"{path}\n\n{file_contents}"

ctx.stream.send(
"The model has chosen these files to help it determine how to test its" " changes:",
"The model has chosen these files to help it determine how to test its changes:",
style="info",
)
ctx.stream.send("\n".join(str(path) for path in paths))
ctx.cost_tracker.display_last_api_call()
ctx.llm_api_handler.display_cost_stats(response)

messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content))
ctx.conversation.add_transcript_message(
Expand All @@ -82,7 +82,7 @@ async def _determine_commands(self) -> List[str]:
try:
# TODO: Should this even be a separate call or should we collect commands in the edit call?
response = await ctx.llm_api_handler.call_llm_api(messages, model, ctx.config.provider, False)
ctx.cost_tracker.display_last_api_call()
ctx.llm_api_handler.display_cost_stats(response)
except BadRequestError as e:
ctx.stream.send(f"Error accessing OpenAI API: {e.message}", style="error")
return []
Expand Down Expand Up @@ -113,7 +113,7 @@ async def add_agent_context(self) -> bool:
run_commands = await ask_yes_no(default_yes=True)
if not run_commands:
ctx.stream.send(
"Enter a new-line separated list of commands to run, or nothing to" " return control to the user:",
"Enter a new-line separated list of commands to run, or nothing to return control to the user:",
style="info",
)
commands: list[str] = (await collect_user_input()).data.strip().splitlines()
Expand Down
17 changes: 7 additions & 10 deletions mentat/auto_completer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def _find_shlex_last_word_position(self, argument_buffer: str, num_words: int) -
lex.whitespace_split = True
for _ in range(num_words - 1):
lex.get_token()
remaining = list(lex.instream)
return 0 if not remaining else -len(remaining[0])
remaining = list(lex.instream) # pyright: ignore
return 0 if not remaining else -len(remaining[0]) # pyright: ignore

def _command_argument_completion(self, buffer: str) -> List[Completion]:
if any(buffer.startswith(space) for space in whitespace):
Expand Down Expand Up @@ -177,7 +177,7 @@ def _refresh_file_completion(self, file_path: Path):
try:
lexer = cast(Lexer, guess_lexer_for_filename(file_path, file_content))
except ClassNotFound:
self._file_completions[file_path] = FileCompletion(datetime.utcnow(), set())
self._file_completions[file_path] = FileCompletion(datetime.now(), set())
return

tokens = list(lexer.get_tokens(file_content))
Expand All @@ -188,7 +188,7 @@ def _refresh_file_completion(self, file_path: Path):
if len(token_value) <= 1:
continue
filtered_tokens.add(token_value)
self._file_completions[file_path] = FileCompletion(datetime.utcnow(), filtered_tokens)
self._file_completions[file_path] = FileCompletion(datetime.now(), filtered_tokens)

def _refresh_all_file_completions(self):
ctx = SESSION_CONTEXT.get()
Expand All @@ -205,7 +205,7 @@ def _refresh_all_file_completions(self):
if file_path not in self._file_completions:
self._refresh_file_completion(file_path)
else:
modified_at = datetime.utcfromtimestamp(os.path.getmtime(file_path))
modified_at = datetime.fromtimestamp(os.path.getmtime(file_path))
if self._file_completions[file_path].last_updated < modified_at:
self._refresh_file_completion(file_path)

Expand All @@ -216,13 +216,10 @@ def _refresh_all_file_completions(self):
self._all_file_completions.add(str(rel_path))
self._all_file_completions.update(file_completion.syntax_fragments)

self._last_refresh_at = datetime.utcnow()
self._last_refresh_at = datetime.now()

def get_file_completions(self, buffer: str) -> List[Completion]:
if (
self._last_refresh_at is None
or (datetime.utcnow() - self._last_refresh_at).seconds > SECONDS_BETWEEN_REFRESH
):
if self._last_refresh_at is None or (datetime.now() - self._last_refresh_at).seconds > SECONDS_BETWEEN_REFRESH:
self._refresh_all_file_completions()

if not buffer or buffer[-1] == " ":
Expand Down
20 changes: 12 additions & 8 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
validate_and_format_path,
)
from mentat.interval import parse_intervals, split_intervals_from_path
from mentat.llm_api_handler import count_tokens, get_max_tokens
from mentat.llm_api_handler import get_max_tokens
from mentat.session_context import SESSION_CONTEXT
from mentat.session_stream import SessionStream
from mentat.utils import get_relative_path, mentat_dir_path
Expand Down Expand Up @@ -60,6 +60,7 @@ def __init__(

async def refresh_daemon(self):
"""Call before interacting with context to ensure daemon is up to date."""

if not hasattr(self, "daemon"):
# Daemon is initialized after setup because it needs the embedding_provider.
ctx = SESSION_CONTEXT.get()
Expand All @@ -76,7 +77,9 @@ async def refresh_daemon(self):
annotators=annotators,
verbose=False,
graph_path=graphs_dir / f"ragdaemon-{cwd.name}.json",
spice_client=getattr(llm_api_handler, "spice_client", None),
spice_client=llm_api_handler.spice,
model=ctx.config.embedding_model,
provider=ctx.config.embedding_provider,
)
await self.daemon.update()

Expand All @@ -96,7 +99,7 @@ async def refresh_context_display(self):

total_tokens = await ctx.conversation.count_tokens(include_code_message=True)

total_cost = ctx.cost_tracker.total_cost
total_cost = ctx.llm_api_handler.spice.total_cost

data = ContextStreamMessage(
cwd=str(ctx.cwd),
Expand Down Expand Up @@ -126,6 +129,7 @@ async def get_code_message(
"""
session_context = SESSION_CONTEXT.get()
config = session_context.config
llm_api_handler = session_context.llm_api_handler
model = config.model
cwd = session_context.cwd
code_file_manager = session_context.code_file_manager
Expand Down Expand Up @@ -164,10 +168,10 @@ async def get_code_message(

# If auto-context, replace the context_builder with a new one
if config.auto_context_tokens > 0 and prompt:
meta_tokens = count_tokens("\n".join(header_lines), model, full_message=True)
meta_tokens = llm_api_handler.spice.count_tokens("\n".join(header_lines), model, is_message=True)

include_files_message = context_builder.render()
include_files_tokens = count_tokens(include_files_message, model, full_message=False)
include_files_tokens = llm_api_handler.spice.count_tokens(include_files_message, model, is_message=False)

tokens_used = prompt_tokens + meta_tokens + include_files_tokens
auto_tokens = min(
Expand Down Expand Up @@ -209,10 +213,10 @@ def get_all_features(
cwd = session_context.cwd

all_features = list[CodeFeature]()
for _, data in self.daemon.graph.nodes(data=True): # pyright: ignore
if data is None or "type" not in data or "ref" not in data or data["type"] not in {"file", "chunk"}:
for _, data in self.daemon.graph.nodes(data=True): # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
if data is None or "type" not in data or "ref" not in data or data["type"] not in {"file", "chunk"}: # pyright: ignore[reportUnnecessaryComparison]
continue
path, interval = split_intervals_from_path(data["ref"]) # pyright: ignore
path, interval = split_intervals_from_path(data["ref"])
intervals = parse_intervals(interval)
if not intervals:
all_features.append(CodeFeature(cwd / path))
Expand Down
7 changes: 4 additions & 3 deletions mentat/code_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from mentat.errors import MentatError
from mentat.interval import INTERVAL_FILE_END, Interval
from mentat.llm_api_handler import count_tokens
from mentat.session_context import SESSION_CONTEXT
from mentat.utils import get_relative_path

Expand Down Expand Up @@ -61,10 +60,12 @@ def __str__(self, cwd: Optional[Path] = None) -> str:


def count_feature_tokens(feature: CodeFeature, model: str) -> int:
cwd = SESSION_CONTEXT.get().cwd
ctx = SESSION_CONTEXT.get()

cwd = ctx.cwd
ref = feature.__str__(cwd)
document = get_document(ref, cwd)
return count_tokens(document, model, full_message=False)
return ctx.llm_api_handler.spice.count_tokens(document, model, is_message=False)


def get_consolidated_feature_refs(features: list[CodeFeature]) -> list[str]:
Expand Down
2 changes: 1 addition & 1 deletion mentat/code_file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def write_changes_to_files(

if file_edit.is_creation:
if file_edit.file_path.exists():
raise MentatError(f"Model attempted to create file {file_edit.file_path} which" " already exists")
raise MentatError(f"Model attempted to create file {file_edit.file_path} which already exists")
self.create_file(file_edit.file_path)
elif not file_edit.file_path.exists():
raise MentatError(f"Attempted to edit non-existent file {file_edit.file_path}")
Expand Down
2 changes: 1 addition & 1 deletion mentat/command/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def apply(self, *args: str) -> None:
stream = session_context.stream

stream.send(
f"{self.invalid_name} is not a valid command. Use /help to see a list of" " all valid commands",
f"{self.invalid_name} is not a valid command. Use /help to see a list of all valid commands",
style="warning",
)

Expand Down
2 changes: 1 addition & 1 deletion mentat/command/commands/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ def argument_autocompletions(cls, arguments: list[str], argument_position: int)
@override
@classmethod
def help_message(cls) -> str:
return "Toggle agent mode. In agent mode Mentat will automatically make changes" " and run commands."
return "Toggle agent mode. In agent mode Mentat will automatically make changes and run commands."
2 changes: 1 addition & 1 deletion mentat/command/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def apply(self, *args: str) -> None:
value = args[1]
if attr.fields_dict(type(config))[setting].metadata.get("no_midsession_change"):
stream.send(
f"Cannot change {setting} mid-session. Please restart" " Mentat to change this setting.",
f"Cannot change {setting} mid-session. Please restart Mentat to change this setting.",
style="warning",
)
return
Expand Down
6 changes: 3 additions & 3 deletions mentat/command/commands/screenshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ async def apply(self, *args: str) -> None:
model = config.model

if "gpt" in model:
if "vision" not in model:
if not ("vision" in model or ("gpt-4-turbo" in model and "preview" not in model)):
stream.send(
"Using a version of gpt that doesn't support images. Changing to" " gpt-4-vision-preview",
"Using a version of gpt that doesn't support images. Changing to gpt-4-turbo",
style="warning",
)
config.model = "gpt-4-vision-preview"
config.model = "gpt-4-turbo"
else:
stream.send(
"Can't determine if this model supports vision. Attempting anyway.",
Expand Down
14 changes: 9 additions & 5 deletions mentat/command/commands/talk.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ async def record(self):
self.start_time = default_timer()

self.q: queue.Queue[np.ndarray[Any, Any]] = queue.Queue()
with sf.SoundFile( # pyright: ignore[reportUnboundVariable]
with sf.SoundFile( # pyright: ignore[reportPossiblyUnboundVariable]
self.file, mode="w", samplerate=RATE, channels=1
) as file:
with sd.InputStream( # pyright: ignore[reportUnboundVariable]
with sd.InputStream( # pyright: ignore[reportPossiblyUnboundVariable]
samplerate=RATE, channels=1, callback=self.callback
):
while not self.shutdown.is_set():
Expand Down Expand Up @@ -75,9 +75,13 @@ async def apply(self, *args: str) -> None:
await recorder.record()
ctx.stream.send("Processing audio with whisper...")
await asyncio.sleep(0.01)
transcript = await ctx.llm_api_handler.call_whisper_api(recorder.file)
ctx.stream.send(transcript, channel="default_prompt")
ctx.cost_tracker.log_whisper_call_stats(recorder.recording_time)
response = await ctx.llm_api_handler.call_whisper_api(recorder.file)
ctx.stream.send(response.text, channel="default_prompt")
if response.cost:
ctx.stream.send(
f"Whisper audio length and cost: {response.input_length:.2f}s, ${response.cost / 100:.2f}",
style="info",
)

@override
@classmethod
Expand Down
Loading

0 comments on commit bd609ce

Please sign in to comment.