Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: "3.12"
- run: uv sync --group dev
- run: uv sync --group dev --extra hf-tokenizers
- run: uv run ruff check .
- run: uv run mypy .
- run: uv run mypy .
15 changes: 15 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: Tests
on: [pull_request]
jobs:
unit_tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
with:
version: "latest"
- uses: actions/setup-python@v4
with:
python-version: "3.12"
- run: uv sync --group dev --extra hf-tokenizers
- run: uv run pytest tests
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ __pycache__
*.pyc
experiments
solutions
scratch_gpt.yaml
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ repo is educational, so the aim is to keep the code as legible as possible.

[x] Switch to uv
[x] Make it easy to modify with a config file
[] Extract the loss calculation from the model
[] Rename main to train
[x] Extract the loss calculation from the model
[x] Rename main to train
[x] Create or check tokenizer interface
[] Create an easy to use interface
[] Create or check tokenizer interface
[] Make it into a package
[] Apply SOTA optimizations

Expand Down Expand Up @@ -80,7 +80,7 @@ uv run tiktoken

## Project Structure

- `scratchgpt/main.py`: Main training script
- `scratchgpt/train.py`: Main training script
- `scratchgpt/infer.py`: Inference script for text generation
- `scratchgpt/model_io.py`: Utilities for saving and loading models
- `scratchgpt/tokenizer/`: Tokenizer implementations
Expand Down
13 changes: 10 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "scratchgpt"
version = "0.2.0"
version = "0.3.0"
description = "Add your description here"
authors = [
{ name = "Aleksandr Yeganov", email = "ayeganov@gmail.com"},
Expand All @@ -19,6 +19,13 @@ dependencies = [
"types-tqdm>=4.67.0.20250809",
]

[project.optional-dependencies]
hf-tokenizers = [
"tokenizers>=0.19.0",
"huggingface-hub>=0.34.4",
]


[dependency-groups]
dev = [
"bandit>=1.8.6",
Expand All @@ -42,7 +49,7 @@ strict = true
exclude = [".venv"]

[[tool.mypy.overrides]]
module = ["ptflops"]
module = ["ptflops", "tokenizers.*", "huggingface_hub.*"]
ignore_missing_imports = true

[tool.ruff]
Expand Down Expand Up @@ -72,6 +79,6 @@ requires = ["hatchling"]
build-backend = "hatchling.build"

[project.scripts]
train = "scratchgpt.main:main"
train = "scratchgpt.train:main"
infer = "scratchgpt.infer:main"
tiktoken = "scratchgpt.tokenizer.tiktoken:main"
21 changes: 20 additions & 1 deletion scratchgpt/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pydantic import Field
import math
from typing import Annotated, Literal

from pydantic import AfterValidator, Field
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
Expand All @@ -7,6 +10,20 @@
)


def ensure_split_is_valid(v: tuple[float, float, float]) -> tuple[float, float, float]:
"""
Validates the data split contains only 3 values and they add to 1.0
"""
splits_sum = sum(v)
is_valid_split = math.isclose(splits_sum, 1.0)
if not is_valid_split:
raise ValueError("Invalid data 'split'")
return v


SplitType = Annotated[tuple[float, float, float], AfterValidator(ensure_split_is_valid)]


class ScratchGPTArchitecture(BaseSettings):
"""
All settings for training the model.
Expand Down Expand Up @@ -35,6 +52,8 @@ class ScratchGPTTraining(BaseSettings):
batch_size: int = 32
dropout_rate: float = 0.2
random_seed: int = 1337
device: Literal["cuda", "cpu"] = "cuda"
splits: SplitType = (0.8, 0.1, 0.1)

model_config = SettingsConfigDict(
env_prefix="TRAINING_",
Expand Down
22 changes: 2 additions & 20 deletions scratchgpt/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Literal, override
from typing import override

import numpy as np
import torch
Expand All @@ -10,8 +10,6 @@

from .tokenizer.base_tokenizer import Tokenizer

DEFAULT_DTYPE = np.dtype(np.uint16)


class TextProvider(ABC):
@abstractmethod
Expand Down Expand Up @@ -64,28 +62,12 @@ def __init__(
text_provider: TextProvider,
tokenizer: Tokenizer,
block_size: int,
split: Literal["train", "validation", "test"],
train_ratio: float = 0.8,
val_ratio: float = 0.1,
) -> None:
self.tokenizer = tokenizer
self.block_size = block_size

self.data = torch.tensor(self.tokenizer.encode(text_provider.get_text()), dtype=torch.long)

total_size = len(self.data)
train_size = int(total_size * train_ratio)
val_size = int(total_size * val_ratio)

if split == "train":
self.data = self.data[:train_size]
elif split == "validation":
self.data = self.data[train_size : train_size + val_size]
elif split == "test":
self.data = self.data[train_size + val_size :]
else:
raise ValueError(f"Invalid split: {split}. Must be 'train', 'validation', or 'test'.")

def __len__(self) -> int:
return len(self.data) - self.block_size

Expand All @@ -100,7 +82,7 @@ def __init__(
self,
token_file: Path,
block_size: int,
dtype: np.dtype = DEFAULT_DTYPE,
dtype: np.dtype,
) -> None:
super().__init__()
self.block_size = block_size
Expand Down
58 changes: 43 additions & 15 deletions scratchgpt/model_io.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
import os
import pickle
from pathlib import Path

import torch

from scratchgpt.model.model import TransformerLanguageModel

from .tokenizer.base_tokenizer import Tokenizer
from .tokenizer.tiktoken import TiktokenWrapper
from scratchgpt.tokenizer import char_tokenizer, hf_tokenizer # noqa
from scratchgpt.tokenizer.base_tokenizer import TOKENIZER_REGISTRY, SerializableTokenizer, Tokenizer
from scratchgpt.tokenizer.tiktoken import TiktokenWrapper


class ModelLoadFailedError(Exception):
Expand All @@ -23,7 +23,7 @@ def get_latest_model_weights_path(exp_folder: Path) -> Path:


def get_tokenizer_path(exp_folder: Path) -> Path:
return exp_folder / "tokenizer.pkl"
return exp_folder / "tokenizer"


def load_model(model_path: Path, model: TransformerLanguageModel, device: torch.device) -> TransformerLanguageModel:
Expand All @@ -41,17 +41,45 @@ def load_model(model_path: Path, model: TransformerLanguageModel, device: torch.


def get_tokenizer(exp_path: Path) -> Tokenizer:
tokenizer_path = get_tokenizer_path(exp_path)
if os.path.exists(tokenizer_path):
with open(tokenizer_path, "rb") as f:
tokenizer: Tokenizer = pickle.load(f)
"""
Loads a tokenizer from the experiment directory.

This function reads the `tokenizer_config.json` to determine the correct
tokenizer type and then uses its `load` method. If no saved tokenizer
is found, it defaults to Tiktoken.
"""
tokenizer_dir = get_tokenizer_path(exp_path)
config_path = tokenizer_dir / "tokenizer_config.json"

if config_path.is_file():
print(f"Found tokenizer config at: {config_path}")
with open(config_path, encoding="utf-8") as f:
config = json.load(f)

tokenizer_type = config.get("tokenizer_type")
if not tokenizer_type:
raise ValueError("Tokenizer config is missing 'tokenizer_type' field.")

tokenizer_class = TOKENIZER_REGISTRY.get(tokenizer_type)

if tokenizer_class:
print(f"Loading tokenizer of type '{tokenizer_type}'...")
return tokenizer_class.load(tokenizer_dir)
else:
raise ValueError(f"Unknown tokenizer type '{tokenizer_type}' in config.")

else:
tokenizer = TiktokenWrapper("cl100k_base")
return tokenizer
print("No saved tokenizer found. Defaulting to Tiktoken 'cl100k_base'.")
return TiktokenWrapper("cl100k_base")


def save_tokenizer(exp_path: Path, tokenizer: Tokenizer) -> None:
tokenizer_path = get_tokenizer_path(exp_path)
with open(tokenizer_path, "wb") as f:
pickle.dump(tokenizer, f)
print(f"Saved the tokenizer to path: {tokenizer_path}")
"""
Saves a tokenizer if it supports the SerializableTokenizer interface.
"""
if isinstance(tokenizer, SerializableTokenizer):
tokenizer_path = get_tokenizer_path(exp_path)
tokenizer.save(tokenizer_path)
print(f"Saved tokenizer to path: {tokenizer_path}")
else:
print(f"Tokenizer of type '{type(tokenizer).__name__}' is not serializable and will not be saved.")
15 changes: 13 additions & 2 deletions scratchgpt/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def __call__(
"""


class FilePreprocessor(Protocol):
"""
Preprocessor that deals specifically with file system io.
"""

def __call__(self, input_path: Path, output_path: Path, chunk_size: int = 10 * 1024 * 1024) -> None:
"""
Process input and output paths
"""


class TokenizerPreprocessor(Preprocessor):
"""
Default pre-processor. Tokenizes a text stream and writes the output
Expand Down Expand Up @@ -69,7 +80,7 @@ def __call__(
pbar.update(len(chunk.encode("utf-8", errors="ignore")))


class File2FileTokenizerPreprocessor:
class File2FileTokenizerPreprocessor(FilePreprocessor):
"""
Orchestrates preprocessing for a single source file to a single destination file.
"""
Expand All @@ -95,7 +106,7 @@ def __call__(self, input_path: Path, output_path: Path, chunk_size: int = 10 * 1
print(f"Successfully preprocessed '{input_path}' to '{output_path}'")


class Folder2FileTokenizerPreprocessor:
class Folder2FileTokenizerPreprocessor(FilePreprocessor):
"""
Orchestrates preprocessing for a directory of source files to a single destination file.
"""
Expand Down
68 changes: 67 additions & 1 deletion scratchgpt/tokenizer/base_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from pathlib import Path
from typing import Self, TypeVar

T_SerializableTokenizer = TypeVar("T_SerializableTokenizer", bound="SerializableTokenizer")


TOKENIZER_REGISTRY: dict[str, type["SerializableTokenizer"]] = {}
"""
A simple registry to map tokenizer types to their classes.
This helps in dynamically loading the correct tokenizer.
"""

class Tokenizer(ABC):

class Tokenizer(ABC):
@abstractmethod
def encode(self, text: str) -> list[int]:
"""Convert a string into a sequence of token IDs."""
Expand All @@ -20,3 +31,58 @@ def vocab_size(self) -> int:
@abstractmethod
def vocabulary(self) -> list[str]:
"""Return the learned vocabulary"""


class SerializableTokenizer(Tokenizer):
"""
An extension of the Tokenizer ABC that adds methods for saving and loading.
"""

@abstractmethod
def save(self, tokenizer_path: Path) -> None:
"""
Saves the tokenizer's state to a specified directory.

This method should create a `tokenizer_config.json` with metadata
and any other necessary data files (e.g., vocabulary).

Args:
tokenizer_path: The directory path to save the tokenizer to.
"""
# Ensure the directory exists
tokenizer_path.mkdir(exist_ok=True, parents=True)

@classmethod
@abstractmethod
def load(cls, tokenizer_path: Path) -> Self:
"""
Loads a tokenizer from a specified directory.

Args:
tokenizer_path: The directory containing the tokenizer's state.

Returns:
An instance of the tokenizer.
"""
config_path = tokenizer_path / "tokenizer_config.json"
if not config_path.is_file():
raise FileNotFoundError(f"Tokenizer config not found at: {config_path}")

raise NotImplementedError


def register_tokenizer(
name: str,
) -> Callable[[type[T_SerializableTokenizer]], type[T_SerializableTokenizer]]:
"""
A decorator to register a tokenizer class in the registry, preserving its type.
"""

def decorator(cls: type[T_SerializableTokenizer]) -> type[T_SerializableTokenizer]:
# Runtime check is still good practice.
if not issubclass(cls, SerializableTokenizer):
raise TypeError("Registered tokenizer must be a subclass of SerializableTokenizer.")
TOKENIZER_REGISTRY[name] = cls
return cls

return decorator
Loading