Skip to content

Commit db999ee

Browse files
committed
experimental code
1 parent 6fd2dab commit db999ee

File tree

1 file changed

+191
-0
lines changed

1 file changed

+191
-0
lines changed

.other/try-accelerate.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import os
2+
import os.path as op
3+
import time
4+
5+
from datasets import load_dataset
6+
import torch
7+
from torch.utils.data import DataLoader
8+
import torchmetrics
9+
from transformers import AutoTokenizer
10+
from transformers import AutoModelForSequenceClassification
11+
from watermark import watermark
12+
from accelerate import Accelerator
13+
14+
from local_dataset_utilities import (
15+
download_dataset,
16+
load_dataset_into_to_dataframe,
17+
partition_dataset,
18+
)
19+
from local_dataset_utilities import IMDBDataset
20+
21+
def tokenize_text(batch):
22+
return tokenizer(batch["text"], truncation=True, padding=True)
23+
24+
def train(num_epochs, model, optimizer, train_loader, val_loader, device):
25+
for epoch in range(num_epochs):
26+
train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(device)
27+
28+
for batch_idx, batch in enumerate(train_loader):
29+
model.train()
30+
for s in ["input_ids", "attention_mask", "label"]:
31+
batch[s] = batch[s].to(device)
32+
33+
### FORWARD AND BACK PROP
34+
outputs = model(
35+
batch["input_ids"],
36+
attention_mask=batch["attention_mask"],
37+
labels=batch["label"],
38+
)
39+
optimizer.zero_grad()
40+
outputs["loss"].backward()
41+
42+
### UPDATE MODEL PARAMETERS
43+
optimizer.step()
44+
45+
### LOGGING
46+
if not batch_idx % 300:
47+
print(
48+
f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {outputs['loss']:.4f}"
49+
)
50+
51+
model.eval()
52+
with torch.no_grad():
53+
predicted_labels = torch.argmax(outputs["logits"], 1)
54+
train_acc.update(predicted_labels, batch["label"])
55+
56+
### MORE LOGGING
57+
with torch.no_grad():
58+
model.eval()
59+
val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(device)
60+
for batch in val_loader:
61+
for s in ["input_ids", "attention_mask", "label"]:
62+
batch[s] = batch[s].to(device)
63+
outputs = model(
64+
batch["input_ids"],
65+
attention_mask=batch["attention_mask"],
66+
labels=batch["label"],
67+
)
68+
predicted_labels = torch.argmax(outputs["logits"], 1)
69+
val_acc.update(predicted_labels, batch["label"])
70+
71+
print(
72+
f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%"
73+
)
74+
train_acc.reset(), val_acc.reset()
75+
76+
if __name__ == "__main__":
77+
print(watermark(packages="torch,lightning,transformers", python=True))
78+
print("Torch CUDA available?", torch.cuda.is_available())
79+
#device = "cuda:0" if torch.cuda.is_available() else "cpu"
80+
accelerator = Accelerator()
81+
device = accelerator.device
82+
83+
torch.manual_seed(123)
84+
85+
##########################
86+
### 1 Loading the Dataset
87+
##########################
88+
download_dataset()
89+
df = load_dataset_into_to_dataframe()
90+
if not (op.exists("train.csv") and op.exists("val.csv") and op.exists("test.csv")):
91+
partition_dataset(df)
92+
93+
imdb_dataset = load_dataset(
94+
"csv",
95+
data_files={
96+
"train": "train.csv",
97+
"validation": "val.csv",
98+
"test": "test.csv",
99+
},
100+
)
101+
102+
#########################################
103+
### 2 Tokenization and Numericalization
104+
#########################################
105+
106+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
107+
print("Tokenizer input max length:", tokenizer.model_max_length, flush=True)
108+
print("Tokenizer vocabulary size:", tokenizer.vocab_size, flush=True)
109+
110+
print("Tokenizing ...", flush=True)
111+
imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None)
112+
del imdb_dataset
113+
imdb_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])
114+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
115+
116+
#########################################
117+
### 3 Set Up DataLoaders
118+
#########################################
119+
120+
train_dataset = IMDBDataset(imdb_tokenized, partition_key="train")
121+
val_dataset = IMDBDataset(imdb_tokenized, partition_key="validation")
122+
test_dataset = IMDBDataset(imdb_tokenized, partition_key="test")
123+
124+
train_loader = DataLoader(
125+
dataset=train_dataset,
126+
batch_size=12,
127+
shuffle=True,
128+
num_workers=4,
129+
drop_last=True,
130+
)
131+
132+
val_loader = DataLoader(
133+
dataset=val_dataset,
134+
batch_size=12,
135+
num_workers=4,
136+
drop_last=True,
137+
)
138+
139+
test_loader = DataLoader(
140+
dataset=test_dataset,
141+
batch_size=12,
142+
num_workers=2,
143+
drop_last=True,
144+
)
145+
146+
#########################################
147+
### 4 Initializing the Model
148+
#########################################
149+
150+
model = AutoModelForSequenceClassification.from_pretrained(
151+
"distilbert-base-uncased", num_labels=2
152+
)
153+
154+
model.to(device)
155+
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
156+
157+
#########################################
158+
### 5 Finetuning
159+
#########################################
160+
161+
model, optimizer, train_loader, val_loader, test_loader = accelerator.prepare(model, optimizer, train_loader, val_loader, test_loader)
162+
163+
start = time.time()
164+
train(
165+
num_epochs=3,
166+
model=model,
167+
optimizer=optimizer,
168+
train_loader=train_loader,
169+
val_loader=val_loader,
170+
device=device,
171+
)
172+
173+
end = time.time()
174+
elapsed = end - start
175+
print(f"Time elapsed {elapsed/60:.2f} min")
176+
177+
with torch.no_grad():
178+
model.eval()
179+
test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(device)
180+
for batch in test_loader:
181+
for s in ["input_ids", "attention_mask", "label"]:
182+
batch[s] = batch[s].to(device)
183+
outputs = model(
184+
batch["input_ids"],
185+
attention_mask=batch["attention_mask"],
186+
labels=batch["label"],
187+
)
188+
predicted_labels = torch.argmax(outputs["logits"], 1)
189+
test_acc.update(predicted_labels, batch["label"])
190+
191+
print(f"Test accuracy {test_acc.compute()*100:.2f}%")

0 commit comments

Comments
 (0)