Skip to content

Commit

Permalink
Merge branch 'lve'
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeurerkellner committed Oct 12, 2023
2 parents 2882d74 + 95f1c97 commit 14594de
Show file tree
Hide file tree
Showing 22 changed files with 299 additions and 128 deletions.
8 changes: 5 additions & 3 deletions docs/docs/lib/inference-certificates.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ say_hello(certificate="my-certificate.json")
"Authorization": "<removed>",
"Content-Type": "application/json"
},
"tokenier": "<LMQLTokenizer 'gpt-3.5-turbo-instruct' using tiktoken <Encoding 'cl100k_base'>>",
"tokenizer": "<LMQLTokenizer 'gpt-3.5-turbo-instruct' using tiktoken <Encoding 'cl100k_base'>>",
"kwargs": {
"model": "gpt-3.5-turbo-instruct",
"prompt": [
Expand Down Expand Up @@ -117,8 +117,10 @@ with lmql.traced("my-context") as t:

This produces one certificate for all calls made in the defined context, where each query is represented as a separate item in the list of `children` certificates. Recorded events are are nested in child certificates. Additionally, an aggregated `metrics` object ranging over all (recursive) calls is included in the top-level certificate.

## Certificate Callbacks
## Certificate Callbacks And Return Values

As an alternative to directly writing certificates to a file, the `certificate` argument can also be used to specify a callback function that is called with the generated certificate as an argument. This can be used to integrate certificate generation into custom workflows.
As an alternative to directly writing certificates to a file, certificates can also be handled via a callback or returned as a function return value.

To specify a callback function that is called with the generated certificate as an argument, specify it as the `certificate=<FCT>` argument.

The callback is provided with a single `certificate` object, which is of type `lmql.InferenceCertificate`. The certificate can be directly serialized to JSON using string conversion, i.e., `str(certificate)`.
36 changes: 5 additions & 31 deletions scripts/serve-all.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,14 @@

summary = ""

summary += "Website on " + termcolor.colored("http://localhost:8080/\n\n", "green")
serve_web = subprocess.Popen(["python", "-m", "http.server", "8080"], cwd="web", stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
processes.append(serve_web)

summary += "Browser Playground on " + termcolor.colored("http://localhost:8081/playground/\n\n", "green")
serve_web_deploy = subprocess.Popen(["python", "-m", "http.server", "8081"], cwd="web-deploy", stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
processes.append(serve_web_deploy)

# autobuild sphinx docs
# onchange "**/*.py" "**/*.md" "**/*.rst" "**/*.css" "**/*.js" "**/*.ipynb" -e "build" -- make html
autobuild_sphinx_p = subprocess.Popen(["onchange", "**/*.py", "**/*.md", "**/*.rst", "**/*.css", "**/*.js", "**/*.ipynb", "-e", "build", "--", "make", "html"], cwd="docs")
processes.append(autobuild_sphinx_p)
# serve docs/build/html on http://localhost:8081/
summary += "Docs on " + termcolor.colored("http://localhost:8082/\n\n", "green")
# if docs are empty, build with make html
if not os.path.exists("docs/build/html/index.html"):
subprocess.run(["make", "html"], cwd="docs")
serve_docs = subprocess.Popen(["python", "-m", "http.server", "8082"], cwd="docs/build/html", stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
processes.append(serve_docs)

# autobuild blog
# onchange ../../docs/build/html/blog/*.html -- node generate.js
summary += "Blog on " + termcolor.colored("http://localhost:8080/blog\n\n", "green")
autobuild_blog_p = subprocess.Popen(["onchange", "../../docs/build/html/blog/*.html", "--", "node", "generate.js"], cwd="web/blog")
processes.append(autobuild_blog_p)

# autobuild web/
# onchange "index.template.html" "**/*.js" "**/*.css" "**/*.md" -e "./index.html" -- node generate.js
autobuild_web_p = subprocess.Popen(["onchange", "index.template.html", "**/*.js", "**/*.css", "**/*.md", "-e", "./index.html", "--", "node", "generate.js"], cwd="web")
processes.append(autobuild_web_p)

# autobuild web/actions
auto_build_actions_p = subprocess.Popen(["onchange", "**/*.js", "**/*.css", "**/*.md", "**/*.html", "**/*.pd", "**/*.json", "-e", "./index.html", "--", "node", "generate.js"], cwd="web/actions")
processes.append(auto_build_actions_p)
# autobuild web
summary += "Web on " + termcolor.colored("http://localhost:5173/\n", "green")
auto_build_docs = subprocess.Popen(["yarn", "run", "docs:dev"], cwd="docs")
processes.append(auto_build_docs)

while True:
try:
Expand All @@ -75,7 +49,7 @@
elif command == "bb": # browser build
print(termcolor.colored("Building browser playground...", "yellow"))
# run bash deploy.sh in web/
subprocess.run(["bash", "deploy.sh"], cwd="web")
subprocess.run(["bash", "deploy-web.sh"], cwd="scripts/")
print(termcolor.colored("Done!", "green"))
elif command == "docs-clean":
# docs/ make clean
Expand Down
4 changes: 2 additions & 2 deletions src/lmql/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def generate(self, prompt, max_tokens=None, **kwargs):
kwargs["chunksize"] = max_tokens
max_tokens = max_tokens + 1

name = "lmql.generate({}, {}, **{})".format(prompt, max_tokens, kwargs)
name = "lmql.generate".format("...", max_tokens, kwargs)
result = await generate_query(prompt, max_tokens=max_tokens, __name__=name, **kwargs)

if len(result) == 0:
Expand Down Expand Up @@ -312,7 +312,7 @@ async def generate_query(prompt, max_tokens=32):
"{prompt}[RESPONSE]" where len(TOKENS(RESPONSE)) < max_tokens
else:
"{prompt}[RESPONSE]"
return context.prompt
return RESPONSE
'''

def model(model_identifier, **kwargs) -> LLM:
Expand Down
2 changes: 1 addition & 1 deletion src/lmql/language/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def transform_qexpr(qexpr):
warnings.warn(f"Warning: distribution variable {qexpr.name} is decorated with {qexpr.decorator_exprs}, but decorators are ignored for distribution variables.")
return DistributionVariable(qexpr.name)
if type(qexpr) is FExpression:
return FExpression(f"lmql.lmql_runtime.f_escape({qexpr.expr})")
return FExpression(f"lmql.lmql_runtime.format({qexpr.expr})")
elif type(qexpr) is TagExpression:
return TagExpression(f"lmql.lmql_runtime.tag(\"{qexpr.tag[1:]}\")")
elif type(qexpr) is TemplateVariable:
Expand Down
21 changes: 9 additions & 12 deletions src/lmql/models/lmtp/backends/random_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class UniformRandomSamplingLLM(LMTPModel):
def __init__(self, seed=None, vocab=None, **kwargs):
self.seed = seed
self.seed = seed or 0
self.kwargs = kwargs

if kwargs.get("verbose", False):
Expand Down Expand Up @@ -37,24 +37,18 @@ def eos_token_id(self):
def vocab_size(self):
return self._vocab_size

# def score(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, **model_kwargs) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
# return super().score(input_ids, attention_mask, **model_kwargs)

def generate(self, input_ids, attention_mask,
temperature: float, max_new_tokens: int,
bias_tensor, streamer: TokenStreamer, **kwargs) -> LMTPModelResult:
if self.seed is not None:
seed = input_ids.sum() + self.seed
rng = np.random.RandomState(seed)
else:
rng = np.random.RandomState()

scores = []

if bias_tensor is not None:
bias_tensor = self.make_bias_tensor(bias_tensor, self.vocab_size)

for i in range(max_new_tokens):
seed = input_ids.sum() + self.seed
rng = np.random.RandomState(seed)

logits = np.zeros([len(input_ids), self.vocab_size])

if bias_tensor is not None:
Expand All @@ -65,12 +59,15 @@ def generate(self, input_ids, attention_mask,

next_ids = np.array([rng.choice(logits.shape[-1], size=1, p=probs[i]) for i in range(len(probs))]).reshape(-1,1)

for i,j in enumerate(next_ids):
logits[i, j.item()] += 1e-2
for k,j in enumerate(next_ids):
logits[k, j.item()] += 1e-2

scores += [nputil.log_softmax(logits, axis=-1)]
input_ids = np.concatenate([input_ids, next_ids], axis=-1)

if i+1 >= max_new_tokens:
break

streamer(input_ids, scores)

return LMTPModelResult(sequences=input_ids, scores=scores)
Expand Down
20 changes: 14 additions & 6 deletions src/lmql/models/lmtp/lmtp_dcmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,18 +229,20 @@ async def generate(self, s, temperature, top_logprobs = 1, chunk_size=None, **kw
chunk_size = self.model.chunk_size
kwargs = {**self.model_args, **kwargs}

# get token masks from interpreter
constrained_seqs = np.array([s.is_query_constrained], dtype=np.bool_)
logits_mask_result = await self.compute_logits_mask(s.input_ids.reshape(1, -1), [s.user_data], constrained_seqs, [s], **kwargs)

mask = logits_mask_result.logits_mask[0]

assert kwargs.get("num_samples", 1) == 1, "LMTP does not support num_samples > 1 right now. Please, duplicate your dc.seq to obtain multiple sampled continuations."

# merge interpreter user data with previous/decoder data
if s.user_data is None:
s.user_data = {}
s.user_data = dc.deepmerge(dc.deepcopy(s.user_data), logits_mask_result.user_data[0])
s.user_data["set_by"] = "where"

# convert token mask to LMTP format
if mask is not None:
num_allowed = masks.mask_num_allowed(mask)
if num_allowed == 1:
Expand All @@ -255,25 +257,31 @@ async def generate(self, s, temperature, top_logprobs = 1, chunk_size=None, **kw
mask_value = 100 if invert else -100
mask = {int(idx): mask_value for idx in np.nonzero(masked)[0]}

# convert seq to input IDs
ids = self.tokenizer.convert_bytes_to_ids(s.input_ids)

if len(ids) > 0 and self.tokenizer.bos_token_id is not None and ids[0] != self.tokenizer.bos_token_id:
if len(ids) == 0 or (len(ids) > 0 and self.tokenizer.bos_token_id is not None and ids[0] != self.tokenizer.bos_token_id):
ids = [self.tokenizer.bos_token_id] + ids

# derive max_tokens
max_tokens = logits_mask_result.max_tokens_hints[0] or chunk_size
# if '-1', generation is not limited
if max_tokens == -1: max_tokens = 128

if self.verbose:
text = await self.detokenize(ids)
print("lmtp generate: {} / {} ({} tokens)".format(ids, str([text])[1:-1], len(ids)))
print("lmtp generate: {} / {} ({} tokens, temperature={}, max_tokens={})".format(ids, str([text])[1:-1], len(ids), temperature, max_tokens))

# get token stream
token_stream = self.client.generate(ids, max_tokens=chunk_size, temperature=temperature, logit_bias=mask, top_logprobs=top_logprobs, **self.extra_decoding_parameters)
token_stream = self.client.generate(ids, max_tokens=max_tokens, temperature=temperature, logit_bias=mask, top_logprobs=top_logprobs, **self.extra_decoding_parameters)

if active_tracer().active:
stream_event = active_tracer().event("lmtp.generate", {
"model": await self.model_info(),
"tokenizer": str(self.tokenizer),
"kwargs": {
"ids": ids,
"max_tokens": chunk_size,
"max_tokens": max_tokens,
"temperature": temperature,
**({"logit_bias": mask} if mask is not None else {}),
"top_logprobs": top_logprobs,
Expand Down
6 changes: 5 additions & 1 deletion src/lmql/ops/booleans.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .node import *
from .booleans import *
from lmql.ops.max_token_hints import *

class NotOp(Node):
def forward(self, op, **kwargs):
Expand Down Expand Up @@ -67,3 +67,7 @@ def all(*args):
return args[0]
else:
return AndOp([AndOp.all(*args[:-1]), args[-1]])

def token_hint(self):
operand_hints = [op.token_hint() for op in self.predecessors]
return dict_min_token_hint(operand_hints)
39 changes: 39 additions & 0 deletions src/lmql/ops/max_token_hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Utility functions for working with max_tokens hints
as produced by LMQL constraint operations.
"""
from typing import List, Dict

TokenHint = Dict[str, int]

def dict_min_token_hint(hints: List[TokenHint]) -> TokenHint:
"""
Takes the element-wise minimum of the given token hints.
"""
if len(hints) == 0:
return {}
elif len(hints) == 1:
return hints[0]
else:
merged = {}
for h in hints:
for k,v in h.items():
if v != 0:
merged[k] = min(h.get(k, v), v)
return merged

def concrete_hints(hints: dict):
"""
Returns the concrete hints from the given list of token hints.
"""
return {k:v for k,v in hints.items() if v != 0}

def most_restrictive_hint(hints: List[int]):
"""
Takes the element-wise minimum of the given token hints.
"""
concrete = [h for h in hints if h != 0]
if len(concrete) == 0:
return 0
else:
return min(concrete)
20 changes: 19 additions & 1 deletion src/lmql/ops/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def final(self, args, **kwargs):
def __nodelabel__(self):
return str(type(self))

def token_hint(self):
"""
Hint for the number of tokens that need to be generated until
this operation forces EOS (0 for no hint, -1 for no constraints, >0 for specific number)
"""
return {}

def postprocess_var(self, var_name):
"""
Returns true if this operations provides postprocessing semantics for complete values for the given variable name.
Expand Down Expand Up @@ -252,9 +259,20 @@ def strip_next_token(x):
x = x[:-len(NextToken)]
return x

def token_hint(op, variable_name):
if op is None:
# unconstrained generation
return -1

if is_node(op):
hints = op.token_hint()
return hints.get(variable_name, 0)

return 0 # no hint (no information about no. of tokens to allow)

class postprocessed_value:
def __init__(self, value):
self.value = value
class postprocessed_rewrite:
def __init__(self, rewrite):
self.rewrite = rewrite
self.rewrite = rewrite
56 changes: 56 additions & 0 deletions src/lmql/ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from lmql.ops.inline_call import InlineCallOp
from lmql.ops.booleans import *
from lmql.ops.max_token_hints import *
from lmql.ops.regex import Regex
from lmql.models.model_info import model_info

Expand Down Expand Up @@ -79,6 +80,11 @@ def final(self, *args, **kwargs):
return None
return self.get_handler(kwargs.get("operands")).final(*args, **kwargs)

def token_hint(self):
if self.delegate is not None:
return self.delegate.token_hint()
return super().token_hint()

def __str__(self):
if self.delegate is not None:
return str(self.delegate)
Expand Down Expand Up @@ -373,6 +379,33 @@ def final(self, ops, operands=None, result=None, **kwargs):
r = transition_table[op1][op2]

return r

def token_hint(self):
"""
Checks if this Lt operation, imposes an upper limit on the
number of tokens that can be generated.
"""
num = [n for n in self.predecessors if type(n) is int]
len_op = [n for n in self.predecessors if isinstance(n, LenOp)]
if len(num) != 1 or len(len_op) != 1:
return super().token_hint()

limit = num[0]

# if limit is not rhs, it is a lower bound, not a limit
if limit != self.predecessors[1]:
return super().token_hint()

tokens_op = [n for n in len_op[0].predecessors if isinstance(n, TokensOp)]
if len(tokens_op) != 1:
return super().token_hint()

var = [n for n in tokens_op[0].predecessors if isinstance(n, Var)]

if len(var) != 1:
return super().token_hint()

return {var[0].name: limit - 1}

def Gt(preds): return Lt(list(reversed(preds)))

Expand Down Expand Up @@ -488,6 +521,29 @@ def final(self, operand_final, operands=None, result=None, **kwargs):
return "fin"

return "var"

def token_hint(self):
"""
Checks if this Eq operation imposes an exact limit on the
number of tokens that can be generated.
"""
num = [n for n in self.predecessors if type(n) is int]
len_op = [n for n in self.predecessors if isinstance(n, LenOp)]
if len(num) != 1 or len(len_op) != 1:
return super().token_hint()

limit = num[0]

tokens_op = [n for n in len_op[0].predecessors if isinstance(n, TokensOp)]
if len(tokens_op) != 1:
return super().token_hint()

var = [n for n in tokens_op[0].predecessors if isinstance(n, Var)]

if len(var) != 1:
return super().token_hint()

return {var[0].name: limit}

EqOp = DynamicTypeDispatch("EqOp", (
((int, int), EqOpInt),
Expand Down
Loading

0 comments on commit 14594de

Please sign in to comment.