Skip to content

Commit

Permalink
merge in master
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali Abid committed Apr 19, 2022
2 parents 499f106 + f97cbbd commit 4f1947d
Show file tree
Hide file tree
Showing 39 changed files with 7,241 additions and 144 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ demo/files/*.mp4
*.bak
workspace.code-workspace
*.h5
.vscode/

# log files
.pnpm-debug.log
4 changes: 1 addition & 3 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ coverage:
- "gradio/"
target: 70%
threshold: 0.1
patch:
default:
target: 50% # new contributions should have a coverage at least equal to 50%
patch: off

comment: false
codecov:
Expand Down
9 changes: 6 additions & 3 deletions gradio.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
Metadata-Version: 1.0
Metadata-Version: 2.1
Name: gradio
Version: 2.7.0b70
Version: 2.8.13
Summary: Python library for easily interacting with trained machine learning models
Home-page: https://github.com/gradio-app/gradio-UI
Author: Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq, Pete Allen, Ömer Faruk Özdemir
Author-email: team@gradio.app
License: Apache License 2.0
Description: UNKNOWN
Keywords: machine learning,visualization,reproducibility
Platform: UNKNOWN
License-File: LICENSE

UNKNOWN

13 changes: 4 additions & 9 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from gradio import encryptor, networking, queueing, strings, utils
from gradio.context import Context
from gradio.process_examples import cache_interface_examples
from gradio.routes import PredictBody

if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from fastapi.applications import FastAPI
Expand All @@ -19,7 +19,8 @@


class Block:
def __init__(self, without_rendering=False):
def __init__(self, without_rendering=False, css=None):
self.css = css
if without_rendering:
return
self.render()
Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(self, visible: bool = True, css: Optional[Dict[str, str]] = None):
css: Css rules to apply to block.
"""
self.children = []
self.css = css if css is not None else {}
self.css = css
self.visible = visible
super().__init__()

Expand Down Expand Up @@ -378,7 +379,6 @@ def launch(
height: int = 500,
width: int = 900,
encrypt: bool = False,
cache_examples: bool = False,
favicon_path: Optional[str] = None,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
Expand All @@ -403,7 +403,6 @@ def launch(
width (int): The width in pixels of the iframe element containing the interface (used if inline=True)
height (int): The height in pixels of the iframe element containing the interface (used if inline=True)
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
cache_examples (bool): If True, examples outputs will be processed and cached in a folder, and will be used if a user uses an example input.
favicon_path (str): If a path to a file (.png, .gif, or .ico) is provided, it will be used as the favicon for the web page.
ssl_keyfile (str): If a path to a file is provided, will use this as the private key file to create a local server running on https.
ssl_certfile (str): If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
Expand All @@ -414,7 +413,6 @@ def launch(
share_url (str): Publicly accessible link to the demo (if share=True, otherwise None)
"""
self.config = self.get_config_file()
self.cache_examples = cache_examples
if (
auth
and not callable(auth)
Expand Down Expand Up @@ -443,9 +441,6 @@ def launch(
config = self.get_config_file()
self.config = config

if self.cache_examples:
cache_interface_examples(self)

if self.is_running:
self.server_app.launchable = self
print(
Expand Down
15 changes: 14 additions & 1 deletion gradio/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import requests

from gradio import components
from gradio import components, utils


def get_huggingface_interface(model_name, api_key, alias):
Expand Down Expand Up @@ -206,6 +206,13 @@ def encode_to_base64(r: requests.Response) -> str:
"preprocess": lambda x: {"inputs": x},
"postprocess": encode_to_base64,
},
"token-classification": {
# example model: hf.co/huggingface-course/bert-finetuned-ner
"inputs": components.Textbox(label="Input"),
"outputs": components.HighlightedText(label="Output"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r, # Handled as a special case in query_huggingface_api()
},
}

if p is None or not (p in pipelines):
Expand All @@ -228,6 +235,12 @@ def query_huggingface_api(*params):
response.status_code
)
)
if (
p == "token-classification"
): # Handle as a special case since HF API only returns the named entities and we need the input as well
ner_groups = response.json()
input_string = params[0]
response = utils.format_ner_list(input_string, ner_groups)
output = pipeline["postprocess"](response)
return output

Expand Down
21 changes: 10 additions & 11 deletions gradio/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import TYPE_CHECKING, Any, List, Optional

import gradio as gr
from gradio import encryptor
from gradio import encryptor, utils

if TYPE_CHECKING:
from gradio.components import Component
Expand Down Expand Up @@ -87,7 +87,7 @@ def flag(
)

with open(log_filepath, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
writer.writerow(csv_data)

with open(log_filepath, "r") as csvfile:
Expand Down Expand Up @@ -153,7 +153,7 @@ def replace_flag_at_index(file_content):
flag_col_index = header.index("flag")
content[flag_index][flag_col_index] = flag_option
output = io.StringIO()
writer = csv.writer(output)
writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
writer.writerows(content)
return output.getvalue()

Expand All @@ -169,7 +169,7 @@ def replace_flag_at_index(file_content):
if flag_index is not None:
file_content = replace_flag_at_index(file_content)
output.write(file_content)
writer = csv.writer(output)
writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
if flag_index is None:
if is_new:
writer.writerow(headers)
Expand All @@ -181,7 +181,9 @@ def replace_flag_at_index(file_content):
else:
if flag_index is None:
with open(log_filepath, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
writer = csv.writer(
csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'"
)
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
Expand Down Expand Up @@ -291,6 +293,7 @@ def flag(
headers = []

for component, sample in zip(self.components, flag_data):
headers.append(component.label)
headers.append(component.label)
infos["flagged"]["features"][component.label] = {
"dtype": "string",
Expand All @@ -316,12 +319,8 @@ def flag(
# Generate the row corresponding to the flagged sample
csv_data = []
for component, sample in zip(self.components, flag_data):
filepath = (
component.save_flagged(
self.dataset_dir, component.label, sample, None
)
if sample is not None
else ""
filepath = component.save_flagged(
self.dataset_dir, component.label, sample, None
)
csv_data.append(filepath)
if isinstance(component, tuple(file_preview_types)):
Expand Down
2 changes: 1 addition & 1 deletion gradio/process_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def cache_interface_examples(interface: Interface) -> None:
def load_from_cache(interface: Interface, example_id: int) -> List[Any]:
"""Loads a particular cached example for the interface."""
with open(CACHE_FILE) as cache:
examples = list(csv.reader(cache))
examples = list(csv.reader(cache, quotechar="'"))
example = examples[example_id + 1] # +1 to adjust for header
output = []
for component, cell in zip(interface.output_components, example):
Expand Down
7 changes: 5 additions & 2 deletions gradio/queueing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import requests

from gradio.routes import QueuePushBody

DB_FILE = "gradio_queue.db"


Expand Down Expand Up @@ -106,8 +108,9 @@ def pop() -> Tuple[int, str, Dict, str]:
return result[0], result[1], json.loads(result[2]), result[3]


def push(input_data: Dict, action: str) -> Tuple[str, int]:
input_data = json.dumps(input_data)
def push(body: QueuePushBody) -> Tuple[str, int]:
action = body.action
input_data = json.dumps({"data": body.data})
hash = generate_hash()
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
Expand Down
87 changes: 46 additions & 41 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import secrets
import traceback
import urllib
from typing import Any, List, Optional, Type
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type

import orjson
import pkg_resources
Expand All @@ -21,6 +22,7 @@
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from jinja2.exceptions import TemplateNotFound
from pydantic import BaseModel
from starlette.responses import RedirectResponse

from gradio import encryptor, queueing, utils
Expand Down Expand Up @@ -49,6 +51,43 @@ def render(self, content: Any) -> bytes:
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)


###########
# Data Models
###########


class PredictBody(BaseModel):
session_hash: Optional[str]
example_id: Optional[int]
data: List[Any]
state: Optional[Any]
fn_index: Optional[int]


class FlagData(BaseModel):
input_data: List[Any]
output_data: List[Any]
flag_option: Optional[str]
flag_index: Optional[int]


class FlagBody(BaseModel):
data: FlagData


class InterpretBody(BaseModel):
data: List[Any]


class QueueStatusBody(BaseModel):
hash: str


class QueuePushBody(BaseModel):
action: str
data: Any


###########
# Auth
###########
Expand Down Expand Up @@ -166,7 +205,8 @@ def file(path):
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
)
else:
return FileResponse(safe_join(app.cwd, path))
if Path(app.cwd).resolve() in Path(path).resolve().parents:
return FileResponse(Path(path).resolve())

@app.get("/api", response_class=HTMLResponse) # Needed for Spaces
@app.get("/api/", response_class=HTMLResponse)
Expand Down Expand Up @@ -229,49 +269,14 @@ async def predict(request: Request, username: str = Depends(get_current_user)):
raise error
return output

@app.post("/api/flag/", dependencies=[Depends(login_check)])
async def flag(request: Request, username: str = Depends(get_current_user)):
if app.blocks.analytics_enabled:
await utils.log_feature_analytics(app.blocks.ip_address, "flag")
body = await request.json()
data = body["data"]
await run_in_threadpool(
app.blocks.flagging_callback.flag,
app.blocks,
data["input_data"],
data["output_data"],
flag_option=data.get("flag_option"),
flag_index=data.get("flag_index"),
username=username,
)
return {"success": True}

@app.post("/api/interpret/", dependencies=[Depends(login_check)])
async def interpret(request: Request):
if app.blocks.analytics_enabled:
await utils.log_feature_analytics(app.blocks.ip_address, "interpret")
body = await request.json()
raw_input = body["data"]
interpretation_scores, alternative_outputs = await run_in_threadpool(
app.blocks.interpret, raw_input
)
return {
"interpretation_scores": interpretation_scores,
"alternative_outputs": alternative_outputs,
}

@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
async def queue_push(request: Request):
body = await request.json()
action = body["action"]
job_hash, queue_position = queueing.push(body, action)
async def queue_push(body: QueuePushBody):
job_hash, queue_position = queueing.push(body)
return {"hash": job_hash, "queue_position": queue_position}

@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
async def queue_status(request: Request):
body = await request.json()
hash = body["hash"]
status, data = queueing.get_status(hash)
async def queue_status(body: QueueStatusBody):
status, data = queueing.get_status(body.hash)
return {"status": status, "data": data}

return app
Expand Down
Loading

0 comments on commit 4f1947d

Please sign in to comment.