Skip to content

Commit

Permalink
fix distribution trace bug + move into interpreter
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeurerkellner committed Oct 12, 2023
1 parent 14594de commit 70936f0
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 20 deletions.
4 changes: 4 additions & 0 deletions src/lmql/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import namedtuple
from dataclasses import dataclass
from typing import Any, Dict, Optional, List, Union, NamedTuple, Tuple, Set
from lmql.runtime.postprocessing.conditional_prob import ConditionalDistributionPostprocessor
import numpy as np
import warnings

Expand Down Expand Up @@ -1126,6 +1127,9 @@ async def debug_out(decoder_step):
# set decoder step +1, for all stats logging that happens in postprocessing
self.decoder_step += 1

# applies distribution postprocessor if required
results = await (ConditionalDistributionPostprocessor(self).process(results))

# check if a certificate was requested
if self.certificate != False:
active_tracer().event("lmql.LMQLResult", results, skip_none=True)
Expand Down
11 changes: 1 addition & 10 deletions src/lmql/runtime/lmql_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from typing import Any, Dict, Optional

from lmql.ops.ops import *
from lmql.runtime.context import Context
from lmql.runtime.langchain import chain, call_sync
from lmql.runtime.output_writer import silent
from lmql.runtime.postprocessing.conditional_prob import \
ConditionalDistributionPostprocessor
from lmql.runtime.postprocessing.group_by import GroupByPostprocessor
from lmql.api.inspect import is_query
from lmql.runtime.formatting import format, tag
Expand Down Expand Up @@ -232,14 +231,6 @@ async def __acall__(self, *args, **kwargs):
finally:
if PromptInterpreter.main == interpreter:
PromptInterpreter.main = None

# applies distribution postprocessor if required
results = await (ConditionalDistributionPostprocessor(interpreter).process(results))

# apply remaining postprocessors
if self.postprocessors is not None:
for postprocessor in self.postprocessors:
results = await postprocessor.process(results, self.output_writer)

interpreter.print_stats()
interpreter.dcmodel.close()
Expand Down
11 changes: 2 additions & 9 deletions src/lmql/runtime/postprocessing/conditional_prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,10 @@ async def score(self, prompt: str, values, dcmodel: dc.DcModel):

async def process(self, results):
model: dc.DcModel = self.interpreter.dcmodel
# optional unpacker for singular results
unpack = lambda v: v

# unpack singular results after processing
if type(results) is not list:
results = [results]
unpack = lambda v: v[0]

# check if distribution is required
if not any(r is not None and hasattr(r, "distribution_variable") and r.distribution_variable is not None for r in results):
return unpack(results)
return results

if len(results) > 1:
if "top1_distribution" in self.interpreter.decoder_kwargs and self.interpreter.decoder_kwargs["top1_distribution"]:
Expand Down Expand Up @@ -77,4 +70,4 @@ async def process(self, results):
result.variables[f"P({distribution_variable})"] = [(value, prob) for value, prob, _ in distribution]
result.variables[f"log P({distribution_variable})"] = [(value, prob) for value, prob, _ in log_distribution]

return unpack(results[0])
return results
4 changes: 3 additions & 1 deletion src/lmql/runtime/tracing/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def active_tracer() -> Tracer:
different tracers in sub-queries.
"""
_ensure_tracer()
assert len(_tracer.get()) > 0, "No tracer set in this context"
if len(_tracer.get()) == 0:
warnings.warn("An LMQL tracer was requested in a context without active tracer. This indicates that some internal LLM calls may not be traced correctly.")
return NullTracer("null")
return _tracer.get()[-1]

def set_tracer(tracer):
Expand Down
1 change: 1 addition & 0 deletions src/lmql/tests/test_sample_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ async def main():
print(error_buffer.getvalue())
print(e)
print(termcolor.colored("[FAIL]", "red"), f"({time.time() - s:.2f}s)")
sys.exit(1)


if __name__ == "__main__":
Expand Down

0 comments on commit 70936f0

Please sign in to comment.