Skip to content

Commit

Permalink
Live audio streaming output (#5077)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
aliabid94 and gradio-pr-bot authored Aug 8, 2023
1 parent cd1353f commit 667875b
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 11 deletions.
40 changes: 40 additions & 0 deletions .changeset/famous-rice-taste.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
---
"@gradio/upload": patch
"gradio": patch
---

fix:Live audio streaming output

highlight:

#### Now supports loading streamed outputs

Allows users to use generators to stream audio out, yielding consecutive chunks of audio. Requires `streaming=True` to be set on the output audio.

```python
import gradio as gr
from pydub import AudioSegment

def stream_audio(audio_file):
audio = AudioSegment.from_mp3(audio_file)
i = 0
chunk_size = 3000

while chunk_size*i < len(audio):
chunk = audio[chunk_size*i:chunk_size*(i+1)]
i += 1
if chunk:
file = f"/tmp/{i}.mp3"
chunk.export(file, format="mp3")
yield file

demo = gr.Interface(
fn=stream_audio,
inputs=gr.Audio(type="filepath", label="Audio file to stream"),
outputs=gr.Audio(autoplay=True, streaming=True),
)

demo.queue().launch()
```

From the backend, streamed outputs are served from the `/stream/` endpoint instead of the `/file/` endpoint. Currently just used to serve audio streaming output. The output JSON will have `is_stream`: `true`, instead of `is_file`: `true` in the file data object.
1 change: 1 addition & 0 deletions demo/stream_audio_out/run.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stream_audio_out"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from pydub import AudioSegment\n", "\n", "def stream_audio(audio_file):\n", " audio = AudioSegment.from_mp3(audio_file)\n", " i = 0\n", " chunk_size = 3000\n", " \n", " while chunk_size*i < len(audio):\n", " chunk = audio[chunk_size*i:chunk_size*(i+1)]\n", " i += 1\n", " if chunk:\n", " file = f\"/tmp/{i}.mp3\"\n", " chunk.export(file, format=\"mp3\") \n", " yield file\n", " \n", "demo = gr.Interface(\n", " fn=stream_audio,\n", " inputs=gr.Audio(type=\"filepath\", label=\"Audio file to stream\"),\n", " outputs=gr.Audio(autoplay=True, streaming=True),\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
24 changes: 24 additions & 0 deletions demo/stream_audio_out/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import gradio as gr
from pydub import AudioSegment

def stream_audio(audio_file):
audio = AudioSegment.from_mp3(audio_file)
i = 0
chunk_size = 3000

while chunk_size*i < len(audio):
chunk = audio[chunk_size*i:chunk_size*(i+1)]
i += 1
if chunk:
file = f"/tmp/{i}.mp3"
chunk.export(file, format="mp3")
yield file

demo = gr.Interface(
fn=stream_audio,
inputs=gr.Audio(type="filepath", label="Audio file to stream"),
outputs=gr.Audio(autoplay=True, streaming=True),
)

if __name__ == "__main__":
demo.queue().launch()
29 changes: 29 additions & 0 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings
import webbrowser
from abc import abstractmethod
from collections import defaultdict
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Literal, cast
Expand Down Expand Up @@ -707,6 +708,7 @@ def __init__(
self.share = False
self.enable_queue = None
self.max_threads = 40
self.pending_streams = defaultdict(dict)
self.show_error = True
if css is not None and os.path.exists(css):
with open(css) as css_file:
Expand Down Expand Up @@ -1333,13 +1335,35 @@ def postprocess_data(

return output

def handle_streaming_outputs(
self, fn_index: int, data: list, session_hash: str | None, run: int | None
) -> list:
if session_hash is None or run is None:
return data

from gradio.events import StreamableOutput

for i, output_id in enumerate(self.dependencies[fn_index]["outputs"]):
block = self.blocks[output_id]
if isinstance(block, StreamableOutput) and block.streaming:
stream = block.stream_output(data[i])
if run not in self.pending_streams[session_hash]:
self.pending_streams[session_hash][run] = defaultdict(list)
self.pending_streams[session_hash][run][output_id].append(stream)
data[i] = {
"name": f"{session_hash}/{run}/{output_id}",
"is_stream": True,
}
return data

async def process_api(
self,
fn_index: int,
inputs: list[Any],
state: dict[int, Any],
request: routes.Request | list[routes.Request] | None = None,
iterators: dict[int, Any] | None = None,
session_hash: str | None = None,
event_id: str | None = None,
event_data: EventData | None = None,
) -> dict[str, Any]:
Expand Down Expand Up @@ -1391,10 +1415,15 @@ async def process_api(
else:
inputs = self.preprocess_data(fn_index, inputs, state)
iterator = iterators.get(fn_index, None) if iterators else None
was_generating = iterator is not None
result = await self.call_function(
fn_index, inputs, iterator, request, event_id, event_data
)
data = self.postprocess_data(fn_index, result["prediction"], state)
if result["is_generating"] or was_generating:
data = self.handle_streaming_outputs(
fn_index, data, session_hash, id(iterator)
)
is_generating, iterator = result["is_generating"], result["iterator"]

block_fn.total_runtime += result["duration"]
Expand Down
22 changes: 19 additions & 3 deletions gradio/components/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Callable, Literal

import numpy as np
import requests
from gradio_client import media_data
from gradio_client import utils as client_utils
from gradio_client.documentation import document, set_documentation_group
Expand All @@ -20,6 +21,7 @@
Playable,
Recordable,
Streamable,
StreamableOutput,
Uploadable,
)
from gradio.interpretation import TokenInterpretable
Expand All @@ -34,6 +36,7 @@ class Audio(
Playable,
Recordable,
Streamable,
StreamableOutput,
Uploadable,
IOComponent,
FileSerializable,
Expand All @@ -52,7 +55,7 @@ def __init__(
self,
value: str | Path | tuple[int, np.ndarray] | Callable | None = None,
*,
source: Literal["upload", "microphone"] = "upload",
source: Literal["upload", "microphone"] | None = None,
type: Literal["numpy", "filepath"] = "numpy",
label: str | None = None,
every: float | None = None,
Expand Down Expand Up @@ -84,7 +87,7 @@ def __init__(
min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.
interactive: if True, will allow users to upload and edit a audio file; if False, can only be used to play audio. If not provided, this is inferred based on whether the component is used as an input or output.
visible: If False, component will be hidden.
streaming: If set to True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'microphone'.
streaming: If set to True when used in a `live` interface as an input, will automatically stream webcam feed. When used set as an output, takes audio chunks yield from the backend and combines them into one streaming audio output.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
format: The file format to save audio files. Either 'wav' or 'mp3'. wav files are lossless but will tend to be larger files. mp3 files tend to be smaller. Default is wav. Applies both when this component is used as an input (when `type` is "format") and when this component is used as an output.
Expand All @@ -93,6 +96,7 @@ def __init__(
show_share_button: If True, will show a share icon in the corner of the component that allows user to share outputs to Hugging Face Spaces Discussions. If False, icon does not appear. If set to None (default behavior), then the icon appears if this Gradio app is launched on Spaces, but not otherwise.
"""
valid_sources = ["upload", "microphone"]
source = source if source else ("microphone" if streaming else "upload")
if source not in valid_sources:
raise ValueError(
f"Invalid value for parameter `source`: {source}. Please choose from one of: {valid_sources}"
Expand All @@ -105,7 +109,7 @@ def __init__(
)
self.type = type
self.streaming = streaming
if streaming and source != "microphone":
if streaming and source == "upload":
raise ValueError(
"Audio streaming only available if source is 'microphone'."
)
Expand Down Expand Up @@ -340,6 +344,18 @@ def postprocess(
file_path = self.make_temp_copy_if_needed(y)
return {"name": file_path, "data": None, "is_file": True}

def stream_output(self, y):
if y is None:
return None
if client_utils.is_http_url_like(y["name"]):
response = requests.get(y["name"])
bytes = response.content
else:
file_path = y["name"]
with open(file_path, "rb") as f:
bytes = f.read()
return bytes

def check_streamable(self):
if self.source != "microphone":
raise ValueError(
Expand Down
8 changes: 8 additions & 0 deletions gradio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,14 @@ def check_streamable(self):
pass


class StreamableOutput(EventListener):
def __init__(self):
self.streaming: bool

def stream_output(self, y) -> bytes:
raise NotImplementedError


@document("*start_recording", "*stop_recording", inherit=True)
class Recordable(EventListener):
def __init__(self):
Expand Down
61 changes: 54 additions & 7 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import posixpath
import secrets
import tempfile
import time
import traceback
from asyncio import TimeoutError as AsyncTimeOutError
from collections import defaultdict
Expand Down Expand Up @@ -386,6 +387,41 @@ async def file(path_or_url: str, request: fastapi.Request):
return response
return FileResponse(abs_path, headers={"Accept-Ranges": "bytes"})

@app.get(
"/stream/{session_hash}/{run}/{component_id}",
dependencies=[Depends(login_check)],
)
async def stream(
session_hash: str, run: int, component_id: int, request: fastapi.Request
):
stream: list = (
app.get_blocks()
.pending_streams[session_hash]
.get(run, {})
.get(component_id, None)
)
if stream is None:
raise HTTPException(404, "Stream not found.")

def stream_wrapper():
check_stream_rate = 0.01
max_wait_time = 120 # maximum wait between yields - assume generator thread has crashed otherwise.
wait_time = 0
while True:
if len(stream) == 0:
if wait_time > max_wait_time:
return
wait_time += check_stream_rate
time.sleep(check_stream_rate)
continue
wait_time = 0
next_stream = stream.pop(0)
if next_stream is None:
return
yield next_stream

return StreamingResponse(stream_wrapper())

@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
async def file_deprecated(path: str, request: fastapi.Request):
return await file(path, request)
Expand All @@ -406,24 +442,25 @@ async def run_predict(
fn_index_inferred: int,
):
fn_index = body.fn_index
if hasattr(body, "session_hash"):
if body.session_hash not in app.state_holder:
app.state_holder[body.session_hash] = {
session_hash = getattr(body, "session_hash", None)
if session_hash is not None:
if session_hash not in app.state_holder:
app.state_holder[session_hash] = {
_id: deepcopy(getattr(block, "value", None))
for _id, block in app.get_blocks().blocks.items()
if getattr(block, "stateful", False)
}
session_state = app.state_holder[body.session_hash]
session_state = app.state_holder[session_hash]
# The should_reset set keeps track of the fn_indices
# that have been cancelled. When a job is cancelled,
# the /reset route will mark the jobs as having been reset.
# That way if the cancel job finishes BEFORE the job being cancelled
# the job being cancelled will not overwrite the state of the iterator.
if fn_index in app.iterators_to_reset[body.session_hash]:
if fn_index in app.iterators_to_reset[session_hash]:
iterators = {}
app.iterators_to_reset[body.session_hash].remove(fn_index)
app.iterators_to_reset[session_hash].remove(fn_index)
else:
iterators = app.iterators[body.session_hash]
iterators = app.iterators[session_hash]
else:
session_state = {}
iterators = {}
Expand All @@ -448,6 +485,7 @@ async def run_predict(
request=request,
state=session_state,
iterators=iterators,
session_hash=session_hash,
event_id=event_id,
event_data=event_data,
)
Expand All @@ -457,6 +495,15 @@ async def run_predict(
if isinstance(output, Error):
raise output
except BaseException as error:
iterator = iterators.get(fn_index, None)
if iterator is not None: # close off any streams that are still open
run_id = id(iterator)
pending_streams: dict[int, list] = (
app.get_blocks().pending_streams[session_hash].get(run_id, {})
)
for stream in pending_streams.values():
stream.append(None)

show_error = app.get_blocks().show_error or isinstance(error, Error)
traceback.print_exc()
return JSONResponse(
Expand Down
6 changes: 5 additions & 1 deletion guides/02_building-interfaces/02_reactive-interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,8 @@ The difference between `gr.Audio(source='microphone')` and `gr.Audio(source='mic

Here is example code of streaming images from the webcam.

$code_stream_frames
$code_stream_frames

Streaming can also be done in an output component. A `gr.Audio(streaming=True)` output component can take a stream of audio data yielded piece-wise by a generator function and combines them into a single audio file.

$code_stream_audio_out
1 change: 1 addition & 0 deletions js/upload/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export interface FileData {
data: string;
blob?: File;
is_file?: boolean;
is_stream?: boolean;
mime_type?: string;
alt_text?: string;
}
6 changes: 6 additions & 0 deletions js/upload/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ export function normalise_file(
} else {
file.data = "/proxy=" + root_url + "file=" + file.name;
}
} else if (file.is_stream) {
if (root_url == null) {
file.data = root + "/stream/" + file.name;
} else {
file.data = "/proxy=" + root_url + "stream/" + file.name;
}
}
return file;
}
Expand Down

1 comment on commit 667875b

@vercel
Copy link

@vercel vercel bot commented on 667875b Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.