Skip to content

Commit

Permalink
feat: add README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
pouya-parsa committed Jun 8, 2023
1 parent 22e1350 commit cb71a35
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
44 changes: 44 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Transformer Encoder for Language Modeling
Transformers have gained prominence as a result of addressing the limitations of previous approaches
to language modeling, namely Word2Vec and RNNs. Word2Vec suffers from assigning a fixed vector to
each word without considering its contextual dependencies. On the other hand, RNNs were
slow and unidirectional, focusing solely on the words preceding a particular word.
In contrast, transformers are bi-directional and, despite their O(N^2) complexity, modern hardware
allows for fast parallel computations. Crucially, transformers vectorize words based on
their surrounding context, meaning that the same word can have different representations in different sentences.

## Usage

First initialize train and validation dataloaders:
```python
dataloader_builder = DataloaderBulder()
vocab_size = dataloader_builder.vocab_size
train_dataloader, val_dataloader = dataloader_builder.get_loaders()
```

Then initialize the model:
```python
model = Predictor(max_seq_length, vocab_size, embed_dim, 6)
model.to(device)
```

Initialize criterion, optimizer, scheduler:
```python
criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
```

Train the model for one epoch
```python
train_epoch_acc, train_epoch_loss = train(model, optimizer, criterion, train_dataloader)
```

Evaluate the model on validation set:
```python
accu_val, loss_val = evaluate(model, optimizer, criterion, val_dataloader)
```




13 changes: 6 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,27 @@
import torch.nn as nn
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from predictor import Predictor
from settings import max_seq_length, embed_dim, device, epochs, lr
from predictor import Predictor
from dataset.build_dataloader import DataloaderBulder
from predictor import Predictor
from train import train
from evaluate import evaluate
from settings import max_seq_length, embed_dim, device, epochs, lr

dataloader_builder = DataloaderBulder()
vocab_size = dataloader_builder.vocab_size
train_dataloader, val_dataloader = dataloader_builder.get_loaders()

model = Predictor(max_seq_length, vocab_size, embed_dim, 6)
model.to(device)
input_data = torch.randint(low=0, high=10, size=(8, 100))
input_data = input_data.to(device)
out_dist = model(input_data)
# input_data = torch.randint(low=0, high=10, size=(8, 100))
# input_data = input_data.to(device)
# out_dist = model(input_data)

criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

total_accu = None
train_loss_list = []
train_acc_list = []
val_loss_list = []
Expand Down

0 comments on commit cb71a35

Please sign in to comment.