From 930c4a3fa6a474da98398bccc823eb3a1b77b1f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20Faruk=20=C3=96zdemir?= Date: Thu, 28 Apr 2022 11:12:40 +0300 Subject: [PATCH] clean-deprecated-parameters (#1090) * clean-deprecated-parameters - resolve conflicts * clean-deprecated-parameters - fix tests * clean-deprecated-parameters - fix tests * clean-deprecated-parameters - resolve conflicts * clean-deprecated-parameters - refactor blocks.py * clean-deprecated-parameters - remove capture_session support * clean-deprecated-parameters - fix get_component_instance * clean-deprecated-parameters - fix tests * clean-deprecated-parameters - reformat * clean-deprecated-parameters - fix tests * clean-deprecated-parameters - fix tests * clean-deprecated-parameters - resolve conflicts * clean-deprecated-parameters - resolve conflicts * Update gradio/deprecation.py * Update gradio/deprecation.py * Update gradio/deprecation.py * removed some incorrect kwargs * formatting Co-authored-by: Abubakar Abid --- gradio/__init__.py | 3 +- gradio/blocks.py | 22 +- gradio/components.py | 90 ++--- gradio/deprecation.py | 42 ++ gradio/interface.py | 112 +----- gradio/interpretation.py | 7 +- gradio/outputs.py | 4 +- gradio/routes.py | 14 - test/test_components.py | 6 +- test/test_inputs.py | 812 --------------------------------------- test/test_interfaces.py | 9 + test/test_outputs.py | 575 --------------------------- test/test_routes.py | 6 +- 13 files changed, 113 insertions(+), 1589 deletions(-) create mode 100644 gradio/deprecation.py delete mode 100644 test/test_inputs.py delete mode 100644 test/test_outputs.py diff --git a/gradio/__init__.py b/gradio/__init__.py index bff966f16dfbe..cfe5cdb384194 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -40,9 +40,8 @@ HuggingFaceDatasetSaver, SimpleCSVLogger, ) -from gradio.interface import Interface, TabbedInterface, close_all, reset_all +from gradio.interface import Interface, TabbedInterface, close_all from gradio.mix import Parallel, Series -from gradio.routes import get_state, set_state current_pkg_version = pkg_resources.require("gradio")[0].version __version__ = current_pkg_version diff --git a/gradio/blocks.py b/gradio/blocks.py index 719657852675a..178f3d7869697 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1,15 +1,16 @@ from __future__ import annotations -import enum import getpass import os import sys import time +import warnings import webbrowser from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple from gradio import encryptor, networking, queueing, strings, utils from gradio.context import Context +from gradio.deprecation import check_deprecated_parameters if TYPE_CHECKING: # Only import for type checking (is False at runtime). from fastapi.applications import FastAPI @@ -19,11 +20,13 @@ class Block: - def __init__(self, without_rendering=False, css=None): + def __init__(self, without_rendering=False, css=None, **kwargs): + self._id = None self.css = css if css is not None else {} if without_rendering: return self.render() + check_deprecated_parameters(self.__class__.__name__, **kwargs) def render(self): """ @@ -35,7 +38,6 @@ def render(self): Context.block.children.append(self) if Context.root_block is not None: Context.root_block.blocks[self._id] = self - self.events = [] def get_block_name(self) -> str: """ @@ -96,13 +98,15 @@ def set_event_trigger( class BlockContext(Block): - def __init__(self, visible: bool = True, css: Optional[Dict[str, str]] = None): + def __init__( + self, visible: bool = True, css: Optional[Dict[str, str]] = None, **kwargs + ): """ css: Css rules to apply to block. """ self.children = [] self.visible = visible - super().__init__(css=css) + super().__init__(css=css, **kwargs) def __enter__(self): self.parent = Context.block @@ -204,6 +208,7 @@ def __init__( theme: str = "default", analytics_enabled: Optional[bool] = None, mode: str = "blocks", + **kwargs, ): # Cleanup shared parameters with Interface #TODO: is this part still necessary after Interface with Blocks? @@ -221,7 +226,7 @@ def __init__( else os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True" ) - super().__init__() + super().__init__(**kwargs) self.blocks = {} self.fns: List[BlockFunction] = [] self.dependencies = [] @@ -431,9 +436,8 @@ def launch( self.width = width self.favicon_path = favicon_path - if hasattr(self, "encrypt") and self.encrypt is None: - self.encrypt = encrypt - if hasattr(self, "encrypt") and self.encrypt: + self.encrypt = encrypt + if self.encrypt: self.encryption_key = encryptor.get_key( getpass.getpass("Enter key for encryption: ") ) diff --git a/gradio/components.py b/gradio/components.py index 81070e9732f9a..c645190256946 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -37,12 +37,7 @@ def __init__( without_rendering: bool = False, **kwargs, ): - if "optional" in kwargs: - warnings.warn( - "Usage of optional is deprecated, and it has no effect", - DeprecationWarning, - ) - super().__init__(without_rendering=without_rendering, css=css) + super().__init__(without_rendering=without_rendering, css=css, **kwargs) def __str__(self): return self.__repr__() @@ -269,16 +264,6 @@ def __init__( placeholder (str): placeholder hint to provide behind textarea. label (str): component name in interface. """ - if "numeric" in kwargs: - warnings.warn( - "The 'numeric' type has been deprecated. Use the Number component instead.", - DeprecationWarning, - ) - if "type" in kwargs: - warnings.warn( - "The 'type' parameter has been deprecated. Use the Number component instead if you need it.", - DeprecationWarning, - ) default_value = str(default_value) self.lines = lines self.max_lines = max_lines @@ -1135,15 +1120,7 @@ def __init__( type (str): The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "file" produces a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly. label (str): component name in interface. """ - if "plot" in kwargs: - warnings.warn( - "The 'plot' parameter has been deprecated. Use the new Plot() component instead", - DeprecationWarning, - ) - self.type = "plot" - else: - self.type = type - + self.type = type self.default_value = ( processing_utils.encode_url_or_file_to_base64(default_value) if default_value @@ -1954,8 +1931,6 @@ def __init__( type (str): Type of value to be returned by component. "file" returns a temporary file object whose path can be retrieved by file_obj.name, "binary" returns an bytes object. label (str): component name in interface. """ - if "keep_filename" in kwargs: - warnings.warn("keep_filename is deprecated", DeprecationWarning) self.default_value = ( processing_utils.encode_url_or_file_to_base64(default_value) if default_value @@ -3144,7 +3119,6 @@ def __init__( self, default_value: str = "", *, - label: Optional[str] = None, css: Optional[Dict] = None, **kwargs, ): @@ -3154,7 +3128,7 @@ def __init__( label (str): component name css (dict): optional css parameters for the component """ - super().__init__(label=label, css=css, **kwargs) + super().__init__(css=css, **kwargs) self.default_value = default_value def get_template_context(self): @@ -3315,10 +3289,37 @@ def get_template_context(self): } +class StatusTracker(Component): + """ + Used to indicate status of a function call. Event listeners can bind to a StatusTracker with 'status=' keyword argument. + """ + + def __init__( + self, + *, + cover_container: bool = False, + css: Optional[Dict] = None, + **kwargs, + ): + """ + Parameters: + cover_container (bool): If True, will expand to cover parent container while function pending. + label (str): component name + css (dict): optional css parameters for the component + """ + super().__init__(css=css, **kwargs) + self.cover_container = cover_container + + def get_template_context(self): + return { + "cover_container": self.cover_container, + **super().get_template_context(), + } + + def component(cls_name: str): """ Returns a component or template with the given class name, or raises a ValueError if not found. - @param cls_name: lower-case string class name of a component @return cls: the component class """ @@ -3355,32 +3356,3 @@ def get_component_instance(comp: str | dict | Component): raise ValueError( f"Component must provided as a `str` or `dict` or `Component` but is {comp}" ) - - -class StatusTracker(Component): - """ - Used to indicate status of a function call. Event listeners can bind to a StatusTracker with 'status=' keyword argument. - """ - - def __init__( - self, - *, - cover_container: bool = False, - label: Optional[str] = None, - css: Optional[Dict] = None, - **kwargs, - ): - """ - Parameters: - cover_container (bool): If True, will expand to cover parent container while function pending. - label (str): component name - css (dict): optional css parameters for the component - """ - super().__init__(label=label, css=css, **kwargs) - self.cover_container = cover_container - - def get_template_context(self): - return { - "cover_container": self.cover_container, - **super().get_template_context(), - } diff --git a/gradio/deprecation.py b/gradio/deprecation.py new file mode 100644 index 0000000000000..25bfd38e41a57 --- /dev/null +++ b/gradio/deprecation.py @@ -0,0 +1,42 @@ +import warnings + + +def simple_deprecated_notice(term: str) -> str: + return f"`{term}` parameter is deprecated, and it has no effect" + + +def use_in_launch(term: str) -> str: + return f"`{term}` is deprecated in `Interface()`, please use it within `launch()` instead." + + +DEPRECATION_MESSAGE = { + "optional": simple_deprecated_notice("optional"), + "keep_filename": simple_deprecated_notice("keep_filename"), + "numeric": simple_deprecated_notice("numeric"), + "verbose": simple_deprecated_notice("verbose"), + "allow_screenshot": simple_deprecated_notice("allow_screenshot"), + "capture_session": simple_deprecated_notice("capture_session"), + "api_mode": simple_deprecated_notice("api_mode"), + "show_tips": use_in_launch("show_tips"), + "encrypt": use_in_launch("encrypt"), + "enable_queue": use_in_launch("enable_queue"), + "server_name": use_in_launch("server_name"), + "server_port": use_in_launch("server_port"), + "width": use_in_launch("width"), + "height": use_in_launch("height"), + "plot": "The 'plot' parameter has been deprecated. Use the new Plot component instead", + "type": "The 'type' parameter has been deprecated. Use the Number component instead.", +} + + +def check_deprecated_parameters(cls: str, **kwargs) -> None: + for key, value in DEPRECATION_MESSAGE.items(): + if key in kwargs: + kwargs.pop(key) + # Interestingly, using DeprecationWarning causes warning to not appear. + warnings.warn(value) + + if len(kwargs) != 0: + warnings.warn( + f"You have unused kwarg parameters in {cls}, please remove them: {kwargs}" + ) diff --git a/gradio/interface.py b/gradio/interface.py index bbdb0d958e134..2f86a5be95226 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -105,7 +105,6 @@ def __init__( fn: Callable | List[Callable], inputs: str | Component | List[str | Component] = None, outputs: str | Component | List[str | Component] = None, - verbose: bool = False, examples: Optional[List[Any] | List[List[Any]] | str] = None, cache_examples: Optional[bool] = None, examples_per_page: int = 10, @@ -113,7 +112,6 @@ def __init__( layout: str = "unaligned", show_input: bool = True, show_output: bool = True, - capture_session: Optional[bool] = None, interpretation: Optional[Callable | str] = None, num_shap: float = 2.0, theme: Optional[str] = None, @@ -123,27 +121,18 @@ def __init__( article: Optional[str] = None, thumbnail: Optional[str] = None, css: Optional[str] = None, - height=None, - width=None, - allow_screenshot: bool = False, allow_flagging: Optional[str] = None, flagging_options: List[str] = None, - encrypt=None, - show_tips=None, flagging_dir: str = "flagged", analytics_enabled: Optional[bool] = None, - server_name=None, - server_port=None, - enable_queue=None, - api_mode=None, flagging_callback: FlaggingCallback = CSVLogger(), - ): # TODO: (faruk) Let's remove depreceated parameters in the version 3.0.0 + **kwargs, + ): """ Parameters: fn (Union[Callable, List[Callable]]): the function to wrap an interface around. inputs (Union[str, InputComponent, List[Union[str, InputComponent]]]): a single Gradio input component, or list of Gradio input components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn. outputs (Union[str, OutputComponent, List[Union[str, OutputComponent]]]): a single Gradio output component, or list of Gradio output components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn. - verbose (bool): DEPRECATED. Whether to print detailed information during launch. examples (Union[List[List[Any]], str]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs. examples_per_page (int): If examples are provided, how many to display per page. cache_examples(Optional[bool]): @@ -152,7 +141,6 @@ def __init__( The default option elsewhere is False. live (bool): whether the interface should automatically reload on change. layout (str): Layout of input and output panels. "horizontal" arranges them as two columns of equal height, "unaligned" arranges them as two columns of unequal height, and "vertical" arranges them vertically. - capture_session (bool): DEPRECATED. If True, captures the default graph and session (needed for Tensorflow 1.x) interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function. num_shap (float): a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results. Only applies if interpretation is "shap". title (str): a title for the interface; if provided, appears above the input and output components. @@ -161,18 +149,13 @@ def __init__( thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub theme (str): Theme to use - one of "default", "huggingface", "seafoam", "grass", "peach". Add "dark-" prefix, e.g. "dark-peach" for dark theme (or just "dark" for the default dark theme). css (str): custom css or path to custom css file to use with interface. - allow_screenshot (bool): DEPRECATED if False, users will not see a button to take a screenshot of the interface. allow_flagging (str): one of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every prediction will be automatically flagged. If "manual", samples are flagged when the user clicks flag button. Can be set with environmental variable GRADIO_ALLOW_FLAGGING. flagging_options (List[str]): if provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual". - encrypt (bool): DEPRECATED. If True, flagged data will be encrypted by key provided by creator at launch flagging_dir (str): what to name the dir where flagged data is stored. - show_tips (bool): DEPRECATED. if True, will occasionally show tips about new Gradio features - enable_queue (bool): DEPRECATED. if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout. - api_mode (bool): DEPRECATED. If True, will skip preprocessing steps when the Interface is called() as a function (should remain False unless the Interface is loaded from an external repo) - server_name (str): DEPRECATED. Name of the server to use for serving the interface - pass in launch() instead. - server_port (int): DEPRECATED. Port of the server to use for serving the interface - pass in launch() instead. """ - super().__init__(analytics_enabled=analytics_enabled, mode="interface") + super().__init__( + analytics_enabled=analytics_enabled, mode="interface", **kwargs + ) if inputs is None: inputs = [] @@ -221,50 +204,18 @@ def __init__( else: raise ValueError("Invalid value for parameter: interpretation") + self.api_mode = False self.predict = fn self.predict_durations = [[0, 0]] * len(fn) self.function_names = [func.__name__ for func in fn] self.__name__ = ", ".join(self.function_names) - if verbose: - warnings.warn( - "The `verbose` parameter in the `Interface`" - "is deprecated and has no effect." - ) - if allow_screenshot: - warnings.warn( - "The `allow_screenshot` parameter in the `Interface`" - "is deprecated and has no effect." - ) - self.live = live self.layout = layout self.show_input = show_input self.show_output = show_output self.flag_hash = random.getrandbits(32) - self.capture_session = capture_session - if capture_session is not None: - warnings.warn( - "The `capture_session` parameter in the `Interface`" - " is deprecated and may be removed in the future." - ) - try: - import tensorflow as tf - - self.session = tf.get_default_graph(), tf.keras.backend.get_session() - except (ImportError, AttributeError): - # If they are using TF >= 2.0 or don't have TF, - # just ignore this parameter. - pass - - if server_name is not None or server_port is not None: - raise DeprecationWarning( - "The `server_name` and `server_port` parameters in `Interface`" - "are deprecated. Please pass into launch() instead." - ) - - self.session = None self.title = title CLEANER = re.compile("<.*?>") @@ -324,14 +275,6 @@ def clean_html(raw_html): ) self.theme = theme - self.height = height - self.width = width - if self.height is not None or self.width is not None: - warnings.warn( - "The `height` and `width` parameters in `Interface` " - "are deprecated and should be passed into launch()." - ) - if css is not None and os.path.exists(css): with open(css) as css_file: self.css = css_file.read() @@ -388,7 +331,6 @@ def clean_html(raw_html): self.examples_per_page = examples_per_page self.simple_server = None - self.allow_screenshot = allow_screenshot # For analytics_enabled and allow_flagging: (1) first check for # parameter, (2) check for env variable, (3) default to True/"manual" @@ -434,47 +376,20 @@ def clean_html(raw_html): self.share_url = None self.local_url = None - if show_tips is not None: - warnings.warn( - "The `show_tips` parameter in the `Interface` is deprecated. Please use the `show_tips` parameter in `launch()` instead" - ) - self.requires_permissions = any( [component.requires_permissions for component in self.input_components] ) self.favicon_path = None - self.height = height - self.width = width - if self.height is not None or self.width is not None: - warnings.warn( - "The `width` and `height` parameters in the `Interface` class" - "will be deprecated. Please provide these parameters" - "in `launch()` instead" - ) - - self.encrypt = encrypt - if self.encrypt is not None: - warnings.warn( - "The `encrypt` parameter in the `Interface` class" - "will be deprecated. Please provide this parameter" - "in `launch()` instead" - ) - - if api_mode is not None: - warnings.warn("The `api_mode` parameter in the `Interface` is deprecated.") - self.api_mode = False data = { "fn": fn, "inputs": inputs, "outputs": outputs, "live": live, - "capture_session": capture_session, "ip_address": self.ip_address, "interpretation": interpretation, "allow_flagging": allow_flagging, - "allow_screenshot": allow_screenshot, "custom_css": self.css is not None, "theme": self.theme, } @@ -701,12 +616,7 @@ def run_prediction( output_component_counter = 0 for predict_fn in self.predict: - if self.capture_session and self.session is not None: # For TF 1.x - graph, sess = self.session - with graph.as_default(), sess.as_default(): - prediction = predict_fn(*processed_input) - else: - prediction = predict_fn(*processed_input) + prediction = predict_fn(*processed_input) if len(self.output_components) == len(self.predict): prediction = [prediction] @@ -851,11 +761,3 @@ def __init__( def close_all(verbose: bool = True) -> None: for io in Interface.get_instances(): io.close(verbose) - - -def reset_all() -> None: - warnings.warn( - "The `reset_all()` method has been renamed to `close_all()` " - "and will be deprecated. Please use `close_all()` instead." - ) - close_all() diff --git a/gradio/interpretation.py b/gradio/interpretation.py index 8dedfc46f3233..5f41cfb25aed2 100644 --- a/gradio/interpretation.py +++ b/gradio/interpretation.py @@ -161,12 +161,7 @@ def get_masked_prediction(binary_mask): for i, input_component in enumerate(interface.input_components) ] interpreter = interface.interpretation - if interface.capture_session and interface.session is not None: - graph, sess = interface.session - with graph.as_default(), sess.as_default(): - interpretation = interpreter(*processed_input) - else: - interpretation = interpreter(*processed_input) + interpretation = interpreter(*processed_input) if len(raw_input) == 1: interpretation = [interpretation] return interpretation, [] diff --git a/gradio/outputs.py b/gradio/outputs.py index 19fbb6326a447..cc1b1fa6f0b63 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -60,7 +60,9 @@ def __init__( "Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components", DeprecationWarning, ) - super().__init__(type=type, label=label, plot=plot) + if plot: + type = "plot" + super().__init__(type=type, label=label) class Video(C_Video): diff --git a/gradio/routes.py b/gradio/routes.py index b5b81a2ba988e..ae1e19c5f7bc3 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -312,17 +312,3 @@ def get_types(cls_set: List[Type], component: str): docset.append(doc_lines[-1].split(":")[-1]) types.append(doc_lines[-1].split(")")[0].split("(")[-1]) return docset, types - - -def get_state(): - raise DeprecationWarning( - "This function is deprecated. To create stateful demos, use the Variable" - " component. Please see the getting started for more information." - ) - - -def set_state(*args): - raise DeprecationWarning( - "This function is deprecated. To create stateful demos, use the Variable" - " component. Please see the getting started for more information." - ) diff --git a/test/test_components.py b/test/test_components.py index 8cfc0680d1255..1e538bfed485b 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -51,7 +51,7 @@ def test_component_functions(self): restored = text_input.restore_flagged(tmpdirname, to_save, None) self.assertEqual(restored, "Hello World!") - with self.assertWarns(DeprecationWarning): + with self.assertWarns(Warning): _ = gr.Textbox(type="number") self.assertEqual( @@ -489,7 +489,7 @@ def test_component_functions(self): image_input = gr.Image(invert_colors=True) self.assertIsNotNone(image_input.preprocess(img)) image_input.preprocess(img) - with self.assertWarns(DeprecationWarning): + with self.assertWarns(Warning): file_image = gr.Image(type="file") file_image.preprocess(deepcopy(media_data.BASE64_IMAGE)) file_image = gr.Image(type="filepath") @@ -526,7 +526,7 @@ def test_component_functions(self): "" ) ) - with self.assertWarns(DeprecationWarning): + with self.assertWarns(Warning): plot_output = gr.Image(plot=True) xpoints = np.array([0, 6]) diff --git a/test/test_inputs.py b/test/test_inputs.py deleted file mode 100644 index cbf3905c22c1d..0000000000000 --- a/test/test_inputs.py +++ /dev/null @@ -1,812 +0,0 @@ -import copy -import json -import os -import tempfile -import unittest -from difflib import SequenceMatcher - -import numpy as np -import pandas -import PIL - -import gradio as gr -from gradio import media_data - -os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" - -# TODO: Delete this file after confirming backwards compatibility works well. - - -class TestTextbox(unittest.TestCase): - def test_as_component(self): - text_input = gr.inputs.Textbox() - self.assertEqual(text_input.preprocess("Hello World!"), "Hello World!") - self.assertEqual(text_input.preprocess_example("Hello World!"), "Hello World!") - self.assertEqual(text_input.serialize("Hello World!", True), "Hello World!") - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = text_input.save_flagged( - tmpdirname, "text_input", "Hello World!", None - ) - self.assertEqual(to_save, "Hello World!") - restored = text_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(restored, "Hello World!") - - with self.assertWarns(DeprecationWarning): - _ = gr.inputs.Textbox(type="number") - - self.assertEqual( - text_input.tokenize("Hello World! Gradio speaking."), - ( - ["Hello", "World!", "Gradio", "speaking."], - [ - "World! Gradio speaking.", - "Hello Gradio speaking.", - "Hello World! speaking.", - "Hello World! Gradio", - ], - None, - ), - ) - text_input.interpretation_replacement = "unknown" - self.assertEqual( - text_input.tokenize("Hello World! Gradio speaking."), - ( - ["Hello", "World!", "Gradio", "speaking."], - [ - "unknown World! Gradio speaking.", - "Hello unknown Gradio speaking.", - "Hello World! unknown speaking.", - "Hello World! Gradio unknown", - ], - None, - ), - ) - - self.assertIsInstance(text_input.generate_sample(), str) - - def test_in_interface(self): - iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox") - self.assertEqual(iface.process(["Hello"]), ["olleH"]) - iface = gr.Interface( - lambda sentence: max([len(word) for word in sentence.split()]), - gr.inputs.Textbox(), - "number", - interpretation="default", - ) - scores = iface.interpret( - ["Return the length of the longest word in this sentence"] - )[0]["interpretation"] - self.assertEqual( - scores, - [ - ("Return", 0.0), - (" ", 0), - ("the", 0.0), - (" ", 0), - ("length", 0.0), - (" ", 0), - ("of", 0.0), - (" ", 0), - ("the", 0.0), - (" ", 0), - ("longest", 0.0), - (" ", 0), - ("word", 0.0), - (" ", 0), - ("in", 0.0), - (" ", 0), - ("this", 0.0), - (" ", 0), - ("sentence", 1.0), - (" ", 0), - ], - ) - - -class TestNumber(unittest.TestCase): - def test_as_component(self): - numeric_input = gr.inputs.Number(optional=True) - self.assertEqual(numeric_input.preprocess(3), 3.0) - self.assertEqual(numeric_input.preprocess(None), None) - self.assertEqual(numeric_input.preprocess_example(3), 3) - self.assertEqual(numeric_input.serialize(3, True), 3) - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = numeric_input.save_flagged(tmpdirname, "numeric_input", 3, None) - self.assertEqual(to_save, 3) - restored = numeric_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(restored, 3) - self.assertIsInstance(numeric_input.generate_sample(), float) - numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="absolute") - self.assertEqual( - numeric_input.get_interpretation_neighbors(1), - ([-2.0, -1.0, 0.0, 2.0, 3.0, 4.0], {}), - ) - numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="percent") - self.assertEqual( - numeric_input.get_interpretation_neighbors(1), - ([0.97, 0.98, 0.99, 1.01, 1.02, 1.03], {}), - ) - self.assertEqual( - numeric_input.get_template_context(), - { - "default_value": None, - "name": "number", - "show_label": True, - "label": None, - "css": {}, - "interactive": None, - }, - ) - - def test_in_interface(self): - iface = gr.Interface(lambda x: x**2, "number", "textbox") - self.assertEqual(iface.process([2]), ["4.0"]) - iface = gr.Interface( - lambda x: x**2, "number", "number", interpretation="default" - ) - scores = iface.interpret([2])[0]["interpretation"] - self.assertEqual( - scores, - [ - (1.94, -0.23640000000000017), - (1.96, -0.15840000000000032), - (1.98, -0.07960000000000012), - [2, None], - (2.02, 0.08040000000000003), - (2.04, 0.16159999999999997), - (2.06, 0.24359999999999982), - ], - ) - - -class TestSlider(unittest.TestCase): - def test_as_component(self): - slider_input = gr.inputs.Slider() - self.assertEqual(slider_input.preprocess(3.0), 3.0) - self.assertEqual(slider_input.preprocess_example(3), 3) - self.assertEqual(slider_input.serialize(3, True), 3) - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = slider_input.save_flagged(tmpdirname, "slider_input", 3, None) - self.assertEqual(to_save, 3) - restored = slider_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(restored, 3) - - self.assertIsInstance(slider_input.generate_sample(), int) - slider_input = gr.inputs.Slider( - default=15, minimum=10, maximum=20, step=1, label="Slide Your Input" - ) - self.assertEqual( - slider_input.get_template_context(), - { - "minimum": 10, - "maximum": 20, - "step": 1, - "default_value": 15, - "name": "slider", - "show_label": True, - "label": "Slide Your Input", - "css": {}, - "interactive": None, - }, - ) - - def test_in_interface(self): - iface = gr.Interface(lambda x: x**2, "slider", "textbox") - self.assertEqual(iface.process([2]), ["4"]) - iface = gr.Interface( - lambda x: x**2, "slider", "number", interpretation="default" - ) - scores = iface.interpret([2])[0]["interpretation"] - self.assertEqual( - scores, - [ - -4.0, - 200.08163265306123, - 812.3265306122449, - 1832.7346938775513, - 3261.3061224489797, - 5098.040816326531, - 7342.938775510205, - 9996.0, - ], - ) - - -class TestCheckbox(unittest.TestCase): - def test_as_component(self): - bool_input = gr.inputs.Checkbox() - self.assertEqual(bool_input.preprocess(True), True) - self.assertEqual(bool_input.preprocess_example(True), True) - self.assertEqual(bool_input.serialize(True, True), True) - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = bool_input.save_flagged(tmpdirname, "bool_input", True, None) - self.assertEqual(to_save, True) - restored = bool_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(restored, True) - self.assertIsInstance(bool_input.generate_sample(), bool) - bool_input = gr.inputs.Checkbox(default=True, label="Check Your Input") - self.assertEqual( - bool_input.get_template_context(), - { - "default_value": True, - "name": "checkbox", - "show_label": True, - "label": "Check Your Input", - "css": {}, - "interactive": None, - }, - ) - - def test_in_interface(self): - iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "number") - self.assertEqual(iface.process([True]), [1]) - iface = gr.Interface( - lambda x: 1 if x else 0, "checkbox", "number", interpretation="default" - ) - scores = iface.interpret([False])[0]["interpretation"] - self.assertEqual(scores, (None, 1.0)) - scores = iface.interpret([True])[0]["interpretation"] - self.assertEqual(scores, (-1.0, None)) - - -class TestCheckboxGroup(unittest.TestCase): - def test_as_component(self): - checkboxes_input = gr.inputs.CheckboxGroup(["a", "b", "c"]) - self.assertEqual(checkboxes_input.preprocess(["a", "c"]), ["a", "c"]) - self.assertEqual(checkboxes_input.preprocess_example(["a", "c"]), ["a", "c"]) - self.assertEqual(checkboxes_input.serialize(["a", "c"], True), ["a", "c"]) - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = checkboxes_input.save_flagged( - tmpdirname, "checkboxes_input", ["a", "c"], None - ) - self.assertEqual(to_save, '["a", "c"]') - restored = checkboxes_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(restored, ["a", "c"]) - self.assertIsInstance(checkboxes_input.generate_sample(), list) - checkboxes_input = gr.inputs.CheckboxGroup( - default=["a", "c"], choices=["a", "b", "c"], label="Check Your Inputs" - ) - self.assertEqual( - checkboxes_input.get_template_context(), - { - "choices": ["a", "b", "c"], - "default_value": ["a", "c"], - "name": "checkboxgroup", - "show_label": True, - "label": "Check Your Inputs", - "css": {}, - "interactive": None, - }, - ) - with self.assertRaises(ValueError): - wrong_type = gr.inputs.CheckboxGroup(["a"], type="unknown") - wrong_type.preprocess(0) - - def test_in_interface(self): - checkboxes_input = gr.inputs.CheckboxGroup(["a", "b", "c"]) - iface = gr.Interface(lambda x: "|".join(x), checkboxes_input, "textbox") - self.assertEqual(iface.process([["a", "c"]]), ["a|c"]) - self.assertEqual(iface.process([[]]), [""]) - checkboxes_input = gr.inputs.CheckboxGroup(["a", "b", "c"], type="index") - - -class TestRadio(unittest.TestCase): - def test_as_component(self): - radio_input = gr.inputs.Radio(["a", "b", "c"]) - self.assertEqual(radio_input.preprocess("c"), "c") - self.assertEqual(radio_input.preprocess_example("a"), "a") - self.assertEqual(radio_input.serialize("a", True), "a") - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = radio_input.save_flagged(tmpdirname, "radio_input", "a", None) - self.assertEqual(to_save, "a") - restored = radio_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(restored, "a") - self.assertIsInstance(radio_input.generate_sample(), str) - radio_input = gr.inputs.Radio( - choices=["a", "b", "c"], default="a", label="Pick Your One Input" - ) - self.assertEqual( - radio_input.get_template_context(), - { - "choices": ["a", "b", "c"], - "default_value": "a", - "name": "radio", - "show_label": True, - "label": "Pick Your One Input", - "css": {}, - "interactive": None, - }, - ) - with self.assertRaises(ValueError): - wrong_type = gr.inputs.Radio(["a", "b"], type="unknown") - wrong_type.preprocess(0) - - def test_in_interface(self): - radio_input = gr.inputs.Radio(["a", "b", "c"]) - iface = gr.Interface(lambda x: 2 * x, radio_input, "textbox") - self.assertEqual(iface.process(["c"]), ["cc"]) - radio_input = gr.inputs.Radio(["a", "b", "c"], type="index") - iface = gr.Interface( - lambda x: 2 * x, radio_input, "number", interpretation="default" - ) - self.assertEqual(iface.process(["c"]), [4]) - scores = iface.interpret(["b"])[0]["interpretation"] - self.assertEqual(scores, [-2.0, None, 2.0]) - - -class TestDropdown(unittest.TestCase): - def test_as_component(self): - dropdown_input = gr.inputs.Dropdown(["a", "b", "c"]) - self.assertEqual(dropdown_input.preprocess("c"), "c") - self.assertEqual(dropdown_input.preprocess_example("a"), "a") - self.assertEqual(dropdown_input.serialize("a", True), "a") - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = dropdown_input.save_flagged( - tmpdirname, "dropdown_input", "a", None - ) - self.assertEqual(to_save, "a") - restored = dropdown_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(restored, "a") - self.assertIsInstance(dropdown_input.generate_sample(), str) - dropdown_input = gr.inputs.Dropdown( - choices=["a", "b", "c"], default="a", label="Drop Your Input" - ) - self.assertEqual( - dropdown_input.get_template_context(), - { - "choices": ["a", "b", "c"], - "default_value": "a", - "name": "dropdown", - "show_label": True, - "label": "Drop Your Input", - "css": {}, - "interactive": None, - }, - ) - with self.assertRaises(ValueError): - wrong_type = gr.inputs.Dropdown(["a"], type="unknown") - wrong_type.preprocess(0) - - def test_in_interface(self): - dropdown_input = gr.inputs.Dropdown(["a", "b", "c"]) - iface = gr.Interface(lambda x: 2 * x, dropdown_input, "textbox") - self.assertEqual(iface.process(["c"]), ["cc"]) - dropdown = gr.inputs.Dropdown(["a", "b", "c"], type="index") - iface = gr.Interface( - lambda x: 2 * x, dropdown, "number", interpretation="default" - ) - self.assertEqual(iface.process(["c"]), [4]) - scores = iface.interpret(["b"])[0]["interpretation"] - self.assertEqual(scores, [-2.0, None, 2.0]) - - -class TestImage(unittest.TestCase): - def test_as_component(self): - img = media_data.BASE64_IMAGE - image_input = gr.inputs.Image() - self.assertEqual(image_input.preprocess(img).shape, (68, 61, 3)) - image_input = gr.inputs.Image(shape=(25, 25), image_mode="L") - self.assertEqual(image_input.preprocess(img).shape, (25, 25)) - image_input = gr.inputs.Image(shape=(30, 10), type="pil") - self.assertEqual(image_input.preprocess(img).size, (30, 10)) - self.assertEqual(image_input.preprocess_example("test/test_files/bus.png"), img) - self.assertEqual(image_input.serialize("test/test_files/bus.png", True), img) - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = image_input.save_flagged(tmpdirname, "image_input", img, None) - self.assertEqual("image_input/0.png", to_save) - to_save = image_input.save_flagged(tmpdirname, "image_input", img, None) - self.assertEqual("image_input/1.png", to_save) - restored = image_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(restored, os.path.join(tmpdirname, "image_input/1.png")) - - self.assertIsInstance(image_input.generate_sample(), str) - image_input = gr.inputs.Image( - source="upload", tool="editor", type="pil", label="Upload Your Image" - ) - self.assertEqual( - image_input.get_template_context(), - { - "image_mode": "RGB", - "shape": None, - "source": "upload", - "tool": "editor", - "name": "image", - "show_label": True, - "label": "Upload Your Image", - "css": {}, - "default_value": None, - "interactive": None, - }, - ) - self.assertIsNone(image_input.preprocess(None)) - image_input = gr.inputs.Image(invert_colors=True) - self.assertIsNotNone(image_input.preprocess(img)) - image_input.preprocess(img) - with self.assertWarns(DeprecationWarning): - file_image = gr.inputs.Image(type="file") - file_image.preprocess(media_data.BASE64_IMAGE) - file_image = gr.inputs.Image(type="filepath") - self.assertIsInstance(file_image.preprocess(img), str) - with self.assertRaises(ValueError): - wrong_type = gr.inputs.Image(type="unknown") - wrong_type.preprocess(img) - with self.assertRaises(ValueError): - wrong_type = gr.inputs.Image(type="unknown") - wrong_type.serialize("test/test_files/bus.png", False) - img_pil = PIL.Image.open("test/test_files/bus.png") - image_input = gr.inputs.Image(type="numpy") - self.assertIsInstance(image_input.serialize(img_pil, False), str) - image_input = gr.inputs.Image(type="pil") - self.assertIsInstance(image_input.serialize(img_pil, False), str) - image_input = gr.inputs.Image(type="file") - with open("test/test_files/bus.png") as f: - self.assertEqual(image_input.serialize(f, False), img) - image_input.shape = (30, 10) - self.assertIsNotNone(image_input._segment_by_slic(img)) - - def test_in_interface(self): - img = media_data.BASE64_IMAGE - image_input = gr.inputs.Image() - iface = gr.Interface( - lambda x: PIL.Image.open(x).rotate(90, expand=True), - gr.inputs.Image(shape=(30, 10), type="file"), - "image", - ) - output = iface.process([img])[0] - self.assertEqual( - gr.processing_utils.decode_base64_to_image(output).size, (10, 30) - ) - iface = gr.Interface( - lambda x: np.sum(x), image_input, "number", interpretation="default" - ) - scores = iface.interpret([img])[0]["interpretation"] - self.assertEqual(scores, media_data.SUM_PIXELS_INTERPRETATION["scores"][0]) - iface = gr.Interface( - lambda x: np.sum(x), image_input, "label", interpretation="shap" - ) - scores = iface.interpret([img])[0]["interpretation"] - self.assertEqual( - len(scores[0]), - len(media_data.SUM_PIXELS_SHAP_INTERPRETATION["scores"][0][0]), - ) - image_input = gr.inputs.Image(shape=(30, 10)) - iface = gr.Interface( - lambda x: np.sum(x), image_input, "number", interpretation="default" - ) - self.assertIsNotNone(iface.interpret([img])) - - -class TestAudio(unittest.TestCase): - def test_as_component(self): - x_wav = copy.deepcopy(media_data.BASE64_AUDIO) - audio_input = gr.inputs.Audio() - output = audio_input.preprocess(x_wav) - self.assertEqual(output[0], 8000) - self.assertEqual(output[1].shape, (8046,)) - self.assertEqual( - audio_input.serialize("test/test_files/audio_sample.wav", True)["data"], - x_wav["data"], - ) - - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = audio_input.save_flagged(tmpdirname, "audio_input", x_wav, None) - self.assertEqual("audio_input/0.wav", to_save) - to_save = audio_input.save_flagged(tmpdirname, "audio_input", x_wav, None) - self.assertEqual("audio_input/1.wav", to_save) - restored = audio_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(restored, "audio_input/1.wav") - - self.assertIsInstance(audio_input.generate_sample(), dict) - audio_input = gr.inputs.Audio(label="Upload Your Audio") - self.assertEqual( - audio_input.get_template_context(), - { - "source": "upload", - "name": "audio", - "show_label": True, - "label": "Upload Your Audio", - "css": {}, - "default_value": None, - "interactive": None, - }, - ) - self.assertIsNone(audio_input.preprocess(None)) - x_wav["is_example"] = True - x_wav["crop_min"], x_wav["crop_max"] = 1, 4 - self.assertIsNotNone(audio_input.preprocess(x_wav)) - with self.assertWarns(DeprecationWarning): - audio_input = gr.inputs.Audio(type="file") - audio_input.preprocess(x_wav) - with open("test/test_files/audio_sample.wav") as f: - audio_input.serialize(f, False) - audio_input = gr.inputs.Audio(type="filepath") - self.assertIsInstance(audio_input.preprocess(x_wav), str) - with self.assertRaises(ValueError): - audio_input = gr.inputs.Audio(type="unknown") - audio_input.preprocess(x_wav) - audio_input.serialize(x_wav, False) - audio_input = gr.inputs.Audio(type="numpy") - x_wav = gr.processing_utils.audio_from_file("test/test_files/audio_sample.wav") - self.assertIsInstance(audio_input.serialize(x_wav, False), dict) - - def test_tokenize(self): - x_wav = media_data.BASE64_AUDIO - audio_input = gr.inputs.Audio() - tokens, _, _ = audio_input.tokenize(x_wav) - self.assertEquals(len(tokens), audio_input.interpretation_segments) - x_new = audio_input.get_masked_inputs(tokens, [[1] * len(tokens)])[0] - similarity = SequenceMatcher(a=x_wav["data"], b=x_new).ratio() - self.assertGreater(similarity, 0.9) - - -class TestFile(unittest.TestCase): - def test_as_component(self): - x_file = media_data.BASE64_FILE - file_input = gr.inputs.File() - output = file_input.preprocess(x_file) - self.assertIsInstance(output, tempfile._TemporaryFileWrapper) - self.assertEqual( - file_input.serialize("test/test_files/sample_file.pdf", True), - "test/test_files/sample_file.pdf", - ) - - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = file_input.save_flagged(tmpdirname, "file_input", [x_file], None) - self.assertEqual("file_input/0", to_save) - to_save = file_input.save_flagged(tmpdirname, "file_input", [x_file], None) - self.assertEqual("file_input/1", to_save) - restored = file_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(restored, "file_input/1") - - self.assertIsInstance(file_input.generate_sample(), dict) - file_input = gr.inputs.File(label="Upload Your File") - self.assertEqual( - file_input.get_template_context(), - { - "file_count": "single", - "name": "file", - "show_label": True, - "label": "Upload Your File", - "css": {}, - "default_value": None, - "interactive": None, - }, - ) - self.assertIsNone(file_input.preprocess(None)) - x_file["is_example"] = True - self.assertIsNotNone(file_input.preprocess(x_file)) - - def test_in_interface(self): - x_file = media_data.BASE64_FILE - - def get_size_of_file(file_obj): - return os.path.getsize(file_obj.name) - - iface = gr.Interface(get_size_of_file, "file", "number") - self.assertEqual(iface.process([[x_file]]), [10558]) - - -class TestDataframe(unittest.TestCase): - def test_as_component(self): - x_data = [["Tim", 12, False], ["Jan", 24, True]] - dataframe_input = gr.inputs.Dataframe(headers=["Name", "Age", "Member"]) - output = dataframe_input.preprocess(x_data) - self.assertEqual(output["Age"][1], 24) - self.assertEqual(output["Member"][0], False) - self.assertEqual(dataframe_input.preprocess_example(x_data), x_data) - self.assertEqual(dataframe_input.serialize(x_data, True), x_data) - - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = dataframe_input.save_flagged( - tmpdirname, "dataframe_input", x_data, None - ) - self.assertEqual(json.dumps(x_data), to_save) - restored = dataframe_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(x_data, restored) - - self.assertIsInstance(dataframe_input.generate_sample(), list) - dataframe_input = gr.inputs.Dataframe( - headers=["Name", "Age", "Member"], label="Dataframe Input" - ) - self.assertEqual( - dataframe_input.get_template_context(), - { - "headers": ["Name", "Age", "Member"], - "datatype": "str", - "row_count": 3, - "col_count": 3, - "col_width": None, - "default_value": [ - ["", "", ""], - ["", "", ""], - ["", "", ""], - ], - "name": "dataframe", - "show_label": True, - "label": "Dataframe Input", - "max_rows": 20, - "max_cols": None, - "overflow_row_behaviour": "paginate", - "css": {}, - "interactive": None, - }, - ) - dataframe_input = gr.inputs.Dataframe() - output = dataframe_input.preprocess(x_data) - self.assertEqual(output[1][1], 24) - with self.assertRaises(ValueError): - wrong_type = gr.inputs.Dataframe(type="unknown") - wrong_type.preprocess(x_data) - - def test_in_interface(self): - x_data = [[1, 2, 3], [4, 5, 6]] - iface = gr.Interface(np.max, "numpy", "number") - self.assertEqual(iface.process([x_data]), [6]) - x_data = [["Tim"], ["Jon"], ["Sal"]] - - def get_last(my_list): - return my_list[-1] - - iface = gr.Interface(get_last, "list", "text") - self.assertEqual(iface.process([x_data]), ["Sal"]) - - -class TestVideo(unittest.TestCase): - def test_as_component(self): - x_video = media_data.BASE64_VIDEO - video_input = gr.inputs.Video() - output = video_input.preprocess(x_video) - self.assertIsInstance(output, str) - - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = video_input.save_flagged(tmpdirname, "video_input", x_video, None) - self.assertEqual("video_input/0.mp4", to_save) - to_save = video_input.save_flagged(tmpdirname, "video_input", x_video, None) - self.assertEqual("video_input/1.mp4", to_save) - restored = video_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(restored, "video_input/1.mp4") - - self.assertIsInstance(video_input.generate_sample(), dict) - video_input = gr.inputs.Video(label="Upload Your Video") - self.assertEqual( - video_input.get_template_context(), - { - "source": "upload", - "name": "video", - "show_label": True, - "label": "Upload Your Video", - "css": {}, - "default_value": None, - "interactive": None, - }, - ) - self.assertIsNone(video_input.preprocess(None)) - x_video["is_example"] = True - self.assertIsNotNone(video_input.preprocess(x_video)) - video_input = gr.inputs.Video(type="avi") - # self.assertEqual(video_input.preprocess(x_video)[-3:], "avi") - with self.assertRaises(NotImplementedError): - video_input.serialize(x_video, True) - - def test_in_interface(self): - x_video = media_data.BASE64_VIDEO - iface = gr.Interface(lambda x: x, "video", "playable_video") - self.assertEqual(iface.process([x_video])[0]["data"], x_video["data"]) - - -class TestTimeseries(unittest.TestCase): - def test_as_component(self): - timeseries_input = gr.inputs.Timeseries(x="time", y=["retail", "food", "other"]) - x_timeseries = { - "data": [[1] + [2] * len(timeseries_input.y)] * 4, - "headers": [timeseries_input.x] + timeseries_input.y, - } - output = timeseries_input.preprocess(x_timeseries) - self.assertIsInstance(output, pandas.core.frame.DataFrame) - - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = timeseries_input.save_flagged( - tmpdirname, "video_input", x_timeseries, None - ) - self.assertEqual(json.dumps(x_timeseries), to_save) - restored = timeseries_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(x_timeseries, restored) - - self.assertIsInstance(timeseries_input.generate_sample(), dict) - timeseries_input = gr.inputs.Timeseries( - x="time", y="retail", label="Upload Your Timeseries" - ) - self.assertEqual( - timeseries_input.get_template_context(), - { - "x": "time", - "y": ["retail"], - "name": "timeseries", - "show_label": True, - "label": "Upload Your Timeseries", - "css": {}, - "default_value": None, - "interactive": None, - }, - ) - self.assertIsNone(timeseries_input.preprocess(None)) - x_timeseries["range"] = (0, 1) - self.assertIsNotNone(timeseries_input.preprocess(x_timeseries)) - - def test_in_interface(self): - timeseries_input = gr.inputs.Timeseries(x="time", y=["retail", "food", "other"]) - x_timeseries = { - "data": [[1] + [2] * len(timeseries_input.y)] * 4, - "headers": [timeseries_input.x] + timeseries_input.y, - } - iface = gr.Interface(lambda x: x, timeseries_input, "dataframe") - self.assertEqual( - iface.process([x_timeseries]), - [ - { - "headers": ["time", "retail", "food", "other"], - "data": [[1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2]], - } - ], - ) - - -class TestImage3D(unittest.TestCase): - def test_as_component(self): - Image3D = media_data.BASE64_MODEL3D - Image3D_input = gr.inputs.Image3D() - output = Image3D_input.preprocess(Image3D) - self.assertIsInstance(output, str) - - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = Image3D_input.save_flagged( - tmpdirname, "Image3D_input", Image3D, None - ) - self.assertEqual("Image3D_input/0.gltf", to_save) - to_save = Image3D_input.save_flagged( - tmpdirname, "Image3D_input", Image3D, None - ) - self.assertEqual("Image3D_input/1.gltf", to_save) - restored = Image3D_input.restore_flagged(tmpdirname, to_save, None) - self.assertEqual(restored["name"], "Image3D_input/1.gltf") - - self.assertIsInstance(Image3D_input.generate_sample(), dict) - Image3D_input = gr.inputs.Image3D(label="Upload Your 3D Image Model") - self.assertEqual( - Image3D_input.get_template_context(), - { - "clearColor": None, - "name": "image3d", - "css": {}, - "interactive": None, - "show_label": True, - "label": "Upload Your 3D Image Model", - }, - ) - - self.assertIsNone(Image3D_input.preprocess(None)) - Image3D["is_example"] = True - self.assertIsNotNone(Image3D_input.preprocess(Image3D)) - Image3D_input = gr.inputs.Image3D() - with self.assertRaises(NotImplementedError): - Image3D_input.serialize(Image3D, True) - - def test_in_interface(self): - Image3D = media_data.BASE64_MODEL3D - iface = gr.Interface(lambda x: x, "model3d", "model3d") - self.assertEqual( - iface.process([Image3D])[0]["data"], - Image3D["data"].replace("@file/gltf", ""), - ) - - -class TestNames(unittest.TestCase): - # this ensures that `components.get_component_instance()` works correctly when instantiating from components - def test_no_duplicate_uncased_names(self): - subclasses = gr.components.Component.__subclasses__() - unique_subclasses_uncased = set([s.__name__.lower() for s in subclasses]) - self.assertEqual(len(subclasses), len(unique_subclasses_uncased)) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_interfaces.py b/test/test_interfaces.py index 9ada65ecab0ee..b5ff6e703cccb 100644 --- a/test/test_interfaces.py +++ b/test/test_interfaces.py @@ -168,7 +168,10 @@ def test_integration_wandb(self): wandb.log = mock.MagicMock() wandb.Html = mock.MagicMock() interface = Interface(lambda x: x, "textbox", "label") + interface.width = 500 + interface.height = 500 interface.integrate(wandb=wandb) + self.assertEqual( out.getvalue().strip(), "The WandB integration requires you to `launch(share=True)` first.", @@ -209,5 +212,11 @@ def test_tabbed_interface_config_matches_manual_tab(self): ) +class TestDeprecatedInterface(unittest.TestCase): + def test_deprecation_notice(self): + with self.assertWarns(Warning): + _ = Interface(lambda x: x, "textbox", "textbox", verbose=True) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_outputs.py b/test/test_outputs.py deleted file mode 100644 index c2a5fe1cd59ef..0000000000000 --- a/test/test_outputs.py +++ /dev/null @@ -1,575 +0,0 @@ -import json -import os -import tempfile -import unittest - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -import gradio as gr -from gradio import media_data - -os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" - - -# TODO: Delete this file after confirming backwards compatibility works well. - - -class TestTextbox(unittest.TestCase): - def test_in_interface(self): - iface = gr.Interface(lambda x: x[-1], "textbox", gr.outputs.Textbox()) - self.assertEqual(iface.process(["Hello"]), ["o"]) - iface = gr.Interface(lambda x: x / 2, "number", gr.outputs.Textbox()) - self.assertEqual(iface.process([10]), ["5.0"]) - - -class TestLabel(unittest.TestCase): - def test_as_component(self): - y = "happy" - label_output = gr.outputs.Label() - label = label_output.postprocess(y) - self.assertDictEqual(label, {"label": "happy"}) - self.assertEqual(label_output.deserialize(y), y) - self.assertEqual(label_output.deserialize(label), y) - with tempfile.TemporaryDirectory() as tmpdir: - to_save = label_output.save_flagged(tmpdir, "label_output", label, None) - self.assertEqual(to_save, y) - y = {3: 0.7, 1: 0.2, 0: 0.1} - label_output = gr.outputs.Label() - label = label_output.postprocess(y) - self.assertDictEqual( - label, - { - "label": 3, - "confidences": [ - {"label": 3, "confidence": 0.7}, - {"label": 1, "confidence": 0.2}, - {"label": 0, "confidence": 0.1}, - ], - }, - ) - label_output = gr.outputs.Label(num_top_classes=2) - label = label_output.postprocess(y) - self.assertDictEqual( - label, - { - "label": 3, - "confidences": [ - {"label": 3, "confidence": 0.7}, - {"label": 1, "confidence": 0.2}, - ], - }, - ) - with self.assertRaises(ValueError): - label_output.postprocess([1, 2, 3]) - - with tempfile.TemporaryDirectory() as tmpdir: - to_save = label_output.save_flagged(tmpdir, "label_output", label, None) - self.assertEqual(to_save, '{"3": 0.7, "1": 0.2}') - self.assertEqual( - label_output.restore_flagged(tmpdir, to_save, None), - { - "label": "3", - "confidences": [ - {"label": "3", "confidence": 0.7}, - {"label": "1", "confidence": 0.2}, - ], - }, - ) - - def test_in_interface(self): - x_img = media_data.BASE64_IMAGE - - def rgb_distribution(img): - rgb_dist = np.mean(img, axis=(0, 1)) - rgb_dist /= np.sum(rgb_dist) - rgb_dist = np.round(rgb_dist, decimals=2) - return { - "red": rgb_dist[0], - "green": rgb_dist[1], - "blue": rgb_dist[2], - } - - iface = gr.Interface(rgb_distribution, "image", "label") - output = iface.process([x_img])[0] - self.assertDictEqual( - output, - { - "label": "red", - "confidences": [ - {"label": "red", "confidence": 0.44}, - {"label": "green", "confidence": 0.28}, - {"label": "blue", "confidence": 0.28}, - ], - }, - ) - - -class TestImage(unittest.TestCase): - def test_as_component(self): - y_img = gr.processing_utils.decode_base64_to_image(media_data.BASE64_IMAGE) - image_output = gr.outputs.Image() - self.assertTrue( - image_output.postprocess(y_img).startswith( - "" - ) - ) - self.assertTrue( - image_output.postprocess(np.array(y_img)).startswith( - "" - ) - ) - with self.assertWarns(DeprecationWarning): - plot_output = gr.outputs.Image(plot=True) - - xpoints = np.array([0, 6]) - ypoints = np.array([0, 250]) - fig = plt.figure() - plt.plot(xpoints, ypoints) - self.assertTrue( - plot_output.postprocess(fig).startswith("data:image/png;base64,") - ) - with self.assertRaises(ValueError): - image_output.postprocess([1, 2, 3]) - image_output = gr.outputs.Image(type="numpy") - self.assertTrue( - image_output.postprocess(y_img).startswith("data:image/png;base64,") - ) - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = image_output.save_flagged( - tmpdirname, "image_output", media_data.BASE64_IMAGE, None - ) - self.assertEqual("image_output/0.png", to_save) - to_save = image_output.save_flagged( - tmpdirname, "image_output", media_data.BASE64_IMAGE, None - ) - self.assertEqual("image_output/1.png", to_save) - - def test_in_interface(self): - def generate_noise(width, height): - return np.random.randint(0, 256, (width, height, 3)) - - iface = gr.Interface(generate_noise, ["slider", "slider"], "image") - self.assertTrue(iface.process([10, 20])[0].startswith("data:image/png;base64")) - - -class TestVideo(unittest.TestCase): - def test_as_component(self): - y_vid = "test/test_files/video_sample.mp4" - video_output = gr.outputs.Video() - self.assertTrue( - video_output.postprocess(y_vid)["data"].startswith("data:video/mp4;base64,") - ) - self.assertTrue( - video_output.deserialize(media_data.BASE64_VIDEO["data"]).endswith(".mp4") - ) - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = video_output.save_flagged( - tmpdirname, "video_output", media_data.BASE64_VIDEO, None - ) - self.assertEqual("video_output/0.mp4", to_save) - to_save = video_output.save_flagged( - tmpdirname, "video_output", media_data.BASE64_VIDEO, None - ) - self.assertEqual("video_output/1.mp4", to_save) - - -class TestHighlightedText(unittest.TestCase): - def test_as_component(self): - ht_output = gr.outputs.HighlightedText(color_map={"pos": "green", "neg": "red"}) - self.assertEqual( - ht_output.get_template_context(), - { - "color_map": {"pos": "green", "neg": "red"}, - "name": "highlightedtext", - "show_label": True, - "label": None, - "show_legend": False, - "css": {}, - "default_value": "", - "interactive": None, - }, - ) - ht = {"pos": "Hello ", "neg": "World"} - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = ht_output.save_flagged(tmpdirname, "ht_output", ht, None) - self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}') - self.assertEqual( - ht_output.restore_flagged(tmpdirname, to_save, None), - {"pos": "Hello ", "neg": "World"}, - ) - - def test_in_interface(self): - def highlight_vowels(sentence): - phrases, cur_phrase = [], "" - vowels, mode = "aeiou", None - for letter in sentence: - letter_mode = "vowel" if letter in vowels else "non" - if mode is None: - mode = letter_mode - elif mode != letter_mode: - phrases.append((cur_phrase, mode)) - cur_phrase = "" - mode = letter_mode - cur_phrase += letter - phrases.append((cur_phrase, mode)) - return phrases - - iface = gr.Interface(highlight_vowels, "text", "highlight") - self.assertListEqual( - iface.process(["Helloooo"])[0], - [("H", "non"), ("e", "vowel"), ("ll", "non"), ("oooo", "vowel")], - ) - - -class TestAudio(unittest.TestCase): - def test_as_component(self): - y_audio = gr.processing_utils.decode_base64_to_file( - media_data.BASE64_AUDIO["data"] - ) - audio_output = gr.outputs.Audio(type="file") - self.assertTrue( - audio_output.postprocess(y_audio.name).startswith( - "data:audio/wav;base64,UklGRuI/AABXQVZFZm10IBAAA" - ) - ) - self.assertEqual( - audio_output.get_template_context(), - { - "name": "audio", - "show_label": True, - "label": None, - "source": "upload", - "css": {}, - "default_value": None, - "interactive": None, - }, - ) - self.assertTrue( - audio_output.deserialize(media_data.BASE64_AUDIO["data"]).endswith(".wav") - ) - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = audio_output.save_flagged( - tmpdirname, "audio_output", media_data.BASE64_AUDIO, None - ) - self.assertEqual("audio_output/0.wav", to_save) - to_save = audio_output.save_flagged( - tmpdirname, "audio_output", media_data.BASE64_AUDIO, None - ) - self.assertEqual("audio_output/1.wav", to_save) - - def test_in_interface(self): - def generate_noise(duration): - return 48000, np.random.randint(-256, 256, (duration, 3)).astype(np.int32) - - iface = gr.Interface(generate_noise, "slider", "audio") - self.assertTrue(iface.process([100])[0].startswith("data:audio/wav;base64")) - - -class TestJSON(unittest.TestCase): - def test_as_component(self): - js_output = gr.outputs.JSON() - self.assertTrue( - js_output.postprocess('{"a":1, "b": 2}'), '"{\\"a\\":1, \\"b\\": 2}"' - ) - js = {"pos": "Hello ", "neg": "World"} - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = js_output.save_flagged(tmpdirname, "js_output", js, None) - self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}') - self.assertEqual( - js_output.restore_flagged(tmpdirname, to_save, None), - {"pos": "Hello ", "neg": "World"}, - ) - - def test_in_interface(self): - def get_avg_age_per_gender(data): - return { - "M": int(data[data["gender"] == "M"].mean()), - "F": int(data[data["gender"] == "F"].mean()), - "O": int(data[data["gender"] == "O"].mean()), - } - - iface = gr.Interface( - get_avg_age_per_gender, - gr.inputs.Dataframe(headers=["gender", "age"]), - "json", - ) - y_data = [ - ["M", 30], - ["F", 20], - ["M", 40], - ["O", 20], - ["F", 30], - ] - self.assertDictEqual(iface.process([y_data])[0], {"M": 35, "F": 25, "O": 20}) - - -class TestHTML(unittest.TestCase): - def test_in_interface(self): - def bold_text(text): - return "" + text + "" - - iface = gr.Interface(bold_text, "text", "html") - self.assertEqual(iface.process(["test"])[0], "test") - - -class TestFile(unittest.TestCase): - def test_as_component(self): - def write_file(content): - with open("test.txt", "w") as f: - f.write(content) - return "test.txt" - - iface = gr.Interface(write_file, "text", "file") - self.assertDictEqual( - iface.process(["hello world"])[0], - { - "name": "test.txt", - "size": 11, - "data": "data:text/plain;base64,aGVsbG8gd29ybGQ=", - }, - ) - file_output = gr.outputs.File() - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = file_output.save_flagged( - tmpdirname, "file_output", [media_data.BASE64_FILE], None - ) - self.assertEqual("file_output/0", to_save) - to_save = file_output.save_flagged( - tmpdirname, "file_output", [media_data.BASE64_FILE], None - ) - self.assertEqual("file_output/1", to_save) - - -class TestDataframe(unittest.TestCase): - def test_as_component(self): - dataframe_output = gr.outputs.Dataframe() - output = dataframe_output.postprocess(np.zeros((2, 2))) - self.assertDictEqual(output, {"data": [[0, 0], [0, 0]]}) - output = dataframe_output.postprocess([[1, 3, 5]]) - self.assertDictEqual(output, {"data": [[1, 3, 5]]}) - output = dataframe_output.postprocess( - pd.DataFrame([[2, True], [3, True], [4, False]], columns=["num", "prime"]) - ) - self.assertDictEqual( - output, - {"headers": ["num", "prime"], "data": [[2, True], [3, True], [4, False]]}, - ) - self.assertEqual( - dataframe_output.get_template_context(), - { - "headers": None, - "max_rows": 20, - "max_cols": None, - "overflow_row_behaviour": "paginate", - "name": "dataframe", - "show_label": True, - "label": None, - "css": {}, - "datatype": "str", - "row_count": 3, - "col_count": 3, - "col_width": None, - "default_value": [ - ["", "", ""], - ["", "", ""], - ["", "", ""], - ], - "name": "dataframe", - "interactive": None, - }, - ) - with self.assertRaises(ValueError): - wrong_type = gr.outputs.Dataframe(type="unknown") - wrong_type.postprocess(0) - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = dataframe_output.save_flagged( - tmpdirname, "dataframe_output", output, None - ) - self.assertEqual( - to_save, - json.dumps( - { - "headers": ["num", "prime"], - "data": [[2, True], [3, True], [4, False]], - } - ), - ) - self.assertEqual( - dataframe_output.restore_flagged(tmpdirname, to_save, None), - { - "headers": ["num", "prime"], - "data": [[2, True], [3, True], [4, False]], - }, - ) - - def test_in_interface(self): - def check_odd(array): - return array % 2 == 0 - - iface = gr.Interface(check_odd, "numpy", "numpy") - self.assertEqual(iface.process([[2, 3, 4]])[0], {"data": [[True, False, True]]}) - - -class TestCarousel(unittest.TestCase): - def test_as_component(self): - carousel_output = gr.outputs.Carousel(["text", "image"], label="Disease") - - output = carousel_output.postprocess( - [ - ["Hello World", "test/test_files/bus.png"], - ["Bye World", "test/test_files/bus.png"], - ] - ) - self.assertEqual( - output, - [ - ["Hello World", media_data.BASE64_IMAGE], - ["Bye World", media_data.BASE64_IMAGE], - ], - ) - - carousel_output = gr.outputs.Carousel("text", label="Disease") - output = carousel_output.postprocess([["Hello World"], ["Bye World"]]) - self.assertEqual(output, [["Hello World"], ["Bye World"]]) - self.assertEqual( - carousel_output.get_template_context(), - { - "components": [ - { - "name": "textbox", - "show_label": True, - "label": None, - "default_value": "", - "lines": 1, - "max_lines": 20, - "css": {}, - "placeholder": None, - "interactive": None, - } - ], - "name": "carousel", - "show_label": True, - "label": "Disease", - "css": {}, - "interactive": None, - }, - ) - output = carousel_output.postprocess(["Hello World", "Bye World"]) - self.assertEqual(output, [["Hello World"], ["Bye World"]]) - with self.assertRaises(ValueError): - carousel_output.postprocess("Hello World!") - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = carousel_output.save_flagged( - tmpdirname, "carousel_output", output, None - ) - self.assertEqual(to_save, '[["Hello World"], ["Bye World"]]') - - def test_in_interface(self): - carousel_output = gr.outputs.Carousel(["text", "image"], label="Disease") - - def report(img): - results = [] - for i, mode in enumerate(["Red", "Green", "Blue"]): - color_filter = np.array([0, 0, 0]) - color_filter[i] = 1 - results.append([mode, img * color_filter]) - return results - - iface = gr.Interface(report, gr.inputs.Image(type="numpy"), carousel_output) - self.assertEqual( - iface.process([media_data.BASE64_IMAGE]), - [ - [ - [ - "Red", - "", - ], - [ - "Green", - "", - ], - [ - "Blue", - "", - ], - ] - ], - ) - - -class TestTimeseries(unittest.TestCase): - def test_as_component(self): - timeseries_output = gr.outputs.Timeseries(label="Disease") - self.assertEqual( - timeseries_output.get_template_context(), - { - "x": None, - "y": None, - "name": "timeseries", - "show_label": True, - "label": "Disease", - "css": {}, - "default_value": None, - "interactive": None, - }, - ) - data = {"Name": ["Tom", "nick", "krish", "jack"], "Age": [20, 21, 19, 18]} - df = pd.DataFrame(data) - self.assertEqual( - timeseries_output.postprocess(df), - { - "headers": ["Name", "Age"], - "data": [["Tom", 20], ["nick", 21], ["krish", 19], ["jack", 18]], - }, - ) - - timeseries_output = gr.outputs.Timeseries(y="Age", label="Disease") - output = timeseries_output.postprocess(df) - self.assertEqual( - output, - { - "headers": ["Name", "Age"], - "data": [["Tom", 20], ["nick", 21], ["krish", 19], ["jack", 18]], - }, - ) - - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = timeseries_output.save_flagged( - tmpdirname, "timeseries_output", output, None - ) - self.assertEqual( - to_save, - '{"headers": ["Name", "Age"], "data": [["Tom", 20], ["nick", 21], ["krish", 19], ' - '["jack", 18]]}', - ) - self.assertEqual( - timeseries_output.restore_flagged(tmpdirname, to_save, None), - { - "headers": ["Name", "Age"], - "data": [["Tom", 20], ["nick", 21], ["krish", 19], ["jack", 18]], - }, - ) - - -class TestImage3D(unittest.TestCase): - def test_as_component(self): - Image3D = "test/test_files/Fox.gltf" - Image3D_output = gr.outputs.Image3D() - self.assertTrue( - Image3D_output.postprocess(Image3D)["data"].startswith("data:;base64,") - ) - with tempfile.TemporaryDirectory() as tmpdirname: - to_save = Image3D_output.save_flagged( - tmpdirname, "Image3D_output", media_data.BASE64_MODEL3D, None - ) - self.assertEqual("Image3D_output/0.gltf", to_save) - to_save = Image3D_output.save_flagged( - tmpdirname, "Image3D_output", media_data.BASE64_MODEL3D, None - ) - self.assertEqual("Image3D_output/1.gltf", to_save) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_routes.py b/test/test_routes.py index c92609fe14ff9..5ee7d6c7c94b8 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -6,7 +6,7 @@ from fastapi.testclient import TestClient -from gradio import Interface, queueing, reset_all +from gradio import Interface, close_all, queueing os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -82,7 +82,7 @@ def test_queue_push_route_2(self): def tearDown(self) -> None: self.io.close() - reset_all() + close_all() class TestAuthenticatedRoutes(unittest.TestCase): @@ -105,7 +105,7 @@ def test_post_login(self): def tearDown(self) -> None: self.io.close() - reset_all() + close_all() if __name__ == "__main__":