Skip to content

Commit

Permalink
Merge pull request gradio-app#37 from gradio-app/aliabd/flagging
Browse files Browse the repository at this point in the history
Flagging
  • Loading branch information
Abubakar Abid authored Jul 30, 2020
2 parents 82c2b5a + 96543e3 commit 480591a
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 34 deletions.
35 changes: 35 additions & 0 deletions gradio/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def get_shortcut_implementations(cls):
"""
return {}

def rebuild(self, dir, data):
"""
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
"""
return data

class Textbox(AbstractInput):
"""
Expand Down Expand Up @@ -290,6 +295,16 @@ def process_example(self, example):
else:
return example

def rebuild(self, dir, data):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(data)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename


class Sketchpad(AbstractInput):
"""
Expand Down Expand Up @@ -341,6 +356,16 @@ def preprocess(self, inp):
def process_example(self, example):
return preprocessing_utils.convert_file_to_base64(example)

def rebuild(self, dir, data):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(data)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename


class Webcam(AbstractInput):
"""
Expand Down Expand Up @@ -378,6 +403,16 @@ def preprocess(self, inp):
im, (self.image_width, self.image_height))
return np.array(im)

def rebuild(self, dir, data):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(data)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename


class Microphone(AbstractInput):
"""
Expand Down
23 changes: 19 additions & 4 deletions gradio/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import sys
import weakref
import analytics

import os

PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
Expand All @@ -30,7 +30,6 @@
except requests.ConnectionError:
ip_address = "No internet connection"


class Interface:
"""
Interfaces are created with Gradio using the `gradio.Interface()` function.
Expand All @@ -41,7 +40,8 @@ def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=N
live=False, show_input=True, show_output=True,
capture_session=False, title=None, description=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True):
allow_screenshot=True, allow_flagging=True,
flagging_dir="flagged"):
"""
Parameters:
fn (Callable): the function to wrap an interface around.
Expand Down Expand Up @@ -101,6 +101,8 @@ def get_output_instance(iface):
self.server_port = server_port
self.simple_server = None
self.allow_screenshot = allow_screenshot
self.allow_flagging = allow_flagging
self.flagging_dir = flagging_dir
Interface.instances.add(self)

data = {'fn': fn,
Expand All @@ -120,6 +122,18 @@ def get_output_instance(iface):
except (ImportError, AttributeError): # If they are using TF >= 2.0 or don't have TF, just ignore this.
pass

if self.allow_flagging:
if self.title is not None:
dir_name = "_".join(self.title.split(" "))
else:
dir_name = "_".join([fn.__name__ for fn in self.predict])
index = 1
while os.path.exists(self.flagging_dir + "/" + dir_name +
"_{}".format(index)):
index += 1
self.flagging_dir = self.flagging_dir + "/" + dir_name + \
"_{}".format(index)

try:
requests.post(analytics_url + 'gradio-initiated-analytics/',
data=data)
Expand All @@ -141,7 +155,8 @@ def get_config_file(self):
"title": self.title,
"description": self.description,
"thumbnail": self.thumbnail,
"allow_screenshot": self.allow_screenshot
"allow_screenshot": self.allow_screenshot,
"allow_flagging": self.allow_flagging
}
try:
param_names = inspect.getfullargspec(self.predict[0])[0]
Expand Down
31 changes: 8 additions & 23 deletions gradio/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import sys
import analytics


INITIAL_PORT_VALUE = int(os.getenv(
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
TRY_NUM_PORTS = int(os.getenv(
Expand All @@ -36,8 +35,6 @@
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"

FLAGGING_DIRECTORY = 'static/flagged/'
FLAGGING_FILENAME = 'data.txt'
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/'

Expand Down Expand Up @@ -175,16 +172,6 @@ def do_POST(self):
if interface.saliency is not None:
saliency = interface.saliency(raw_input, prediction)
output['saliency'] = saliency.tolist()
# if interface.always_flag:
# msg = json.loads(data_string)
# flag_dir = os.path.join(FLAGGING_DIRECTORY, str(interface.hash))
# os.makedirs(flag_dir, exist_ok=True)
# output_flag = {'input': interface.input_interface.rebuild_flagged(flag_dir, msg['data']),
# 'output': interface.output_interface.rebuild_flagged(flag_dir, processed_output),
# }
# with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
# f.write(json.dumps(output_flag))
# f.write("\n")

self.wfile.write(json.dumps(output).encode())

Expand All @@ -197,20 +184,18 @@ def do_POST(self):
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
flag_dir = os.path.join(FLAGGING_DIRECTORY,
str(interface.flag_hash))
os.makedirs(flag_dir, exist_ok=True)
os.makedirs(interface.flagging_dir, exist_ok=True)
output = {'inputs': [interface.input_interfaces[
i].rebuild_flagged(
flag_dir, msg['data']['input_data']) for i
i].rebuild(
interface.flagging_dir, msg['data']['input_data']) for i
in range(len(interface.input_interfaces))],
'outputs': [interface.output_interfaces[
i].rebuild_flagged(
flag_dir, msg['data']['output_data']) for i
in range(len(interface.output_interfaces))],
'message': msg['data']['message']}
i].rebuild(
interface.flagging_dir, msg['data']['output_data']) for i
in range(len(interface.output_interfaces))]}

with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
with open("{}/log.txt".format(interface.flagging_dir),
'a+') as f:
f.write(json.dumps(output))
f.write("\n")

Expand Down
15 changes: 13 additions & 2 deletions gradio/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def get_shortcut_implementations(cls):
"""
return {}

def rebuild(self, dir, data):
"""
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
"""
return data


class Textbox(AbstractOutput):
'''
Expand Down Expand Up @@ -130,6 +136,11 @@ def get_shortcut_implementations(cls):
"label": {},
}

def rebuild(self, dir, data):
"""
Default rebuild method for label
"""
return json.loads(data)

class Image(AbstractOutput):
'''
Expand Down Expand Up @@ -169,11 +180,11 @@ def postprocess(self, prediction):
raise ValueError(
"The `Image` output interface (with plt=False) expects a numpy array.")

def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
im = preprocessing_utils.decode_base64_to_image(data)
timestamp = datetime.datetime.now()
filename = 'output_{}.png'.format(timestamp.
strftime("%Y-%m-%d-%H-%M-%S"))
Expand Down
6 changes: 6 additions & 0 deletions gradio/static/css/gradio.css
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ input.submit {
input.submit:hover {
background-color: #f39c12;
}
.flag {
visibility: hidden;
}
.flagged {
background-color: pink !important;
}
/* label:hover {
background-color: lightgray;
} */
Expand Down
5 changes: 2 additions & 3 deletions gradio/static/js/all_io.js
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,11 @@ var io_master_template = {
this.target.find(".output_interfaces").css("opacity", 1);
}
},
flag: function(message) {
flag: function() {
var post_data = {
'data': {
'input_data' : toStringIfObject(this.last_input) ,
'output_data' : toStringIfObject(this.last_output),
'message' : message
'output_data' : toStringIfObject(this.last_output)
}
}
$.ajax({type: "POST",
Expand Down
19 changes: 17 additions & 2 deletions gradio/static/js/gradio.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function gradio(config, fn, target) {
<div class="screenshot_logo">
<img src="static/img/logo_inline.png">
</div>
</div>
<input class="flag panel_button" type="button" value="FLAG"/>
</div>
</div>`);
let io_master = Object.create(io_master_template);
Expand Down Expand Up @@ -117,6 +117,7 @@ function gradio(config, fn, target) {
output_interface.clear();
}
target.find(".flag").removeClass("flagged");
target.find(".flag").val("FLAG");
target.find(".flag_message").empty();
target.find(".loading").addClass("invisible");
target.find(".loading_time").text("");
Expand All @@ -127,6 +128,9 @@ function gradio(config, fn, target) {
if (config["allow_screenshot"]) {
target.find(".screenshot").css("visibility", "visible");
}
if(config["allow_flagging"]){
target.find(".flag").css("visibility", "visible");
}
target.find(".screenshot").click(function() {
$(".screenshot").hide();
$(".screenshot_logo").show();
Expand All @@ -146,11 +150,22 @@ function gradio(config, fn, target) {
target.find(".submit").click(function() {
io_master.gather();
target.find(".flag").removeClass("flagged");
target.find(".flag").val("FLAG");
})
}
if (!config.show_input) {
target.find(".input_panel").hide();
}
}

target.find(".flag").click(function() {
if (io_master.last_output) {
target.find(".flag").addClass("flagged");
target.find(".flag").val("FLAGGED");
io_master.flag();

// io_master.flag($(".flag_message").val());
}
})

return io_master;
}
Expand Down

0 comments on commit 480591a

Please sign in to comment.