Help for adversarial learning with pytorch lighting #5795
Unanswered
yang9112
asked this question in
code help: CV
Replies: 3 comments 3 replies
-
@yang9112 The GAN example might help you: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/generative_adversarial_net.py |
Beta Was this translation helpful? Give feedback.
2 replies
-
Hi, You should use
|
Beta Was this translation helpful? Give feedback.
1 reply
-
I write a solution, maybe works from typing import Any, Dict, Iterator, List, Optional, cast
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch import Tensor
from torch.optim.optimizer import Optimizer
def _get_params(optimizer: Optimizer) -> Iterator[Tensor]:
for param_group in cast(List[Dict[Any, Any]], optimizer.param_groups):
for param in param_group["params"]:
if not isinstance(param, Tensor):
raise TypeError(f"expected Tensor, but got: {type(param)}")
yield param
def _get_loss(step_output: STEP_OUTPUT) -> Optional[Tensor]:
if step_output is None:
return None
if isinstance(step_output, Tensor):
return step_output
return step_output.get("loss")
class FGM(torch.nn.Module):
def __init__(self, model, emb_name="word_embeddings"):
super().__init__()
self.model = model
self.backup = {}
self.emb_name = emb_name
def forward(self ,x):
return self.model(x)
def attack(self, epsilon=1.):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and self.emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0 and not torch.isnan(norm):
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and self.emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
class FGMCallBack(Callback):
_batch: Any
_batch_idx: int
def __init__(self, emb_name="word_embeddings") -> None:
super().__init__()
self.emb_name = emb_name
def on_train_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any,
batch_idx: int,
unused: int = 0,
) -> None:
self._batch = batch
self._batch_idx = batch_idx
@torch.no_grad()
def on_before_optimizer_step(
self,
trainer: Trainer,
pl_module: LightningModule,
optimizer: Optimizer,
opt_idx: int,
) -> None:
# org_weights = self._first_step(optimizer)
with torch.enable_grad():
step_output = pl_module.training_step(self._batch, self._batch_idx)
loss = _get_loss(step_output)
if loss is not None:
trainer.accelerator.backward(
loss, optimizer=optimizer, optimizer_idx=opt_idx
)
pl_module.model.attack()
adv_loss = _get_loss(pl_module.training_step(self._batch, self._batch_idx))
trainer.accelerator.backward(
adv_loss, optimizer=optimizer, optimizer_idx=opt_idx
)
pl_module.model.restore()
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
pl_module.model = FGM(model=pl_module.model)
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
pl_module.model = pl_module.model.model
if __name__ == "__main__":
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, num_samples):
self.data = torch.randn(num_samples, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.word_embeddings = torch.nn.Linear(32, 2)
self.mlp = torch.nn.Linear(2,2)
def forward(self, x):
return self.mlp(self.word_embeddings(x))
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.model = Net()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
return self(batch).mean()
def configure_optimizers(self):
return torch.optim.SGD(self.model.parameters(), lr=0.1)
model = BoringModel()
# trainer = Trainer(max_epochs=3, callbacks=[FGMCallBack()])
trainer = Trainer(max_epochs=3, callbacks=[])
trainer.fit(model, train_dataloaders=DataLoader(RandomDataset(32, 64), batch_size=2)) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Help for adversarial learning with pytorch lighting
What is your question?
Code
the old method for adversarial learning is like this:
but I don't know how to deal it by pytorch lighting
What have you tried?
It seems to work, but I‘m not sure if it is correct or there is a better way to achieve it.
What's your environment?
Beta Was this translation helpful? Give feedback.
All reactions