Skip to content

Commit

Permalink
re-adding code to ff
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Aug 2, 2024
1 parent 5356a0e commit 0af9583
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 12 deletions.
2 changes: 1 addition & 1 deletion flood_forecast/multi_models/crossvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(
:type dropout: float, optional
:param use_rotary: Whether to use rotary positional embeddings, defaults to True
:type use_rotary: bool, optional
:param use_glu: _description_, defaults to True
:param use_glu: Weather to use gated linear units , defaults to True
:type use_glu: bool, optional
"""

Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/pytorch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def multi_crit(crit_multi: List, output, labels, valid=None):
"""_summary_
"""Used for computing the loss when there are multiple criteria.
:param crit_multi: _description_
:type crit_multi: List
Expand Down
8 changes: 4 additions & 4 deletions flood_forecast/transformer_xl/data_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim, freq_type="lucidrains", **kwargs):
def __init__(self, dim: int, freq_type="lucidrains", **kwargs):
super().__init__()
self.dim = dim
self.freq_type = freq_type
Expand Down Expand Up @@ -57,11 +57,11 @@ def forward(self, coords: torch.Tensor):

class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=5000):
"""[summary]
"""Create the positional embedding for use in the transformer and attention mechanisms.
:param d_model: [description]
:param d_model: The dimension of the positional embedding.
:type d_model: int
:param max_len: [description], defaults to 5000
:param max_len: The max length of the forecast_history, defaults to 5000
:type max_len: int, optional
"""
super(PositionalEmbedding, self).__init__()
Expand Down
39 changes: 33 additions & 6 deletions tests/mult_modal_tests/test_cross_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,52 @@
import torch
from flood_forecast.multi_models.crossvivit import RoCrossViViT, VisionTransformer
from flood_forecast.transformer_xl.attn import SelfAttention
from flood_forecast.transformer_xl.data_embedding import CyclicalEmbedding, NeRF_embedding
from flood_forecast.transformer_xl.data_embedding import CyclicalEmbedding, NeRF_embedding, PositionalEncoding2D


class TestCrossVivVit(unittest.TestCase):
def setUp(self):
self.crossvivit = RoCrossViViT(image_size=(128, 128), patch_size=(8, 8), time_coords_encoder=NeRF_embedding(), **{"max_freq":12})
self.crossvivit = RoCrossViViT(image_size=(128, 128), patch_size=(8, 8), time_coords_encoder=CyclicalEmbedding(), **{"max_freq":12})

def test_positional_encoding_forward(self):
"""
Test the positional encoding forward pass.
"""
positional_encoding = PositionalEncoding2D(128)
coords = torch.rand(5, 2, 32, 32)
output = positional_encoding(coords)
self.assertEqual(output.shape, (5, 32, 32, 128))

def test_vivit_model(self):
self.vivit_model = VisionTransformer(128, 5, 8, 128, 128, [512, 512, 512], dropout=0.1)
self.vivit_model(torch.rand(5, 512, 128), torch.rand(5, 512, 128))
pass
out = self.vivit_model(torch.rand(5, 512, 128), (torch.rand(5, 512, 64), torch.rand(5, 512, 64)))
assert out[0].shape == (5, 512, 128)

def test_forward(self):
x = self.crossvivit(torch.randn(1, 3, 128, 128), torch.randn(1, 3, 128, 128), torch.randn(1, 3, 128, 128), )
"""
ctx (torch.Tensor): Context frames of shape [B, T, C, H, W]
ctx_coords (torch.Tensor): Coordinates of context frames of shape [B, 2, H, W]
ts (torch.Tensor): Station timeseries of shape [B, T, C]
ts_coords (torch.Tensor): Station coordinates of shape [B, 2, 1, 1]
time_coords (torch.Tensor): Time coordinates of shape [B, T, C, H, W]
mask (bool): Whether to mask or not. Useful for inference.
"""
# The context tensor
ctx_tensor = torch.rand(5, 10, 12, 120, 120)
ctx_coords = torch.rand(5, 2, 120, 120)
ts = torch.rand(5, 10, 12)
time_coords = torch.rand(5, 10, 12, 120, 120)
ts_coords = torch.rand(5, 2, 1, 1)
mask = True
x = self.crossvivit(ctx_tensor, ctx_coords, ts, ts_coords, time_coords=time_coords, mask=True)
self.assertEqual(x.shape, (1, 1000))

def test_self_attention_dims(self):
"""
Test the self attention layer with the correct dimensions.
"""
self.self_attention = SelfAttention(dim=128, use_rotary=True)
self.self_attention(torch.rand(5, 512, 128), torch.rand(5,512, 128))
self.self_attention(torch.rand(5, 512, 128), (torch.rand(5, 512, 64), torch.rand(5, 512, 64)))


if __name__ == '__main__':
Expand Down

0 comments on commit 0af9583

Please sign in to comment.