Skip to content

Commit

Permalink
Standalone nanotron config (#285)
Browse files Browse the repository at this point in the history
What does this implement/fix? Explain your changes.
---------------------------------------------------
This PR moves the lighteval config to lighteval codebase.
- Enforces the lighteval_config_path as the only way to read the lighteval config. The nanotron part is ignore, this way the breaking changes won't be as breaking.
- Some typing corrections

---------

Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com>
Co-authored-by: Nathan Habib <nathan.habib@huggingface.co>
Co-authored-by: Hynek Kydlicek <kydlicek.hynek@huggingface.co>
  • Loading branch information
4 people authored Sep 4, 2024
1 parent 21934d5 commit aaa8bbf
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 76 deletions.
18 changes: 9 additions & 9 deletions examples/nanotron/lighteval_config_override_template.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
batch_size: 16
checkpoints_path: null
# As of right now auto batch size doesn't work, so we use some default
batch_size: 8
generation: null
logging:
hub_repo_details: null
hub_repo_results: null
hub_repo_tensorboard: null
local_output_path: ./output_dir
push_details_to_hub: false
output_dir: "outputs"
save_details: false
push_results_to_hub: false
push_results_to_tensorboard: true
tensorboard_metric_prefix: e
push_details_to_hub: false
push_results_to_tensorboard: false
public_run: false
results_org: null
tensorboard_metric_prefix: "eval"
parallelism:
dp: 1
pp: 1
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def cli_evaluate():
elif args.subcommand == "nanotron":
from lighteval.main_nanotron import main as main_nanotron

main_nanotron(args.checkpoint_config_path, args.lighteval_override, args.cache_dir)
main_nanotron(args.checkpoint_config_path, args.lighteval_config_path, args.cache_dir)

elif args.subcommand == "tasks":
if args.list:
Expand Down
100 changes: 100 additions & 0 deletions src/lighteval/config/lighteval_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from dataclasses import dataclass
from typing import Dict, Optional, Union

from nanotron.config import Config
from nanotron.config.parallelism_config import ParallelismArgs
from nanotron.generation.sampler import SamplerType
from nanotron.logging import get_logger


logger = get_logger(__name__)

DEFAULT_GENERATION_SEED = 42


@dataclass
class GenerationArgs:
sampler: Optional[Union[str, SamplerType]] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
n_samples: Optional[int] = None
eos: Optional[str] = None
seed: Optional[int] = None
use_cache: Optional[bool] = False

def __post_init__(self):
if isinstance(self.sampler, str):
self.sampler = SamplerType[self.sampler.upper()]
if self.seed is None:
self.seed = DEFAULT_GENERATION_SEED


@dataclass
class LightEvalLoggingArgs:
"""Arguments related to logging for LightEval"""

output_dir: str
save_details: bool = True
push_results_to_hub: bool = False
push_details_to_hub: bool = False
push_results_to_tensorboard: bool = False
public_run: bool = False
results_org: str | None = None
tensorboard_metric_prefix: str = "eval"


@dataclass
class LightEvalTasksArgs:
"""Arguments related to tasks for LightEval"""

tasks: str
custom_tasks: Optional[str] = None
max_samples: Optional[int] = None
num_fewshot_seeds: Optional[int] = None

dataset_loading_processes: int = 8
multichoice_continuations_start_space: Optional[bool] = None


@dataclass
class LightEvalConfig:
"""Arguments related to running LightEval on checkpoints.
All is optional because you can also use this class to later supply arguments to override
the saved config when running LightEval after training.
"""

logging: LightEvalLoggingArgs
tasks: LightEvalTasksArgs
parallelism: ParallelismArgs
batch_size: int = 0
generation: Optional[Union[GenerationArgs, Dict[str, GenerationArgs]]] = None


@dataclass
class FullNanotronConfig:
lighteval_config: LightEvalConfig
nanotron_config: Config
11 changes: 5 additions & 6 deletions src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class EvaluationTracker:

def __init__(
self,
output_dir: str = None,
hub_results_org: str = "",
output_dir: str,
hub_results_org: str | None = None,
push_results_to_hub: bool = False,
push_details_to_hub: bool = False,
push_results_to_tensorboard: bool = False,
Expand Down Expand Up @@ -133,14 +133,13 @@ def __init__(

self.output_dir = output_dir

self.hub_results_org = hub_results_org # will also contain tensorboard results
if hub_results_org in ["", None] and any(
[push_details_to_hub, push_results_to_hub, push_results_to_tensorboard]
):
if hub_results_org in [None] and any([push_details_to_hub, push_results_to_hub, push_results_to_tensorboard]):
raise Exception(
"You need to select which org to push to, using `--results_org`, if you want to save information to the hub."
)

self.hub_results_org = hub_results_org # will also contain tensorboard results

self.hub_results_repo = f"{hub_results_org}/results"
self.hub_private_results_repo = f"{hub_results_org}/private-results"
self.push_results_to_hub = push_results_to_hub
Expand Down
25 changes: 12 additions & 13 deletions src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import os
from typing import Optional

from lighteval.config.lighteval_config import FullNanotronConfig, LightEvalConfig
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.logging.hierarchical_logger import htrack, htrack_block
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
Expand All @@ -34,7 +35,7 @@
if not is_nanotron_available():
raise ImportError(NO_NANOTRON_ERROR_MSG)

from nanotron.config import Config, LightEvalConfig, get_config_from_file
from nanotron.config import Config, get_config_from_file


SEED = 1234
Expand All @@ -60,28 +61,26 @@ def main(
skip_unused_config_keys=True,
skip_null_keys=True,
)
if lighteval_config_path:
lighteval_config = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig)
model_config.lighteval = lighteval_config
else:
lighteval_config = model_config.lighteval

# We are getting an type error, because the get_config_from_file is not correctly typed,
lighteval_config: LightEvalConfig = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) # type: ignore
nanotron_config = FullNanotronConfig(lighteval_config, model_config)

evaluation_tracker = EvaluationTracker(
token=os.getenv("HF_TOKEN"),
output_dir=lighteval_config.logging.local_output_path,
hub_results_org=lighteval_config.logging.hub_repo_tensorboard,
output_dir=lighteval_config.logging.output_dir,
hub_results_org=lighteval_config.logging.results_org,
tensorboard_metric_prefix=lighteval_config.logging.tensorboard_metric_prefix,
nanotron_run_info=model_config.general,
nanotron_run_info=nanotron_config.nanotron_config.general,
)

pipeline_parameters = PipelineParameters(
launcher_type=ParallelismManager.NANOTRON,
env_config=env_config,
job_id=os.environ.get("SLURM_JOB_ID", None),
job_id=os.environ.get("SLURM_JOB_ID", 0),
nanotron_checkpoint_path=checkpoint_config_path,
dataset_loading_processes=lighteval_config.tasks.dataset_loading_processes,
custom_tasks_directory=lighteval_config.tasks.custom_tasks,
override_batch_size=None,
override_batch_size=lighteval_config.batch_size,
num_fewshot_seeds=1,
max_samples=lighteval_config.tasks.max_samples,
use_chat_template=False,
Expand All @@ -92,7 +91,7 @@ def main(
tasks=lighteval_config.tasks.tasks,
pipeline_parameters=pipeline_parameters,
evaluation_tracker=evaluation_tracker,
model_config=model_config,
model_config=nanotron_config,
)

pipeline.evaluate()
Expand Down
7 changes: 0 additions & 7 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def __init__(
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation."""
self._config = config.init_configs(env_config)
self.accelerator = config.accelerator
self._batch_size = config.batch_size
self._max_length = self._init_max_length(config.max_length)
self.use_chat_template = config.use_chat_template

Expand Down Expand Up @@ -285,12 +284,6 @@ def _init_max_length(self, max_length) -> int:
# or no max length config setting is found in the model or tokenizer.
return 2048

@property
def batch_size(self) -> int:
if self._batch_size >= 0:
self._batch_size = self._get_batch_size(max_input_length=self.max_length)
return self._batch_size # * gpus

@property
def device(self) -> Union[int, str, torch.device]:
return self._device
Expand Down
40 changes: 16 additions & 24 deletions src/lighteval/models/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tqdm import tqdm
from transformers import AutoTokenizer, BatchEncoding

from lighteval.config.lighteval_config import FullNanotronConfig
from lighteval.data import (
GenDistributedSampler,
GenerativeTaskDatasetNanotron,
Expand All @@ -55,18 +56,16 @@
)
from lighteval.utils.imports import is_nanotron_available
from lighteval.utils.parallelism import find_executable_batch_size
from lighteval.utils.utils import EnvConfig, as_list, boolstring_to_bool
from lighteval.utils.utils import EnvConfig, as_list


os.environ["TOKENIZERS_PARALLELISM"] = "false"

TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding]

if is_nanotron_available():
import nanotron
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import LightEvalConfig, ModelArgs, TokenizerArgs
from nanotron.generation.decode import decode_tokenized
from nanotron.logging import human_format, log_rank
from nanotron.models import build_model
Expand All @@ -90,7 +89,7 @@ class NanotronLightevalModel(LightevalModel):
def __init__(
self,
checkpoint_path: str,
nanotron_config: nanotron.config.Config,
nanotron_config: FullNanotronConfig,
parallel_context: ParallelContext,
max_gen_toks: Optional[int] = 256,
max_length: Optional[int] = None,
Expand All @@ -104,12 +103,11 @@ def __init__(
"""Initializes a nanotron model for evaluation.
Args:
"""
model_args: ModelArgs = nanotron_config.model
tokenizer: TokenizerArgs = nanotron_config.tokenizer
lighteval_config: LightEvalConfig = nanotron_config.lighteval
parallel_config: ParallelContext = nanotron_config.lighteval.parallelism
model_args = nanotron_config.nanotron_config.model
tokenizer = nanotron_config.nanotron_config.tokenizer
lighteval_config = nanotron_config.lighteval_config
parallel_config = nanotron_config.lighteval_config.parallelism

self._batch_size = lighteval_config.batch_size
self._max_gen_toks = max_gen_toks
self._max_length = max_length
self.parallel_config = parallel_config
Expand All @@ -120,9 +118,7 @@ def __init__(
raise ValueError("PP parallelism is not supported yet")

# multichoice_continuations_start_space can be True (forcing space), False (forcing no space) or None (no forcing)
multichoice_continuations_start_space = boolstring_to_bool(
lighteval_config.tasks.multichoice_continuations_start_space
)
multichoice_continuations_start_space = lighteval_config.tasks.multichoice_continuations_start_space

self.generation_config = lighteval_config.generation
if isinstance(self.generation_config, dict):
Expand Down Expand Up @@ -217,7 +213,9 @@ def __init__(

self.multichoice_continuations_start_space = multichoice_continuations_start_space

self.model_info = ModelInfo(model_name=f"{nanotron_config.general.run}/{nanotron_config.general.step}")
self.model_info = ModelInfo(
model_name=f"{nanotron_config.nanotron_config.general.run}/{nanotron_config.nanotron_config.general.step}"
)

@property
def tokenizer(self):
Expand Down Expand Up @@ -299,12 +297,6 @@ def max_length(self) -> int:
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH

@property
def batch_size(self) -> int:
if self._batch_size >= 0:
self._batch_size = self._get_batch_size(max_input_length=self.max_length)
return self._batch_size # * gpus

@property
def device(self) -> Union[int, str, torch.device]:
return "cuda"
Expand Down Expand Up @@ -415,7 +407,7 @@ def _check_continuations_start_space(self, continuation: str) -> str:
return continuation

def loglikelihood_single_token(
self, requests: List[Tuple[str, dict]], override_bs=None
self, requests: List[Tuple[str, dict]], override_bs=0
) -> List[LoglikelihoodSingleTokenResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
Expand Down Expand Up @@ -475,7 +467,7 @@ def loglikelihood(self, requests: List[LoglikelihoodRequest], override_bs=None)
)

def loglikelihood_rolling(
self, requests: List[LoglikelihoodRollingRequest], override_bs=None
self, requests: List[LoglikelihoodRollingRequest], override_bs: int = 0
) -> List[LoglikelihoodResponse]:
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
for request in tqdm(
Expand Down Expand Up @@ -652,7 +644,7 @@ def _get_subsets(self, dataset, num_dataset_splits):

@torch.inference_mode()
def _loglikelihood_single_token(
self, requests, disable_tqdm: bool = False, override_bs: int = -1, num_dataset_splits: int = 1
self, requests, disable_tqdm: bool = False, override_bs: int = 0, num_dataset_splits: int = 1
) -> List[LoglikelihoodSingleTokenResponse]:
dataset = LoglikelihoodSingleTokenDataset(requests=requests)
res = []
Expand Down Expand Up @@ -1115,7 +1107,7 @@ def greedy_until(
self,
requests: List[GreedyUntilRequest],
disable_tqdm: bool = False,
override_bs=None,
override_bs: int = -1,
num_dataset_splits: int = 1,
) -> List[GenerativeResponse]:
"""Greedy generation until a stop token is generated."""
Expand Down Expand Up @@ -1155,7 +1147,7 @@ def greedy_until(
max_input_length = min(len(context_enc) + max_gen, self.max_length)

batch_size = self._get_batch_size(
override_bs=self._batch_size,
override_bs=override_bs,
max_input_length=max_input_length,
starting_batch_size=starting_batch_size,
)
Expand Down
Loading

0 comments on commit aaa8bbf

Please sign in to comment.