Skip to content

Commit

Permalink
Merge pull request #14 from stochasticai/marcos/training_improvements
Browse files Browse the repository at this point in the history
feat: PyTorch-Lightning trainer improvements
  • Loading branch information
MarcosRiveraMartinez authored Mar 21, 2023
2 parents 1c9cf85 + 220efa7 commit 0ebf897
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 4 deletions.
4 changes: 4 additions & 0 deletions src/turing/engines/gpt2_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def validation_step(self, batch):

return acc

def save(self, saving_path: Union[str, Path]):
self.model.save_pretrained(saving_path)
self.tokenizer.save_pretrained(saving_path)


class GPT2LoraEngine(GPT2Engine):
config_name: str = "gpt2_lora_engine"
Expand Down
4 changes: 4 additions & 0 deletions src/turing/engines/gptj_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def validation_step(self, batch):

return acc

def save(self, saving_path: Union[str, Path]):
self.model.save_pretrained(saving_path)
self.tokenizer.save_pretrained(saving_path)


class GPTJLoraEngine(GPTJEngine):
config_name: str = "gptj_lora_engine"
Expand Down
4 changes: 4 additions & 0 deletions src/turing/engines/llama_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def validation_step(self, batch):

return acc

def save(self, saving_path: Union[str, Path]):
self.model.save_pretrained(saving_path)
self.tokenizer.save_pretrained(saving_path)


class LlamaLoraEngine(LLamaEngine):
config_name: str = "llama_lora_engine"
Expand Down
2 changes: 1 addition & 1 deletion src/turing/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,4 @@ def generate(
return outputs

def save(self, path: Union[str, Path]):
pass
self.engine.save(path)
2 changes: 1 addition & 1 deletion src/turing/models/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,4 @@ def generate(
return outputs

def save(self, path: Union[str, Path]):
pass
self.engine.save(path)
2 changes: 1 addition & 1 deletion src/turing/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,4 @@ def generate(
return outputs

def save(self, path: Union[str, Path]):
pass
self.engine.save(path)
31 changes: 30 additions & 1 deletion src/turing/trainers/lightning_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import datetime
import os
import tempfile
import uuid
from pathlib import Path
from typing import Optional

import pytorch_lightning as pl
Expand Down Expand Up @@ -50,7 +54,10 @@ def train_dataloader(self):
return self.train_dl

def training_step(self, batch, batch_idx):
return self.model_engine.training_step(batch)
loss = self.model_engine.training_step(batch)
self.log("loss", loss)

return loss

def validation_step(self, batch, batch_idx):
return self.model_engine.validation_step(batch)
Expand All @@ -65,18 +72,38 @@ def __init__(
train_dataset: BaseDataset,
preprocessor: BasePreprocessor,
max_epochs: int = 3,
max_training_time_in_secs: Optional[int] = None,
):
self.lightning_model = TuringLightningModule(
model_engine=model_engine,
train_dataset=train_dataset,
preprocessor=preprocessor,
)

checkpoints_dir_path = Path(tempfile.gettempdir()) / str(uuid.uuid4())

if not checkpoints_dir_path.exists():
checkpoints_dir_path.mkdir(exist_ok=True, parents=True)

training_callbacks = [
callbacks.LearningRateFinder(),
callbacks.BatchSizeFinder(),
callbacks.ModelCheckpoint(
dirpath=str(checkpoints_dir_path),
save_top_k=3,
monitor="loss",
mode="min", # Best model = min loss
every_n_train_steps=200,
),
]

if max_training_time_in_secs is not None:
training_callbacks.append(
callbacks.Timer(
duration=datetime.timedelta(seconds=max_training_time_in_secs)
)
)

self.trainer = Trainer(
num_nodes=1,
accelerator="gpu",
Expand All @@ -89,6 +116,8 @@ def __init__(

def fit(self):
self.trainer.fit(self.lightning_model)
if self.trainer.checkpoint_callback is not None:
self.trainer.checkpoint_callback.best_model_path

def engine(self):
return self.lightning_model.model_engine
99 changes: 99 additions & 0 deletions tests/tests.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"\n",
"dataset = load_dataset(\"glue\", \"mrpc\")\n",
"dataset['train'] = dataset['train'].remove_columns([\"sentence2\", \"label\", \"idx\"])\n",
"dataset['train'] = dataset['train'].rename_column(\"sentence1\", \"text\")\n",
"\n",
"dataset.save_to_disk(\"./test_dataset\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from turing.datasets.text_dataset import TextDataset\n",
"\n",
"text_dataset = TextDataset(\"./test_dataset\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from turing.models.gpt2 import GPT2\n",
"from turing.datasets.text_dataset import TextDataset\n",
"\n",
"text_dataset = TextDataset(\"./test_dataset\")\n",
"model = GPT2()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.finetune(dataset=text_dataset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"\n",
"dataset = load_dataset(\"glue\", \"mrpc\")\n",
"dataset['train'] = dataset['train'].remove_columns([\"label\", \"idx\"])\n",
"dataset['train'] = dataset['train'].rename_column(\"sentence1\", \"text\")\n",
"dataset['train'] = dataset['train'].rename_column(\"sentence2\", \"target\")\n",
"dataset['train'] = dataset['train'].add_column(\"instruction\", dataset['train'][\"target\"])\n",
"\n",
"dataset.save_to_disk(\"./test_instruction_dataset\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from turing.models.gpt2 import GPT2\n",
"from turing.datasets.instruction_dataset import InstructionDataset\n",
"\n",
"inst_dataset = InstructionDataset(\"./test_instruction_dataset\")\n",
"model = GPT2()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.finetune(dataset=inst_dataset)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 0ebf897

Please sign in to comment.