Skip to content

Commit

Permalink
PR review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
aliabd committed Jul 29, 2020
1 parent 05099e7 commit 96543e3
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 30 deletions.
16 changes: 8 additions & 8 deletions gradio/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def get_shortcut_implementations(cls):
"""
return {}

def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
"""
return msg
return data

class Textbox(AbstractInput):
"""
Expand Down Expand Up @@ -295,11 +295,11 @@ def process_example(self, example):
else:
return example

def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
im = preprocessing_utils.decode_base64_to_image(data)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
Expand Down Expand Up @@ -356,11 +356,11 @@ def preprocess(self, inp):
def process_example(self, example):
return preprocessing_utils.convert_file_to_base64(example)

def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
im = preprocessing_utils.decode_base64_to_image(data)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
Expand Down Expand Up @@ -403,11 +403,11 @@ def preprocess(self, inp):
im, (self.image_width, self.image_height))
return np.array(im)

def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
im = preprocessing_utils.decode_base64_to_image(data)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
Expand Down
23 changes: 11 additions & 12 deletions gradio/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@
except requests.ConnectionError:
ip_address = "No internet connection"

FLAGGING_DIRECTORY = 'flagged/'


class Interface:
"""
Interfaces are created with Gradio using the `gradio.Interface()` function.
Expand All @@ -43,7 +40,8 @@ def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=N
live=False, show_input=True, show_output=True,
capture_session=False, title=None, description=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True):
allow_screenshot=True, allow_flagging=True,
flagging_dir="flagged"):
"""
Parameters:
fn (Callable): the function to wrap an interface around.
Expand Down Expand Up @@ -104,6 +102,7 @@ def get_output_instance(iface):
self.simple_server = None
self.allow_screenshot = allow_screenshot
self.allow_flagging = allow_flagging
self.flagging_dir = flagging_dir
Interface.instances.add(self)

data = {'fn': fn,
Expand All @@ -125,15 +124,15 @@ def get_output_instance(iface):

if self.allow_flagging:
if self.title is not None:
dir_name = "_".join(self.title.split(" ")) + "_1"
dir_name = "_".join(self.title.split(" "))
else:
dir_name = "_".join([fn.__name__ for fn in self.predict]) + \
"_1"
i = 1
while os.path.exists(FLAGGING_DIRECTORY + dir_name):
i += 1
dir_name = dir_name[:-2] + "_" + str(i)
self.flagging_dir = FLAGGING_DIRECTORY + dir_name
dir_name = "_".join([fn.__name__ for fn in self.predict])
index = 1
while os.path.exists(self.flagging_dir + "/" + dir_name +
"_{}".format(index)):
index += 1
self.flagging_dir = self.flagging_dir + "/" + dir_name + \
"_{}".format(index)

try:
requests.post(analytics_url + 'gradio-initiated-analytics/',
Expand Down
8 changes: 4 additions & 4 deletions gradio/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"

FLAGGING_FILENAME = 'flagged.txt'
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/'

Expand Down Expand Up @@ -187,15 +186,16 @@ def do_POST(self):
msg = json.loads(data_string)
os.makedirs(interface.flagging_dir, exist_ok=True)
output = {'inputs': [interface.input_interfaces[
i].rebuild_flagged(
i].rebuild(
interface.flagging_dir, msg['data']['input_data']) for i
in range(len(interface.input_interfaces))],
'outputs': [interface.output_interfaces[
i].rebuild_flagged(
i].rebuild(
interface.flagging_dir, msg['data']['output_data']) for i
in range(len(interface.output_interfaces))]}

with open(os.path.join(interface.flagging_dir, FLAGGING_FILENAME), 'a+') as f:
with open("{}/log.txt".format(interface.flagging_dir),
'a+') as f:
f.write(json.dumps(output))
f.write("\n")

Expand Down
12 changes: 6 additions & 6 deletions gradio/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def get_shortcut_implementations(cls):
"""
return {}

def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
"""
return msg
return data


class Textbox(AbstractOutput):
Expand Down Expand Up @@ -136,11 +136,11 @@ def get_shortcut_implementations(cls):
"label": {},
}

def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
Default rebuild method for label
"""
return json.loads(msg)
return json.loads(data)

class Image(AbstractOutput):
'''
Expand Down Expand Up @@ -180,11 +180,11 @@ def postprocess(self, prediction):
raise ValueError(
"The `Image` output interface (with plt=False) expects a numpy array.")

def rebuild_flagged(self, dir, msg):
def rebuild(self, dir, data):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
im = preprocessing_utils.decode_base64_to_image(data)
timestamp = datetime.datetime.now()
filename = 'output_{}.png'.format(timestamp.
strftime("%Y-%m-%d-%H-%M-%S"))
Expand Down

0 comments on commit 96543e3

Please sign in to comment.