1+ import os
2+ import os .path as op
3+ import time
4+
5+ from datasets import load_dataset
6+ import lightning as L
7+ from lightning .pytorch .callbacks import ModelCheckpoint
8+ from lightning .pytorch .loggers import CSVLogger
9+ import torch
10+ from torch .utils .data import DataLoader
11+ import torchmetrics
12+ from transformers import AutoTokenizer
13+ from transformers import AutoModelForSequenceClassification
14+ from watermark import watermark
15+
16+ from local_dataset_utilities import (
17+ download_dataset ,
18+ load_dataset_into_to_dataframe ,
19+ partition_dataset ,
20+ )
21+ from local_dataset_utilities import IMDBDataset
22+
23+
24+ def tokenize_text (batch ):
25+ return tokenizer (batch ["text" ], truncation = True , padding = True )
26+
27+
28+ class LightningModel (L .LightningModule ):
29+ def __init__ (self , model , learning_rate = 5e-5 ):
30+ super ().__init__ ()
31+
32+ self .learning_rate = learning_rate
33+ self .model = model
34+
35+ self .train_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 )
36+ self .val_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 )
37+ self .test_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 )
38+
39+ def forward (self , input_ids , attention_mask , labels ):
40+ return self .model (input_ids , attention_mask = attention_mask , labels = labels )
41+
42+ def training_step (self , batch , batch_idx ):
43+ outputs = self (
44+ batch ["input_ids" ],
45+ attention_mask = batch ["attention_mask" ],
46+ labels = batch ["label" ],
47+ )
48+ self .log ("train_loss" , outputs ["loss" ])
49+ with torch .no_grad ():
50+ logits = outputs ["logits" ]
51+ predicted_labels = torch .argmax (logits , 1 )
52+ self .train_acc (predicted_labels , batch ["label" ])
53+ self .log ("train_acc" , self .train_acc , on_epoch = True , on_step = False )
54+ return outputs ["loss" ] # this is passed to the optimizer for training
55+
56+ def validation_step (self , batch , batch_idx ):
57+ outputs = self (
58+ batch ["input_ids" ],
59+ attention_mask = batch ["attention_mask" ],
60+ labels = batch ["label" ],
61+ )
62+ self .log ("val_loss" , outputs ["loss" ], prog_bar = True )
63+
64+ logits = outputs ["logits" ]
65+ predicted_labels = torch .argmax (logits , 1 )
66+ self .val_acc (predicted_labels , batch ["label" ])
67+ self .log ("val_acc" , self .val_acc , prog_bar = True )
68+
69+ def test_step (self , batch , batch_idx ):
70+ outputs = self (
71+ batch ["input_ids" ],
72+ attention_mask = batch ["attention_mask" ],
73+ labels = batch ["label" ],
74+ )
75+
76+ logits = outputs ["logits" ]
77+ predicted_labels = torch .argmax (logits , 1 )
78+ self .test_acc (predicted_labels , batch ["label" ])
79+ self .log ("accuracy" , self .test_acc , prog_bar = True )
80+
81+ def configure_optimizers (self ):
82+ optimizer = torch .optim .Adam (
83+ self .trainer .model .parameters (), lr = self .learning_rate
84+ )
85+ return optimizer
86+
87+
88+ if __name__ == "__main__" :
89+ print (watermark (packages = "torch,lightning,transformers" , python = True ), flush = True )
90+ print ("Torch CUDA available?" , torch .cuda .is_available (), flush = True )
91+
92+ torch .manual_seed (123 )
93+
94+ ##########################
95+ ### 1 Loading the Dataset
96+ ##########################
97+ download_dataset ()
98+ df = load_dataset_into_to_dataframe ()
99+ if not (op .exists ("train.csv" ) and op .exists ("val.csv" ) and op .exists ("test.csv" )):
100+ partition_dataset (df )
101+
102+ imdb_dataset = load_dataset (
103+ "csv" ,
104+ data_files = {
105+ "train" : "train.csv" ,
106+ "validation" : "val.csv" ,
107+ "test" : "test.csv" ,
108+ },
109+ )
110+
111+ #########################################
112+ ### 2 Tokenization and Numericalization
113+ ########################################
114+
115+ tokenizer = AutoTokenizer .from_pretrained ("distilbert-base-uncased" )
116+ print ("Tokenizer input max length:" , tokenizer .model_max_length , flush = True )
117+ print ("Tokenizer vocabulary size:" , tokenizer .vocab_size , flush = True )
118+
119+ print ("Tokenizing ..." , flush = True )
120+ imdb_tokenized = imdb_dataset .map (tokenize_text , batched = True , batch_size = None )
121+ del imdb_dataset
122+ imdb_tokenized .set_format ("torch" , columns = ["input_ids" , "attention_mask" , "label" ])
123+ os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
124+
125+ #########################################
126+ ### 3 Set Up DataLoaders
127+ #########################################
128+
129+ train_dataset = IMDBDataset (imdb_tokenized , partition_key = "train" )
130+ val_dataset = IMDBDataset (imdb_tokenized , partition_key = "validation" )
131+ test_dataset = IMDBDataset (imdb_tokenized , partition_key = "test" )
132+
133+ train_loader = DataLoader (
134+ dataset = train_dataset ,
135+ batch_size = 12 ,
136+ shuffle = True ,
137+ num_workers = 1 ,
138+ drop_last = True ,
139+ )
140+
141+ val_loader = DataLoader (
142+ dataset = val_dataset ,
143+ batch_size = 12 ,
144+ num_workers = 1 ,
145+ drop_last = True ,
146+ )
147+
148+ test_loader = DataLoader (
149+ dataset = test_dataset ,
150+ batch_size = 12 ,
151+ num_workers = 1 ,
152+ drop_last = True ,
153+ )
154+
155+ #########################################
156+ ### 4 Initializing the Model
157+ #########################################
158+
159+ model = AutoModelForSequenceClassification .from_pretrained (
160+ "distilbert-base-uncased" , num_labels = 2
161+ )
162+
163+ #########################################
164+ ### 5 Finetuning
165+ #########################################
166+
167+ lightning_model = LightningModel (model )
168+
169+ callbacks = [
170+ ModelCheckpoint (save_top_k = 1 , mode = "max" , monitor = "val_acc" ) # save top 1 model
171+ ]
172+ logger = CSVLogger (save_dir = "logs/" , name = "my-model" )
173+
174+ trainer = L .Trainer (
175+ max_epochs = 3 ,
176+ callbacks = callbacks ,
177+ accelerator = "gpu" ,
178+ devices = 1 ,
179+ precision = "16" , # <-- NEW
180+ logger = logger ,
181+ log_every_n_steps = 10 ,
182+ deterministic = True ,
183+ )
184+
185+ start = time .time ()
186+ trainer .fit (
187+ model = lightning_model ,
188+ train_dataloaders = train_loader ,
189+ val_dataloaders = val_loader ,
190+ )
191+
192+ end = time .time ()
193+ elapsed = end - start
194+ print (f"Time elapsed { elapsed / 60 :.2f} min" )
195+
196+ test_acc = trainer .test (lightning_model , dataloaders = test_loader , ckpt_path = "best" )
197+ print (test_acc )
198+
199+ with open (op .join (trainer .logger .log_dir , "outputs.txt" ), "w" ) as f :
200+ f .write ((f"Time elapsed { elapsed / 60 :.2f} min\n " ))
201+ f .write (f"Test acc: { test_acc } " )
0 commit comments