1+ import os
2+ import os .path as op
3+ import time
4+
5+ from datasets import load_dataset
6+ import torch
7+ from torch .utils .data import DataLoader
8+ import torchmetrics
9+ from transformers import AutoTokenizer
10+ from transformers import AutoModelForSequenceClassification
11+ from watermark import watermark
12+ from accelerate import Accelerator
13+
14+ from local_dataset_utilities import (
15+ download_dataset ,
16+ load_dataset_into_to_dataframe ,
17+ partition_dataset ,
18+ )
19+ from local_dataset_utilities import IMDBDataset
20+
21+ def tokenize_text (batch ):
22+ return tokenizer (batch ["text" ], truncation = True , padding = True )
23+
24+ def train (num_epochs , model , optimizer , train_loader , val_loader , device ):
25+ for epoch in range (num_epochs ):
26+ train_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 ).to (device )
27+
28+ for batch_idx , batch in enumerate (train_loader ):
29+ model .train ()
30+ for s in ["input_ids" , "attention_mask" , "label" ]:
31+ batch [s ] = batch [s ].to (device )
32+
33+ ### FORWARD AND BACK PROP
34+ outputs = model (
35+ batch ["input_ids" ],
36+ attention_mask = batch ["attention_mask" ],
37+ labels = batch ["label" ],
38+ )
39+ optimizer .zero_grad ()
40+ outputs ["loss" ].backward ()
41+
42+ ### UPDATE MODEL PARAMETERS
43+ optimizer .step ()
44+
45+ ### LOGGING
46+ if not batch_idx % 300 :
47+ print (
48+ f"Epoch: { epoch + 1 :04d} /{ num_epochs :04d} | Batch { batch_idx :04d} /{ len (train_loader ):04d} | Loss: { outputs ['loss' ]:.4f} "
49+ )
50+
51+ model .eval ()
52+ with torch .no_grad ():
53+ predicted_labels = torch .argmax (outputs ["logits" ], 1 )
54+ train_acc .update (predicted_labels , batch ["label" ])
55+
56+ ### MORE LOGGING
57+ with torch .no_grad ():
58+ model .eval ()
59+ val_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 ).to (device )
60+ for batch in val_loader :
61+ for s in ["input_ids" , "attention_mask" , "label" ]:
62+ batch [s ] = batch [s ].to (device )
63+ outputs = model (
64+ batch ["input_ids" ],
65+ attention_mask = batch ["attention_mask" ],
66+ labels = batch ["label" ],
67+ )
68+ predicted_labels = torch .argmax (outputs ["logits" ], 1 )
69+ val_acc .update (predicted_labels , batch ["label" ])
70+
71+ print (
72+ f"Epoch: { epoch + 1 :04d} /{ num_epochs :04d} | Train acc.: { train_acc .compute ()* 100 :.2f} % | Val acc.: { val_acc .compute ()* 100 :.2f} %"
73+ )
74+ train_acc .reset (), val_acc .reset ()
75+
76+ if __name__ == "__main__" :
77+ print (watermark (packages = "torch,lightning,transformers" , python = True ))
78+ print ("Torch CUDA available?" , torch .cuda .is_available ())
79+ #device = "cuda:0" if torch.cuda.is_available() else "cpu"
80+ accelerator = Accelerator ()
81+ device = accelerator .device
82+
83+ torch .manual_seed (123 )
84+
85+ ##########################
86+ ### 1 Loading the Dataset
87+ ##########################
88+ download_dataset ()
89+ df = load_dataset_into_to_dataframe ()
90+ if not (op .exists ("train.csv" ) and op .exists ("val.csv" ) and op .exists ("test.csv" )):
91+ partition_dataset (df )
92+
93+ imdb_dataset = load_dataset (
94+ "csv" ,
95+ data_files = {
96+ "train" : "train.csv" ,
97+ "validation" : "val.csv" ,
98+ "test" : "test.csv" ,
99+ },
100+ )
101+
102+ #########################################
103+ ### 2 Tokenization and Numericalization
104+ #########################################
105+
106+ tokenizer = AutoTokenizer .from_pretrained ("distilbert-base-uncased" )
107+ print ("Tokenizer input max length:" , tokenizer .model_max_length , flush = True )
108+ print ("Tokenizer vocabulary size:" , tokenizer .vocab_size , flush = True )
109+
110+ print ("Tokenizing ..." , flush = True )
111+ imdb_tokenized = imdb_dataset .map (tokenize_text , batched = True , batch_size = None )
112+ del imdb_dataset
113+ imdb_tokenized .set_format ("torch" , columns = ["input_ids" , "attention_mask" , "label" ])
114+ os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
115+
116+ #########################################
117+ ### 3 Set Up DataLoaders
118+ #########################################
119+
120+ train_dataset = IMDBDataset (imdb_tokenized , partition_key = "train" )
121+ val_dataset = IMDBDataset (imdb_tokenized , partition_key = "validation" )
122+ test_dataset = IMDBDataset (imdb_tokenized , partition_key = "test" )
123+
124+ train_loader = DataLoader (
125+ dataset = train_dataset ,
126+ batch_size = 12 ,
127+ shuffle = True ,
128+ num_workers = 4 ,
129+ drop_last = True ,
130+ )
131+
132+ val_loader = DataLoader (
133+ dataset = val_dataset ,
134+ batch_size = 12 ,
135+ num_workers = 4 ,
136+ drop_last = True ,
137+ )
138+
139+ test_loader = DataLoader (
140+ dataset = test_dataset ,
141+ batch_size = 12 ,
142+ num_workers = 2 ,
143+ drop_last = True ,
144+ )
145+
146+ #########################################
147+ ### 4 Initializing the Model
148+ #########################################
149+
150+ model = AutoModelForSequenceClassification .from_pretrained (
151+ "distilbert-base-uncased" , num_labels = 2
152+ )
153+
154+ model .to (device )
155+ optimizer = torch .optim .Adam (model .parameters (), lr = 5e-5 )
156+
157+ #########################################
158+ ### 5 Finetuning
159+ #########################################
160+
161+ model , optimizer , train_loader , val_loader , test_loader = accelerator .prepare (model , optimizer , train_loader , val_loader , test_loader )
162+
163+ start = time .time ()
164+ train (
165+ num_epochs = 3 ,
166+ model = model ,
167+ optimizer = optimizer ,
168+ train_loader = train_loader ,
169+ val_loader = val_loader ,
170+ device = device ,
171+ )
172+
173+ end = time .time ()
174+ elapsed = end - start
175+ print (f"Time elapsed { elapsed / 60 :.2f} min" )
176+
177+ with torch .no_grad ():
178+ model .eval ()
179+ test_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 ).to (device )
180+ for batch in test_loader :
181+ for s in ["input_ids" , "attention_mask" , "label" ]:
182+ batch [s ] = batch [s ].to (device )
183+ outputs = model (
184+ batch ["input_ids" ],
185+ attention_mask = batch ["attention_mask" ],
186+ labels = batch ["label" ],
187+ )
188+ predicted_labels = torch .argmax (outputs ["logits" ], 1 )
189+ test_acc .update (predicted_labels , batch ["label" ])
190+
191+ print (f"Test accuracy { test_acc .compute ()* 100 :.2f} %" )
0 commit comments