Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(closes #31) adds a base sampler protocol, reorganizes samplers and main #40

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
(closes #31) adds a base sampler protocol, reorganizes samplers and main
  • Loading branch information
qdbp committed Oct 8, 2024
commit c601285638100b9a436f1418b172b64c2e6a8435
91 changes: 62 additions & 29 deletions entropix/main.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, Generic

import jax
import jax.numpy as jnp
import tyro

from entropix.config import LLAMA_1B_PARAMS
from entropix.config import LLAMA_1B_PARAMS, ModelParams
from entropix.kvcache import KVCache
from entropix.model import xfmr
from entropix.sampler import SamplerConfig, sample
from entropix.prompts import create_prompts_from_csv, prompt
from entropix.sampler import sample
from entropix.samplers import ST, Cfg_contra, EntropySampler
from entropix.samplers.baseline_sampler import SamplerConfig as BaselineSamplerConfig
from entropix.samplers.baseline_sampler import sample as baseline_sampler
from entropix.tokenizer import Tokenizer
from entropix.weights import load_weights
from entropix.weights import XfmrWeights, load_weights

DEFAULT_WEIGHTS_PATH = Path(__file__).parent / "../weights"

DEFAULT_WEIGHTS_PATH = Path(__file__).parent / '../weights'

def apply_scaling(freqs: jax.Array):
SCALE_FACTOR = 8
Expand All @@ -36,13 +40,15 @@ def scale_mid(_):
wavelen < high_freq_wavelen,
lambda _: freq,
lambda _: jax.lax.cond(wavelen > low_freq_wavelen, lambda _: freq / SCALE_FACTOR, scale_mid, None),
None
None,
)

return jax.vmap(scale_freq)(freqs)


def precompute_freqs_cis(dim: int, end: int, theta: float = 500000.0, use_scaled: bool = False, dtype: jnp.dtype = jnp.float32) -> jax.Array:
def precompute_freqs_cis(
dim: int, end: int, theta: float = 500000.0, use_scaled: bool = False, dtype: jnp.dtype = jnp.float32
) -> jax.Array:
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
if use_scaled:
freqs = apply_scaling(freqs)
Expand All @@ -54,55 +60,82 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 500000.0, use_scaled
def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array:
mask = jnp.zeros((seqlen, seqlen), dtype=jnp.float32)
if seqlen > 1:
mask = jnp.full((seqlen, seqlen), float('-inf'))
mask = jnp.full((seqlen, seqlen), float("-inf"))
mask = jnp.triu(mask, k=1)
mask = jnp.hstack([jnp.zeros((seqlen, start_pos)), mask], dtype=jnp.float32)
return mask


def main(weights_path: Path = DEFAULT_WEIGHTS_PATH.joinpath('1B-Instruct')):
model_params = LLAMA_1B_PARAMS
xfmr_weights = load_weights(weights_path.absolute())
tokenizer = Tokenizer('entropix/tokenizer.model')
# Create the batch of tokens
@dataclass(kw_only=True)
class TokenGenerator(Generic[Cfg_contra, ST]):
weights: XfmrWeights
model_params: ModelParams
tokenizer: Tokenizer
sampler: EntropySampler[Cfg_contra, ST]
sampler_cfg: Cfg_contra

# Create the batch of tokens
def generate(xfmr_weights, model_params, tokens):
def generate_from_prompt(self, init_tokens) -> Generator[str, None, None]:
gen_tokens = None
cur_pos = 0
tokens = jnp.array([tokens], jnp.int32)
tokens = jnp.array([init_tokens], jnp.int32)
bsz, seqlen = tokens.shape
attn_mask = build_attn_mask(seqlen, cur_pos)
freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope)
kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim)
logits, kvcache, _, _ = xfmr(xfmr_weights, model_params, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask)
mp = self.model_params
freqs_cis = precompute_freqs_cis(mp.head_dim, mp.max_seq_len, mp.rope_theta, mp.use_scaled_rope)
kvcache = KVCache.new(mp.n_layers, bsz, mp.max_seq_len, mp.n_local_kv_heads, mp.head_dim)
logits, kvcache, _, _ = xfmr(self.weights, mp, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask)
next_token = jnp.argmax(logits[:, -1], axis=-1, keepdims=True).astype(jnp.int32)
gen_tokens = next_token
print(tokenizer.decode([next_token.item()]), end='', flush=True)

yield self.tokenizer.decode([next_token.item()])

cur_pos = seqlen
stop = jnp.array([128001, 128008, 128009])
sampler_cfg = SamplerConfig()
state: ST | None = None
while cur_pos < 8192:
cur_pos += 1
logits, kvcache, scores, stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache)
next_token = sample(gen_tokens, logits, scores, cfg=sampler_cfg)
logits, kvcache, scores, _ = xfmr(
self.weights, mp, next_token, cur_pos, freqs_cis[cur_pos : cur_pos + 1], kvcache
)
next_token, state = self.sampler(gen_tokens, logits, scores, cfg=self.sampler_cfg, state=state)
gen_tokens = jnp.concatenate((gen_tokens, next_token))
print(tokenizer.decode(next_token.tolist()[0]), end='', flush=True)
yield self.tokenizer.decode(next_token.tolist()[0])
if jnp.isin(next_token, stop).any():
break

csv_path = Path('entropix/data/prompts.csv')

def main(weights_path: Path = DEFAULT_WEIGHTS_PATH.joinpath("1B-Instruct")):
model_params = LLAMA_1B_PARAMS
xfmr_weights = load_weights(weights_path.absolute())
# TODO(qdbp) make tokenizer into arg as well
tokenizer = Tokenizer("entropix/tokenizer.model")

csv_path = Path("entropix/data/prompts.csv")
prompts = create_prompts_from_csv(csv_path)
PROMPT_TEST = False

# TODO(qdbp) make these configurable once more are implemented
sampler = baseline_sampler
sampler_cfg = BaselineSamplerConfig()

generator = TokenGenerator(
weights=xfmr_weights, model_params=model_params, tokenizer=tokenizer, sampler=sampler, sampler_cfg=sampler_cfg
)

if PROMPT_TEST:
for p in prompts:
print(p)
tokens = tokenizer.encode(p, bos=False, eos=False, allowed_special='all')
generate(xfmr_weights, model_params, tokens)
tokens = tokenizer.encode(p, bos=False, eos=False, allowed_special="all")
for token in generator.generate_from_prompt(tokens):
print(token, end="", flush=True)

else:
print(prompt)
tokens = tokenizer.encode(prompt, bos=False, eos=False, allowed_special='all')
generate(xfmr_weights, model_params, tokens)
tokens = tokenizer.encode(prompt, bos=False, eos=False, allowed_special="all")
for token in generator.generate_from_prompt(tokens):
print(token, end="", flush=True)


if __name__ == '__main__':
if __name__ == "__main__":
tyro.cli(main)
189 changes: 0 additions & 189 deletions entropix/sampler.py

This file was deleted.

44 changes: 44 additions & 0 deletions entropix/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Protocol, TypeVar

import jax

# TODO(qdbp) these type vars would look MUCH less ugly if we just
# bumped to 3.12 for the new non-fugly generics syntax and variance inference

# sampler config typevar
Cfg_contra = TypeVar("Cfg_contra", contravariant=True) # input only -> contravariant

# sampler state type variable
ST = TypeVar("ST") # i/o -> invariant


class EntropySampler(Protocol[Cfg_contra, ST]):
"""
A sampler is any object that can be called to perform a single sampling step (see Sampler.__call__)

Functions count.
"""

def __call__(
self,
gen_tokens: jax.Array,
logits: jax.Array,
attention_scores: jax.Array,
*,
cfg: Cfg_contra,
state: ST | None = None,
key: jax.Array = jax.random.PRNGKey(1337),
) -> tuple[jax.Array, ST]:
"""
Performs a single sampling step.

Args:
gen_tokens: Array of the current token context.
logits: Array of next token logits predicted by the model
attention_scores: Array of attention scores are returned by xfmr
cfg: class-specific configuration object encapsulating any other sampling parameters

Returns:
next token as jax.Array
"""
...
Loading