Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Make IceVision requirements fully optional #742

Merged
merged 1 commit into from
Sep 7, 2021
Rate limit · GitHub

Access has been restricted

You have triggered a rate limit.

Please wait a few minutes before you try again;
in some cases this may take up to an hour.

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@
)
from flash.core.data.process import Deserializer, Postprocess, Preprocess
from flash.core.data.properties import ProcessState
from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires_extras
from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires

if _AUDIO_AVAILABLE:
import librosa
@@ -155,7 +155,7 @@ def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any:


class SpeechRecognitionPreprocess(Preprocess):
@requires_extras("audio")
@requires("audio")
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
@@ -197,7 +197,7 @@ class SpeechRecognitionBackboneState(ProcessState):


class SpeechRecognitionPostprocess(Postprocess):
@requires_extras("audio")
@requires("audio")
def __init__(self):
super().__init__()

6 changes: 3 additions & 3 deletions flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
from torch import nn

from flash.core.data.data_source import DefaultDataKeys
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires_extras
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires

if _ICEVISION_AVAILABLE:
from icevision.core import tasks
@@ -206,15 +206,15 @@ def forward(self, x):
return from_icevision_record(record)


@requires_extras("image")
@requires(["image", "icevision"])
def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
"""The default transforms from IceVision."""
return {
"pre_tensor_transform": IceVisionTransformAdapter([*A.resize_and_pad(image_size), A.Normalize()]),
}


@requires_extras("image")
@requires(["image", "icevision"])
def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
"""The default augmentations from IceVision."""
return {
12 changes: 6 additions & 6 deletions flash/core/model.py
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.serve import Composition
from flash.core.utilities.apply_func import get_callable_dict
from flash.core.utilities.imports import requires_extras
from flash.core.utilities.imports import requires


class ModuleWrapperBase:
@@ -258,11 +258,11 @@ class CheckDependenciesMeta(ABCMeta):
def __new__(mcs, *args, **kwargs):
result = ABCMeta.__new__(mcs, *args, **kwargs)
if result.required_extras is not None:
result.__init__ = requires_extras(result.required_extras)(result.__init__)
result.__init__ = requires(result.required_extras)(result.__init__)
load_from_checkpoint = getattr(result, "load_from_checkpoint", None)
if load_from_checkpoint is not None:
result.load_from_checkpoint = classmethod(
requires_extras(result.required_extras)(result.load_from_checkpoint.__func__)
requires(result.required_extras)(result.load_from_checkpoint.__func__)
)
return result

@@ -282,7 +282,7 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check

schedulers: FlashRegistry = _SCHEDULERS_REGISTRY

required_extras: Optional[str] = None
required_extras: Optional[Union[str, List[str]]] = None

def __init__(
self,
@@ -826,7 +826,7 @@ def configure_callbacks(self):
if flash._IS_TESTING and torch.cuda.is_available():
return [BenchmarkConvergenceCI()]

@requires_extras("serve")
@requires("serve")
def run_serve_sanity_check(self):
if not self.is_servable:
raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.")
@@ -846,7 +846,7 @@ def run_serve_sanity_check(self):
resp = tc.post("http://0.0.0.0:8000/predict", json=body)
print(f"Sanity check response: {resp.json()}")

@requires_extras("serve")
@requires("serve")
def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> "Composition":
if not self.is_servable:
raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.")
4 changes: 2 additions & 2 deletions flash/core/serve/component.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@

from flash.core.serve.core import ParameterContainer, Servable
from flash.core.serve.decorators import BoundMeta, UnboundMeta
from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_AVAILABLE, requires_extras
from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_AVAILABLE, requires

if _CYTOOLZ_AVAILABLE:
from cytoolz import first, isiterable, valfilter
@@ -145,7 +145,7 @@ def _validate_config_args(config: Optional[Dict[str, Union[str, int, float, byte
class FlashServeMeta(type):
"""We keep a mapping of externally used names to classes."""

@requires_extras("serve")
@requires("serve")
def __new__(cls, name, bases, namespace):
# create new instance of cls in order to apply any @expose class decorations.
_tmp_cls = super().__new__(cls, name, bases, namespace)
4 changes: 2 additions & 2 deletions flash/core/serve/core.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@

from flash.core.serve.types.base import BaseType
from flash.core.serve.utils import download_file
from flash.core.utilities.imports import _PYDANTIC_AVAILABLE, requires_extras
from flash.core.utilities.imports import _PYDANTIC_AVAILABLE, requires

if _PYDANTIC_AVAILABLE:
from pydantic import FilePath, HttpUrl, parse_obj_as, ValidationError
@@ -102,7 +102,7 @@ class Servable:
* How to handle ``__init__`` args not recorded in hparams of ``pl.LightningModule``
"""

@requires_extras("serve")
@requires("serve")
def __init__(
self,
*args: ServableValidArgs_T,
43 changes: 21 additions & 22 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
import operator
import types
from importlib.util import find_spec
from typing import Callable, List, Union
from typing import List, Union
from warnings import warn

from pkg_resources import DistributionNotFound
@@ -142,8 +142,6 @@ class Image(metaclass=MetaImage):
_KORNIA_AVAILABLE,
_PYSTICHE_AVAILABLE,
_SEGMENTATION_MODELS_AVAILABLE,
_ICEVISION_AVAILABLE,
_ICEDATA_AVAILABLE,
]
)
_SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE
@@ -163,22 +161,33 @@ class Image(metaclass=MetaImage):
}


def _requires(
module_paths: Union[str, List],
module_available: Callable[[str], bool],
formatter: Callable[[List[str]], str],
):
def requires(module_paths: Union[str, List]):

if not isinstance(module_paths, list):
module_paths = [module_paths]

def decorator(func):
if not all(module_available(module_path) for module_path in module_paths):
available = True
extras = []
modules = []
for module_path in module_paths:
if module_path in _EXTRAS_AVAILABLE:
extras.append(module_path)
if not _EXTRAS_AVAILABLE[module_path]:
available = False
else:
modules.append(module_path)
if not _module_available(module_path):
available = False

if not available:
modules = [f"'{module}'" for module in modules]
modules.append(f"'lightning-flash[{','.join(extras)}]'")

@functools.wraps(func)
def wrapper(*args, **kwargs):
raise ModuleNotFoundError(
f"Required dependencies not available. Please run: pip install {formatter(module_paths)}"
f"Required dependencies not available. Please run: pip install {' '.join(modules)}"
)

return wrapper
@@ -188,18 +197,8 @@ def wrapper(*args, **kwargs):
return decorator


def requires(module_paths: Union[str, List]):
return _requires(module_paths, _module_available, lambda module_paths: " ".join(module_paths))


def requires_extras(extras: Union[str, List]):
return _requires(
extras, lambda extras: _EXTRAS_AVAILABLE[extras], lambda extras: f"'lightning-flash[{','.join(extras)}]'"
)


def example_requires(extras: Union[str, List[str]]):
return requires_extras(extras)(lambda: None)()
def example_requires(module_paths: Union[str, List[str]]):
return requires(module_paths)(lambda: None)()


def lazy_import(module_name, callback=None):
4 changes: 2 additions & 2 deletions flash/graph/data.py
Original file line number Diff line number Diff line change
@@ -16,15 +16,15 @@
from torch.utils.data import Dataset

from flash.core.data.data_source import DatasetDataSource
from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires_extras
from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires

if _GRAPH_AVAILABLE:
from torch_geometric.data import Data
from torch_geometric.data import Dataset as TorchGeometricDataset


class GraphDatasetDataSource(DatasetDataSource):
@requires_extras("graph")
@requires("graph")
def load_data(self, data: Dataset, dataset: Any = None) -> Dataset:
data = super().load_data(data, dataset=dataset)
if not self.predicting:
6 changes: 3 additions & 3 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, LoaderDataFrameDataSource
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, Image, requires, requires_extras
from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, Image, requires
from flash.image.classification.transforms import default_transforms, train_default_transforms
from flash.image.data import (
image_loader,
@@ -45,7 +45,7 @@ class ImageClassificationDataFrameDataSource(LoaderDataFrameDataSource):
def __init__(self):
super().__init__(image_loader)

@requires_extras("image")
@requires("image")
def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
sample = super().load_sample(sample, dataset)
w, h = sample[DefaultDataKeys.INPUT].size # WxH
@@ -315,7 +315,7 @@ class MatplotlibVisualization(BaseVisualization):
block_viz_window: bool = True # parameter to allow user to block visualisation windows

@staticmethod
@requires_extras("image")
@requires("image")
def _to_numpy(img: Union[np.ndarray, torch.Tensor, Image.Image]) -> np.ndarray:
out: np.ndarray
if isinstance(img, np.ndarray):
6 changes: 3 additions & 3 deletions flash/image/data.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@
TensorDataSource,
)
from flash.core.data.process import Deserializer
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires_extras
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS
@@ -55,7 +55,7 @@ def image_loader(filepath: str):


class ImageDeserializer(Deserializer):
@requires_extras("image")
@requires("image")
def deserialize(self, data: str) -> Dict:
encoded_with_padding = (data + "===").encode("ascii")
img = base64.b64decode(encoded_with_padding)
@@ -75,7 +75,7 @@ class ImagePathsDataSource(PathsDataSource):
def __init__(self):
super().__init__(loader=image_loader, extensions=IMG_EXTENSIONS + NP_EXTENSIONS)

@requires_extras("image")
@requires("image")
def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
sample = super().load_sample(sample, dataset)
w, h = sample[DefaultDataKeys.INPUT].size # WxH
2 changes: 1 addition & 1 deletion flash/image/detection/model.py
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ class ObjectDetector(AdapterTask):

heads: FlashRegistry = OBJECT_DETECTION_HEADS

required_extras: str = "image"
required_extras: List[str] = ["image", "icevision", "effdet"]

def __init__(
self,
4 changes: 2 additions & 2 deletions flash/image/instance_segmentation/cli.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
from typing import Callable, Optional

from flash.core.utilities.flash_cli import FlashCLI
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires
from flash.image import InstanceSegmentation, InstanceSegmentationData

if _ICEDATA_AVAILABLE:
@@ -24,7 +24,7 @@
__all__ = ["instance_segmentation"]


@requires_extras("image")
@requires(["image", "icedata"])
def from_pets(
val_split: float = 0.1,
batch_size: int = 4,
2 changes: 1 addition & 1 deletion flash/image/instance_segmentation/model.py
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ class InstanceSegmentation(AdapterTask):

heads: FlashRegistry = INSTANCE_SEGMENTATION_HEADS

required_extras: str = "image"
required_extras: List[str] = ["image", "icevision"]

def __init__(
self,
4 changes: 2 additions & 2 deletions flash/image/keypoint_detection/cli.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
from typing import Callable, Optional

from flash.core.utilities.flash_cli import FlashCLI
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires
from flash.image import KeypointDetectionData, KeypointDetector

if _ICEDATA_AVAILABLE:
@@ -23,7 +23,7 @@
__all__ = ["keypoint_detection"]


@requires_extras("image")
@requires("image")
def from_biwi(
val_split: float = 0.1,
batch_size: int = 4,
2 changes: 1 addition & 1 deletion flash/image/keypoint_detection/model.py
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ class KeypointDetector(AdapterTask):

heads: FlashRegistry = KEYPOINT_DETECTION_HEADS

required_extras: str = "image"
required_extras: List[str] = ["image", "icevision"]

def __init__(
self,
3 changes: 1 addition & 2 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
@@ -42,7 +42,6 @@
Image,
lazy_import,
requires,
requires_extras,
)
from flash.image.data import ImageDeserializer, IMG_EXTENSIONS
from flash.image.segmentation.serialization import SegmentationLabels
@@ -459,7 +458,7 @@ def __init__(self, labels_map: Dict[int, Tuple[int, int, int]]):
self.labels_map: Dict[int, Tuple[int, int, int]] = labels_map

@staticmethod
@requires_extras("image")
@requires("image")
def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray:
out: np.ndarray
if isinstance(img, Image.Image):
3 changes: 1 addition & 2 deletions flash/image/segmentation/serialization.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,6 @@
_MATPLOTLIB_AVAILABLE,
lazy_import,
requires,
requires_extras,
)

Segmentation = None
@@ -56,7 +55,7 @@ class SegmentationLabels(Serializer):
visualize: Wether to visualize the image labels.
"""

@requires_extras("image")
@requires("image")
def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False):
super().__init__()
self.labels_map = labels_map
Rate limit · GitHub

Access has been restricted

You have triggered a rate limit.

Please wait a few minutes before you try again;
in some cases this may take up to an hour.