Skip to content

Commit

Permalink
Merge pull request #1 from abidlabs/address-reviews
Browse files Browse the repository at this point in the history
Handle `None` -> `null` correct as well as `gr.update()`
  • Loading branch information
abidlabs authored Feb 28, 2025
2 parents 70bdfe9 + 4b5b42d commit 0ca532c
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 33 deletions.
97 changes: 64 additions & 33 deletions groovy/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,23 @@ def transform_len(self, node: ast.Call) -> str:
return f"{arg_code}.length"

# === Function Calls ===
def _handle_gradio_component_updates(self, node: ast.Call):
"""Handle Gradio component calls and return JSON representation."""
kwargs = {}
for kw in node.keywords:
if isinstance(kw.value, ast.Constant) and kw.value.value is None:
# None values should remain None in the kwargs dictionary
# so that they are converted to null, not "null" in json.dumps().
kwargs[kw.arg] = None
continue
value = self.visit(kw.value)
try:
kwargs[kw.arg] = ast.literal_eval(value)
except Exception:
kwargs[kw.arg] = value
kwargs["__type__"] = "update"
return json.dumps(kwargs)

def visit_Call(self, node: ast.Call): # noqa: N802
try:
import gradio
Expand All @@ -302,19 +319,15 @@ def visit_Call(self, node: ast.Call): # noqa: N802
# Try to resolve if this is a Gradio component.
if has_gradio:
try:
# Handle direct update() call
if node.func.id == "update":
return self._handle_gradio_component_updates(node)

component_class = getattr(gradio, node.func.id, None)
if component_class and issubclass(
component_class, gradio.blocks.Block
):
kwargs = {}
for kw in node.keywords:
value = self.visit(kw.value)
try:
kwargs[kw.arg] = ast.literal_eval(value)
except Exception:
kwargs[kw.arg] = value
kwargs["__type__"] = "update"
return json.dumps(kwargs)
return self._handle_gradio_component_updates(node)
except Exception:
pass

Expand All @@ -333,19 +346,15 @@ def visit_Call(self, node: ast.Call): # noqa: N802
"gradio",
"gr",
}:
# Handle gr.update() call
if node.func.attr == "update":
return self._handle_gradio_component_updates(node)

component_class = getattr(gradio, node.func.attr, None)
if component_class and issubclass(
component_class, gradio.blocks.Block
):
kwargs = {}
for kw in node.keywords:
value = self.visit(kw.value)
try:
kwargs[kw.arg] = ast.literal_eval(value)
except Exception:
kwargs[kw.arg] = value
kwargs["__type__"] = "update"
return json.dumps(kwargs)
return self._handle_gradio_component_updates(node)
except Exception:
pass

Expand Down Expand Up @@ -514,7 +523,7 @@ def transpile(fn: Callable, validate: bool = False) -> str:
Parameters:
fn: The Python function to transpile.
validate: If True, the function will be validated to ensure it takes no arguments & only returns gradio component property updates. This is used when Groovy is used inside Gradio.
validate: If True, the function will be validated to ensure it takes no arguments & only returns gradio component property updates. This is used when Groovy is used inside Gradio and `gradio` must be installed to use this.
Returns:
The JavaScript code as a string.
Expand Down Expand Up @@ -622,21 +631,43 @@ def _is_valid_gradio_return(node: ast.AST) -> bool:
node.func.value, ast.Name
):
if node.func.value.id in {"gr", "gradio"}:
if node.args:
return False
try:
import gradio

for kw in node.keywords:
if kw.arg == "value":
return False
return True
elif isinstance(node.func, ast.Name):
if node.args:
if node.func.attr == "update":
return True

component_class = getattr(gradio, node.func.attr, None)
if component_class and issubclass(
component_class, gradio.blocks.Block
):
if node.args:
return False
for kw in node.keywords:
if kw.arg == "value":
return False
return True
except (ImportError, AttributeError):
pass
return False
elif isinstance(node.func, ast.Name):
try:
import gradio

if node.func.id == "update":
return True

for kw in node.keywords:
if kw.arg == "value":
return False
return True
component_class = getattr(gradio, node.func.id, None)
if component_class and issubclass(component_class, gradio.blocks.Block):
if node.args:
return False
for kw in node.keywords:
if kw.arg == "value":
return False
return True
except (ImportError, AttributeError):
pass
return False

elif isinstance(node, (ast.Tuple, ast.List)):
if not node.elts:
Expand All @@ -652,7 +683,7 @@ def _is_valid_gradio_return(node: ast.AST) -> bool:
import gradio as gr

def filter_rows_by_term():
return gr.Tabs(selected=2)
return gr.update(selected=2, visible=True, info=None)

js_code = transpile(filter_rows_by_term)
js_code = transpile(filter_rows_by_term, validate=True)
print(js_code)
58 changes: 58 additions & 0 deletions tests/test_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,61 @@ def mixed_components():
transpile(mixed_components, validate=True)

assert "Function must only return Gradio component updates" in str(e.value)


def test_gradio_component_with_none_values():
def component_with_none():
return gradio.Textbox(visible=True, info=None)

expected = """function component_with_none() {
return {"visible": true, "info": null, "__type__": "update"};
}"""
assert transpile(component_with_none).strip() == expected.strip()


def test_gradio_update_function():
def update_component():
return gradio.update(visible=False, interactive=True)

expected = """function update_component() {
return {"visible": false, "interactive": true, "__type__": "update"};
}"""
assert transpile(update_component).strip() == expected.strip()


def test_update_with_none_values():
def update_with_none():
return gradio.update(info=None, label="Updated")

expected = """function update_with_none() {
return {"info": null, "label": "Updated", "__type__": "update"};
}"""
assert transpile(update_with_none).strip() == expected.strip()


def test_mixed_update_and_components():
def mixed_updates():
return gradio.update(visible=True), gradio.Textbox(placeholder="Test")

expected = """function mixed_updates() {
return [{"visible": true, "__type__": "update"}, {"placeholder": "Test", "__type__": "update"}];
}"""
assert transpile(mixed_updates).strip() == expected.strip()


def test_conditional_update():
def conditional_update(x: int):
if x > 10:
return gradio.update(visible=True)
else:
return gradio.update(visible=False)

expected = """function conditional_update(x) {
if ((x > 10)) {
return {"visible": true, "__type__": "update"};
}
else {
return {"visible": false, "__type__": "update"};
}
}"""
assert transpile(conditional_update).strip() == expected.strip()

0 comments on commit 0ca532c

Please sign in to comment.