Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Batching to React UI MusicGen #281

Merged
merged 2 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ https://rsxdalv.github.io/bark-speaker-directory/
https://github.com/rsxdalv/tts-generation-webui/discussions/186#discussioncomment-7291274

## Changelog
Mar 5:
* Add Batching to React UI MusicGen (#281), thanks to https://github.com/Aamir3d for requesting this and providing feedback

Mar 3:
* Add MMS demo as a notebook
* Add MultiBandDiffusion high VRAM disclaimer
Expand Down
17 changes: 11 additions & 6 deletions react-ui/src/hooks/useLocalStorage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,17 @@ export default function useLocalStorage<T>(

const setValue: Dispatch<SetStateAction<T>> = (value) => {
// Allow value to be a function so we have the same API as useState
const valueToStore = value instanceof Function ? value(storedValue) : value;

// update local storage
setLocalValue(valueToStore);
// Save state
setStoredValue(valueToStore);
// const valueToStore = value instanceof Function ? value(storedValue) : value;

// // update local storage
// setLocalValue(valueToStore);
// // Save state
// setStoredValue(valueToStore);
setStoredValue(x => {
const newValue = value instanceof Function ? value(x) : value;
setLocalValue(newValue);
return newValue;
});
};

// watch localStorage changes
Expand Down
215 changes: 190 additions & 25 deletions react-ui/src/pages/musicgen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ const ModelSelector = ({
);
};

const initialMusicgenHyperParams = {
iterations: 1,
splitByLines: false,
};

const initialHistory = []; // prevent infinite loop
const MusicgenPage = () => {
const [data, setData] = useLocalStorage<Result | null>(
Expand All @@ -187,23 +192,57 @@ const MusicgenPage = () => {
musicgenId,
initialMusicgenParams
);
// hyperparameters
const [musicgenHyperParams, setMusicgenHyperParams] = useLocalStorage<
typeof initialMusicgenHyperParams
>("musicgenHyperParams", initialMusicgenHyperParams);
const [showLast, setShowLast] = useLocalStorage<number>(
"musicgenShowLast",
10
);
const interrupted = React.useRef(false);

async function musicgen() {
const body = JSON.stringify({
...musicgenParams,
melody: musicgenParams.model.includes("melody")
? musicgenParams.melody
: null,
model: musicgenParams.model,
});
const response = await fetch("/api/gradio/musicgen", {
method: "POST",
body,
});
const [progress, setProgress] = React.useState(0);
const [progressMax, setProgressMax] = React.useState(0);

const result: Result = await response.json();
setData(result);
setHistoryData((x) => [result, ...x]);
async function musicgen() {
interrupted.current = false;
const texts = musicgenHyperParams.splitByLines
? musicgenParams.text.split("\n")
: [musicgenParams.text];

const incrementNonRandomSeed = (seed: number, iteration: number) => {
return seed === -1 ? -1 : seed + iteration;
};

const musicgenIteration = async (text, iteration: number) => {
const result = await musicgenGenerate({
...musicgenParams,
text,
seed: incrementNonRandomSeed(musicgenParams.seed, iteration),
});
setData(result);
setHistoryData((x) => [result, ...x]);
};

setProgress(0);
setProgressMax(texts.length * musicgenHyperParams.iterations);
for (
let iteration = 0;
iteration < musicgenHyperParams.iterations;
iteration++
) {
for (const text of texts) {
if (interrupted.current) {
return;
}
await musicgenIteration(text, iteration);
setProgress((x) => x + 1);
}
}
interrupted.current = false;
setProgress(0);
setProgressMax(0);
}

const handleChange = (
Expand Down Expand Up @@ -272,6 +311,8 @@ const MusicgenPage = () => {
useParameters,
};

const interrupt = () => (interrupted.current = true);

return (
<Template>
<Head>
Expand Down Expand Up @@ -416,6 +457,24 @@ const MusicgenPage = () => {
className="border border-gray-300 p-2 rounded"
/>
</div>

<HyperParameters
params={musicgenHyperParams}
setParams={setMusicgenHyperParams}
progress={progress}
progressMax={progressMax}
interrupted={interrupted}
interrupt={interrupt}
/>
<button
className="border border-gray-300 p-2 rounded"
onClick={() => {
setMusicgenParams(initialMusicgenParams);
setMusicgenHyperParams(initialMusicgenHyperParams);
}}
>
Reset Parameters
</button>
</div>
</div>
</div>
Expand All @@ -438,19 +497,32 @@ const MusicgenPage = () => {

<div className="flex flex-col gap-y-2 border border-gray-300 p-2 rounded">
<label className="text-sm">History:</label>
{/* Clear history */}
<button
className="border border-gray-300 p-2 rounded"
onClick={() => {
setHistoryData([]);
}}
>
Clear History
</button>
<div className="flex gap-x-2 items-center">
<button
className="border border-gray-300 p-2 px-40 rounded"
onClick={() => {
setHistoryData([]);
}}
>
Clear History
</button>
<div className="flex gap-x-2 items-center">
<label className="text-sm">Show Last X entries:</label>
<input
type="number"
value={showLast}
onChange={(event) => setShowLast(Number(event.target.value))}
className="border border-gray-300 p-2 rounded"
min="0"
max="100"
step="1"
/>
</div>
</div>
<div className="flex flex-col gap-y-2">
{historyData &&
historyData
.slice(1, 6)
.slice(1, showLast + 1)
.map((item, index) => (
<AudioOutput
key={index}
Expand All @@ -469,3 +541,96 @@ const MusicgenPage = () => {
};

export default MusicgenPage;

async function musicgenGenerate(musicgenParams: MusicgenParams) {
const body = JSON.stringify({
...musicgenParams,
melody: musicgenParams.model.includes("melody")
? musicgenParams.melody
: null,
model: musicgenParams.model,
});
const response = await fetch("/api/gradio/musicgen", {
method: "POST",
body,
});

return (await response.json()) as Result;
}

const HyperParameters = ({
params: musicgenHyperParams,
setParams: setMusicgenHyperParams,
progress,
progressMax,
interrupted,
interrupt,
}: {
params: typeof initialMusicgenHyperParams;
setParams: React.Dispatch<
React.SetStateAction<typeof initialMusicgenHyperParams>
>;
progress: number;
progressMax: number;
interrupted: React.MutableRefObject<boolean>;
interrupt: () => void;
}) => (
<div className="flex flex-col gap-y-2 border border-gray-300 p-2 rounded">
<label className="text-sm">Hyperparameters:</label>
<div className="flex gap-x-2 items-center">
<label className="text-sm">Iterations:</label>
<input
type="number"
name="iterations"
value={musicgenHyperParams.iterations}
onChange={(event) => {
setMusicgenHyperParams({
...musicgenHyperParams,
iterations: Number(event.target.value),
});
}}
className="border border-gray-300 p-2 rounded"
min="1"
max="10"
step="1"
/>
</div>
<div className="flex gap-x-2 items-center">
<div className="text-sm">Each line as a separate prompt:</div>
<input
type="checkbox"
name="splitByLines"
checked={musicgenHyperParams.splitByLines}
onChange={(event) => {
setMusicgenHyperParams({
...musicgenHyperParams,
splitByLines: event.target.checked,
});
}}
className="border border-gray-300 p-2 rounded"
/>
</div>
<Progress progress={progress} progressMax={progressMax} />
<button className="border border-gray-300 p-2 rounded" onClick={interrupt}>
{interrupted.current ? "Interrupted..." : "Interrupt"}
</button>
</div>
);

const Progress = ({
progress,
progressMax,
}: {
progress: number;
progressMax: number;
}) => (
<div className="flex gap-x-2 items-center">
<label className="text-sm">Progress:</label>
<progress
value={progress}
max={progressMax}
className="[&::-webkit-progress-bar]:rounded [&::-webkit-progress-value]:rounded [&::-webkit-progress-bar]:bg-slate-300 [&::-webkit-progress-value]:bg-orange-400 [&::-moz-progress-bar]:bg-orange-400 [&::-webkit-progress-value]:transition-all [&::-webkit-progress-value]:duration-200"
/>
{progress}/{progressMax}
</div>
);
2 changes: 1 addition & 1 deletion react-ui/src/tabs/MusicgenParams.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export const initialMusicgenParams: MusicgenParams = {
text: "lofi hip hop beats to relax/study to",
melody: undefined,
// melody: "https://www.mfiles.co.uk/mp3-downloads/gs-cd-track2.mp3",
model: "Small",
model: "facebook/musicgen-small",
duration: 1,
topk: 250,
topp: 0,
Expand Down