Skip to content

Commit

Permalink
fixing the code doc-strings
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Sep 30, 2024
1 parent 395b9e8 commit edbd8d1
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 16 deletions.
27 changes: 19 additions & 8 deletions flood_forecast/multi_models/crossvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def __init__(
:param dim: The embedding dimension. The authors generally use a dimension of 384 for training the large models.
:type dim: int
"""
super().__init__()
self.image_size = image_size
Expand Down Expand Up @@ -244,17 +245,28 @@ def forward(
src_pos_emb: torch.Tensor,
tgt_pos_emb: torch.Tensor,
):
"""Performs the following computation in each layer:
"""
:param src: Source sequence of shape [B, N, D]. In the case of CrossVIVIT. src is the encoded video_ctx. Where
B is the batch_size*forecast_history, N is the number_of_patches after random masking is applied and D is the
dimension of the model. In other use cases this might differ.
:type src: torch.Tensor
:param tgt: Target sequence of shape [B, M, D]. In the case of CrossVIVIT. tgt is the encoded_timeseries. Where
B is the batch_size*forecast_history, M is usually one and D is the dimension of the model. In other use cases
this might differ.
:type tgt: torch.Tensor
:param src_pos_emb: Positional embedding of source sequence's tokens of shape [B, N, D]
:type src_pos_emb: torch.Tensor
:param tgt_pos_emb: Positional embedding of target sequence's tokens of shape [B, M, D]
:type tgt_pos_emb: torch.Tensor
:return: Tuple of (tgt, attention_scores)
:rtype: Tuple[torch.Tensor, Dict[str, torch.Tensor]]
Performs the following computation in each layer:
1. Self-Attention on the source sequence
2. FFN on the source sequence
3. Cross-Attention between target and source sequence
4. FFN on the target sequence
Args:
src: Source sequence of shape [B, N, D]
tgt: Target sequence of shape [B, M, D]
src_pos_emb: Positional embedding of source sequence's tokens of shape [B, N, D]
tgt_pos_emb: Positional embedding of target sequence's tokens of shape [B, M, D]
"""
attention_scores = {}
for i in range(len(self.cross_layers)):
Expand All @@ -263,7 +275,6 @@ def forward(
attention_scores["cross_attention"] = cattn_scores
tgt = out + tgt
tgt = cff(tgt) + tgt

return tgt, attention_scores


Expand Down Expand Up @@ -581,6 +592,7 @@ def forward(

# Apply masking to video context if specified
# (Likely discussed in Section 3.2, subsection on regularization techniques)
# Prior to masking embedded_video_context it has shape [batch_size*forecast_history, num_patches, dim],
if self.ctx_masking_ratio > 0 and apply_masking:
mask_ratio = self.ctx_masking_ratio * torch.rand(1).item()
embedded_video_context, _, _, keep_indices = self.random_masking(
Expand Down Expand Up @@ -667,5 +679,4 @@ def forward(
quantile_mask = self.quantile_masker(
rearrange(transformed_timeseries.detach(), "b t c -> b c t")
)

return outputs, quantile_mask, self_attention_scores, cross_attention_scores
4 changes: 3 additions & 1 deletion flood_forecast/preprocessing/pytorch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,11 @@ def __init__(
**kwargs
):
"""
A data loader for the test data and plotting code it is a subclass of CSVDataLoader.
:param str df_path: The path to the CSV file you want to use (GCS compatible) or a Pandas DataFrame
A data loader for the test data.
:type df_path: str
:param int forecast_total: The total length of the forecast
:type forecast_total: int
"""
if "file_path" not in kwargs:
kwargs["file_path"] = df_path
Expand Down
7 changes: 6 additions & 1 deletion flood_forecast/time_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,12 @@ def make_data_load(


def scaling_function(start_end_params: Dict, dataset_params: Dict) -> Dict:
""""""
"""
Function to scale the data based on the parameters in the dataset_params dict
:param start_end_params: The start_end_params dictionary
:param dataset_params: The dataset_params dictionary
:return: The start_end_params dictionary
"""
if "scaler" in dataset_params:
in_dataset_params = "scaler"
elif "scaling" in dataset_params:
Expand Down
6 changes: 5 additions & 1 deletion flood_forecast/transformer_xl/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,14 +561,18 @@ def __init__(

self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)

# Maps the input dimension to the inner dimension
self.to_q = nn.Linear(dim, inner_dim, bias=False)

self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)

self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))

def forward(self, src: Float[torch.Tensor, ""], src_pos_emb, tgt, tgt_pos_emb):
"""
Performs the forward pass of the CrossAttention module.
"""
q = self.to_q(tgt)

qkv = (q, *self.to_kv(src).chunk(2, dim=-1))
Expand Down
9 changes: 4 additions & 5 deletions flood_forecast/transformer_xl/data_embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple, Any

import torch
import torch.nn as nn
Expand Down Expand Up @@ -34,11 +34,10 @@ def __init__(self, dim: int, freq_type: str = "lucidrains", **kwargs: dict):
)
self.register_buffer("scales", scales)

def forward(self, coords: torch.Tensor):
def forward(self, coords: Float[torch.Tensor, "batch_size*time_series 2 1 1"]) -> Tuple[Any, Any]:
"""Assumes that coordinates do not change throughout the batches.
Args:
coords (torch.Tensor): Coordinates of shape [B, 2, H, W]
:param coords: The coordinates to embed. We assume these will be of shape batch_shape*time_series. The last two dimensions are the x and y coordinates.
:type coords: torch.Tensor
"""
seq_x = coords[:, 0, 0, :]
seq_x = seq_x.unsqueeze(-1)
Expand Down

0 comments on commit edbd8d1

Please sign in to comment.