Skip to content

Commit ba16733

Browse files
committed
Fix mypy issues
1 parent c218531 commit ba16733

File tree

7 files changed

+25
-16
lines changed

7 files changed

+25
-16
lines changed

code2seq/data/path_context.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import List, Iterable, Tuple, Optional
2+
from typing import Iterable, Tuple, Optional, Sequence
33

44
import torch
55

@@ -14,12 +14,12 @@ class Path:
1414
@dataclass
1515
class LabeledPathContext:
1616
label: torch.Tensor # [max label parts]
17-
path_contexts: List[Path]
17+
path_contexts: Sequence[Path]
1818

1919

2020
class BatchedLabeledPathContext:
21-
def __init__(self, samples: List[Optional[LabeledPathContext]]):
22-
samples = [s for s in samples if s is not None]
21+
def __init__(self, all_samples: Sequence[Optional[LabeledPathContext]]):
22+
samples = [s for s in all_samples if s is not None]
2323

2424
# [max label parts; batch size]
2525
self.labels = torch.cat([s.label.unsqueeze(1) for s in samples], dim=1)
@@ -59,12 +59,13 @@ class TypedPath(Path):
5959

6060
@dataclass
6161
class LabeledTypedPathContext(LabeledPathContext):
62-
path_contexts: List[TypedPath]
62+
path_contexts: Sequence[TypedPath]
6363

6464

6565
class BatchedLabeledTypedPathContext(BatchedLabeledPathContext):
66-
def __init__(self, samples: List[Optional[LabeledTypedPathContext]]):
67-
super().__init__(samples)
66+
def __init__(self, all_samples: Sequence[Optional[LabeledTypedPathContext]]):
67+
super().__init__(all_samples)
68+
samples = [s for s in all_samples if s is not None]
6869
# [max type parts; n contexts]
6970
self.from_type = torch.cat([path.from_type.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1)
7071
# [max type parts; n contexts]

code2seq/data/path_context_data_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def collate_wrapper(batch: List[Optional[LabeledPathContext]]) -> BatchedLabeled
5353
return BatchedLabeledPathContext(batch)
5454

5555
def _create_dataset(self, holdout_file: str, random_context: bool) -> PathContextDataset:
56+
if self._vocabulary is None:
57+
raise RuntimeError(f"Setup vocabulary before creating data loaders")
5658
return PathContextDataset(holdout_file, self._config, self._vocabulary, random_context)
5759

5860
def _shared_dataloader(self, holdout: str) -> DataLoader:

code2seq/data/typed_path_context_data_module.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@ def __init__(self, data_dir: str, config: DictConfig):
1717
super().__init__(data_dir, config)
1818

1919
@staticmethod
20-
def collate_wrapper(batch: List[Optional[LabeledTypedPathContext]]) -> BatchedLabeledTypedPathContext:
20+
def collate_wrapper( # type: ignore[override]
21+
batch: List[Optional[LabeledTypedPathContext]],
22+
) -> BatchedLabeledTypedPathContext:
2123
return BatchedLabeledTypedPathContext(batch)
2224

2325
def _create_dataset(self, holdout_file: str, random_context: bool) -> TypedPathContextDataset:
26+
if self._vocabulary is None:
27+
raise RuntimeError(f"Setup vocabulary before creating data loaders")
2428
return TypedPathContextDataset(holdout_file, self._config, self._vocabulary, random_context)
2529

2630
def setup(self, stage: Optional[str] = None):

code2seq/data/typed_path_context_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
class TypedPathContextDataset(PathContextDataset):
1111
def __init__(self, data_file: str, config: DictConfig, vocabulary: TypedVocabulary, random_context: bool):
1212
super().__init__(data_file, config, vocabulary, random_context)
13-
self._vocab = vocabulary
13+
self._vocab: TypedVocabulary = vocabulary
1414

1515
def _get_path(self, raw_path: List[str]) -> TypedPath:
1616
return TypedPath(

code2seq/model/code2class.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from commode_utils.modules import Classifier
55
from omegaconf import DictConfig
66
from pytorch_lightning import LightningModule
7+
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
78
from torch.optim import Optimizer
89
from torch.optim.lr_scheduler import _LRScheduler
910
from torchmetrics import Metric, Accuracy, MetricCollection
@@ -77,18 +78,19 @@ def test_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict:
7778

7879
# ========== ON EPOCH END ==========
7980

80-
def _shared_epoch_end(self, outputs: List[Dict], step: str):
81+
def _shared_epoch_end(self, outputs: EPOCH_OUTPUT, step: str):
82+
assert isinstance(outputs, dict)
8183
with torch.no_grad():
8284
mean_loss = torch.stack([out[f"{step}/loss"] for out in outputs]).mean()
8385
accuracy = self.__metrics[f"{step}_acc"].compute()
8486
log = {f"{step}/loss": mean_loss, f"{step}/accuracy": accuracy}
8587
self.log_dict(log, on_step=False, on_epoch=True)
8688

87-
def training_epoch_end(self, outputs: List[Dict]):
89+
def training_epoch_end(self, outputs: EPOCH_OUTPUT):
8890
self._shared_epoch_end(outputs, "train")
8991

90-
def validation_epoch_end(self, outputs: List[Dict]):
92+
def validation_epoch_end(self, outputs: EPOCH_OUTPUT):
9193
self._shared_epoch_end(outputs, "val")
9294

93-
def test_epoch_end(self, outputs: List[Dict]):
95+
def test_epoch_end(self, outputs: EPOCH_OUTPUT):
9496
self._shared_epoch_end(outputs, "test")

code2seq/model/modules/typed_path_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def _calculate_concat_size(embedding_size: int, rnn_size: int, num_directions: i
2727
def _type_embedding(self, types: torch.Tensor) -> torch.Tensor:
2828
return self.type_embedding(types).sum(0)
2929

30-
def forward(
30+
def forward( # type: ignore
3131
self,
3232
from_type: torch.Tensor,
3333
from_token: torch.Tensor,

code2seq/model/typed_code2seq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(
1818
teacher_forcing: float = 0.0,
1919
):
2020
super().__init__(model_config, optimizer_config, vocabulary, teacher_forcing)
21-
self._vocabulary = vocabulary
21+
self._vocabulary: TypedVocabulary = vocabulary
2222

2323
def _get_encoder(self, config: DictConfig) -> PathEncoder:
2424
return TypedPathEncoder(
@@ -46,7 +46,7 @@ def forward( # type: ignore
4646
output_logits = self._decoder(encoded_paths, contexts_per_label, output_length, target_sequence)
4747
return output_logits
4848

49-
def logits_from_batch(
49+
def logits_from_batch( # type: ignore[override]
5050
self, batch: BatchedLabeledTypedPathContext, target_sequence: Optional[torch.Tensor]
5151
) -> torch.Tensor:
5252
return self(

0 commit comments

Comments
 (0)