Skip to content

Commit

Permalink
update tutorial with new features
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed May 29, 2023
1 parent 69c8902 commit 651ccdf
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 23 deletions.
2 changes: 0 additions & 2 deletions tutorial/1_quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def training_step(self, batch):
optimizer=reweight_optimizer,
train_data_loader=reweight_dataloader,
config=reweight_config,
device=device,
)

####################
Expand Down Expand Up @@ -138,7 +137,6 @@ def training_step(self, batch):
scheduler=classifier_scheduler,
train_data_loader=classifier_dataloader,
config=classifier_config,
device=device,
)


Expand Down
8 changes: 2 additions & 6 deletions tutorial/2_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def build_dataset(reweight_size=1000, imbalanced_factor=100):
valid_dataset = MNIST(root="./data", train=False, transform=transform)
valid_dataloader = DataLoader(valid_dataset, batch_size=100, pin_memory=True)

device = "cuda" if torch.cuda.is_available() else "cpu"


####################
##### Reweight #####
Expand Down Expand Up @@ -102,7 +100,6 @@ def training_step(self, batch):
optimizer=reweight_optimizer,
train_data_loader=reweight_dataloader,
config=reweight_config,
device=device,
)

####################
Expand Down Expand Up @@ -139,7 +136,6 @@ def training_step(self, batch):
scheduler=classifier_scheduler,
train_data_loader=classifier_dataloader,
config=classifier_config,
device=device,
)


Expand All @@ -151,15 +147,15 @@ def validation(self):
if not hasattr(self, "best_acc"):
self.best_acc = -1
for x, target in valid_dataloader:
x, target = x.to(device), target.to(device)
x, target = x.to(self.device), target.to(self.device)
out = self.classifier(x)
correct += (out.argmax(dim=1) == target).sum().item()
total += x.size(0)
acc = correct / total * 100
if self.best_acc < acc:
self.best_acc = acc

print("acc:", acc, "best_acc:", self.best_acc)
return {"acc": acc, "best_acc": self.best_acc}


engine_config = EngineConfig(train_iters=3000, valid_step=100)
Expand Down
6 changes: 1 addition & 5 deletions tutorial/3_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def build_dataset(reweight_size=1000, imbalanced_factor=100):
valid_dataset = MNIST(root="./data", train=False, transform=transform)
valid_dataloader = DataLoader(valid_dataset, batch_size=100, pin_memory=True)

device = "cuda" if torch.cuda.is_available() else "cpu"


####################
##### Reweight #####
Expand Down Expand Up @@ -103,7 +101,6 @@ def training_step(self, batch):
optimizer=reweight_optimizer,
train_data_loader=reweight_dataloader,
config=reweight_config,
device=device,
)

####################
Expand Down Expand Up @@ -140,7 +137,6 @@ def training_step(self, batch):
scheduler=classifier_scheduler,
train_data_loader=classifier_dataloader,
config=classifier_config,
device=device,
)


Expand All @@ -152,7 +148,7 @@ def validation(self):
if not hasattr(self, "best_acc"):
self.best_acc = -1
for x, target in valid_dataloader:
x, target = x.to(device), target.to(device)
x, target = x.to(self.device), target.to(self.device)
out = self.classifier(x)
correct += (out.argmax(dim=1) == target).sum().item()
total += x.size(0)
Expand Down
8 changes: 3 additions & 5 deletions tutorial/4_memory_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from betty.configs import Config, EngineConfig


fp16 = True
precision = "fp16"
distributed = True


Expand Down Expand Up @@ -112,14 +112,13 @@ def training_step(self, batch):
return {"loss": loss, "acc": acc}


reweight_config = Config(log_step=100, fp16=fp16)
reweight_config = Config(log_step=100, precision=precision)
reweight = Reweight(
name="reweight",
module=reweight_module,
optimizer=reweight_optimizer,
train_data_loader=reweight_dataloader,
config=reweight_config,
device=device,
)

####################
Expand Down Expand Up @@ -223,15 +222,14 @@ def training_step(self, batch):
return torch.mean(weight * loss_reshape)


classifier_config = Config(type="darts", unroll_steps=1, fp16=fp16)
classifier_config = Config(type="darts", unroll_steps=1, precision=precision)
classifier = Classifier(
name="classifier",
module=classifier_module,
optimizer=classifier_optimizer,
scheduler=classifier_scheduler,
train_data_loader=classifier_dataloader,
config=classifier_config,
device=device,
)


Expand Down
8 changes: 3 additions & 5 deletions tutorial/5_distributed_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ def build_dataset(reweight_size=1000, imbalanced_factor=100):
valid_dataset = CIFAR10(root="./data", train=False, transform=valid_transform)
valid_dataloader = DataLoader(valid_dataset, batch_size=50, pin_memory=True)

device = "cuda" if torch.cuda.is_available() else "cpu"


####################
##### Reweight #####
Expand All @@ -114,7 +112,7 @@ def training_step(self, batch):
return {"loss": loss, "acc": acc}


reweight_config = Config(log_step=100, fp16=True)
reweight_config = Config(log_step=100, precision="fp16")
reweight = Reweight(
name="reweight",
module=reweight_module,
Expand Down Expand Up @@ -224,7 +222,7 @@ def training_step(self, batch):
return torch.mean(weight * loss_reshape)


classifier_config = Config(type="darts", unroll_steps=1, fp16=True)
classifier_config = Config(type="darts", unroll_steps=1, precision="fp16")
classifier = Classifier(
name="classifier",
module=classifier_module,
Expand All @@ -243,7 +241,7 @@ def validation(self):
if not hasattr(self, "best_acc"):
self.best_acc = -1
for x, target in valid_dataloader:
x, target = x.to(device), target.to(device)
x, target = x.to(self.device), target.to(self.device)
out = self.classifier(x)
correct += (out.argmax(dim=1) == target).sum().item()
total += x.size(0)
Expand Down

0 comments on commit 651ccdf

Please sign in to comment.