Skip to content

Commit

Permalink
Merge pull request #442 from datamol-io/fix-typing-py38
Browse files Browse the repository at this point in the history
Fix isinstance check and broadened tests to include more Python versions
  • Loading branch information
cwognum committed Aug 22, 2023
2 parents 4adaaf7 + 7a6a722 commit f96c1b6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
python-version: ["3.8", "3.9", "3.10"]
pytorch-version: ["2.0"]

runs-on: "ubuntu-latest"
Expand Down
18 changes: 3 additions & 15 deletions graphium/finetuning/finetuning_architecture.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,13 @@
from typing import Iterable, List, Dict, Tuple, Union, Callable, Any, Optional, Type

from copy import deepcopy

from loguru import logger
from typing import Any, Dict, Optional, Union

import torch
import torch.nn as nn

from torch import Tensor
from torch_geometric.data import Batch

from graphium.data.utils import get_keys
from graphium.nn.base_graph_layer import BaseGraphStructure
from graphium.nn.architectures.encoder_manager import EncoderManager
from graphium.nn.architectures import FullGraphMultiTaskNetwork, FeedForwardNN, FeedForwardPyg, TaskHeads
from graphium.nn.architectures.global_architectures import FeedForwardGraph
from graphium.trainer.predictor_options import ModelOptions
from graphium.nn.utils import MupMixin

from graphium.trainer.predictor import PredictorModule
from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT, FINETUNING_HEADS_DICT
from graphium.utils.spaces import FINETUNING_HEADS_DICT, GRAPHIUM_PRETRAINED_MODELS_DICT


class FullGraphFinetuningNetwork(nn.Module, MupMixin):
Expand Down Expand Up @@ -318,7 +306,7 @@ def __init__(self, finetuning_head_kwargs: Dict[str, Any]):
self.net = net(**finetuning_head_kwargs)

def forward(self, g: Union[Dict[str, Union[torch.Tensor, Batch]], torch.Tensor, Batch]):
if isinstance(g, Union[torch.Tensor, Batch]):
if isinstance(g, (torch.Tensor, Batch)):
pass
elif isinstance(g, Dict) and len(g) == 1:
g = list(g.values())[0]
Expand Down

0 comments on commit f96c1b6

Please sign in to comment.