Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali Abid committed Apr 19, 2022
1 parent 3a0d78f commit ba9eebf
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 16 deletions.
6 changes: 3 additions & 3 deletions test/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def test_as_component(self):
)
self.assertEqual("Image3D_input/1", to_save)
restored = Image3D_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, "Image3D_input/1")
self.assertEqual(restored["name"], "Image3D_input/1")

self.assertIsInstance(Image3D_input.generate_sample(), dict)
Image3D_input = gr.inputs.Image3D(label="Upload Your 3D Image Model")
Expand All @@ -778,9 +778,9 @@ def test_as_component(self):

def test_in_interface(self):
Image3D = media_data.BASE64_MODEL3D
iface = gr.Interface(lambda x: x, "Image3D", "Image3D")
iface = gr.Interface(lambda x: x, "Model3D", "Model3D")
self.assertEqual(
iface.process([Image3D])[0][0]["data"],
iface.process([Image3D])[0]["data"],
Image3D["data"].replace("@file/gltf", ""),
)

Expand Down
13 changes: 2 additions & 11 deletions test/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,23 +555,14 @@ def test_as_component(self):
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = Image3D_output.save_flagged(
tmpdirname, "Image3D_output", gr.test_data.BASE64_IMAGE3D, None
tmpdirname, "Image3D_output", media_data.BASE64_MODEL3D, None
)
self.assertEqual("Image3D_output/0.gltf", to_save)
to_save = Image3D_output.save_flagged(
tmpdirname, "Image3D_output", gr.test_data.BASE64_IMAGE3D, None
tmpdirname, "Image3D_output", media_data.BASE64_MODEL3D, None
)
self.assertEqual("Image3D_output/1.gltf", to_save)


class TestNames(unittest.TestCase):
def test_no_duplicate_uncased_names(
self,
): # this ensures that get_input_instance() works correctly when instantiating from components
subclasses = gr.outputs.OutputComponent.__subclasses__()
unique_subclasses_uncased = set([s.__name__.lower() for s in subclasses])
self.assertEqual(len(subclasses), len(unique_subclasses_uncased))


if __name__ == "__main__":
unittest.main()
8 changes: 6 additions & 2 deletions test/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,15 @@ def predict(input, history=""):
io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"data": ["test", None]})
response = client.post(
"/api/predict/", json={"data": ["test", None], "fn_index": 0}
)
output = dict(response.json())
print("output", output)
self.assertEqual(output["data"], ["test", None])
response = client.post("/api/predict/", json={"data": ["test", None]})
response = client.post(
"/api/predict/", json={"data": ["test", None], "fn_index": 0}
)
output = dict(response.json())
self.assertEqual(output["data"], ["testtest", None])

Expand Down

0 comments on commit ba9eebf

Please sign in to comment.