-
Notifications
You must be signed in to change notification settings - Fork 2
/
onnx.html
165 lines (131 loc) · 7.35 KB
/
onnx.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width">
<title>Lyra v2 (SoundStream) - ONNX Runtime Web</title>
<script src="./enable-threads.js"></script>
</head>
<body>
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@1.13.1/dist/ort.js"></script>
<div>
backend: <select id="backendSelectEl">
<option>wasm</option>
<option>webgl (not currently working due to op support)</option>
</select>
</div>
<div>
<span>Model version:</span>
<select id="modelVersionSelectEl">
<option value="1.3.0">1.3.0</option>
<option value="1.2.0">1.2.0</option>
</select>
</div>
<br>
<div>
<button onclick="testEncoder();">Test Encoder</button>
<button onclick="testQuantizerEncoder();">Test Quantizer Encoder</button>
<button onclick="testQuantizerDecoder();">Test Quantizer Decoder</button>
<button onclick="testDecoder();">Test Decoder</button>
</div>
<p>This is a WIP - not currently working. See README for blockers.</p>
<p><a href="https://github.com/josephrocca/lyra-v2-soundstream-web">github repo</a> - <a href="https://huggingface.co/rocca/lyra-v2-soundstream">huggingface repo</a></p>
<script>
if(self.crossOriginIsolated) { // needs to be cross-origin-isolated to use wasm threads. you need to serve this html file with these two headers: https://web.dev/coop-coep/
ort.env.wasm.numThreads = navigator.hardwareConcurrency
}
let encoderModel;
let decoderModel;
let quantizerEncoderModel;
let quantizerDecoderModel;
async function initModels(opts={}) {
console.log("Loading models...");
[
encoderModel,
quantizerEncoderModel,
quantizerDecoderModel,
decoderModel,
] = await Promise.all([
opts.encoder ? ort.InferenceSession.create(`https://huggingface.co/rocca/lyra-v2-soundstream/resolve/main/onnx/${modelVersionSelectEl.value}/soundstream_encoder.onnx`, { executionProviders: [backendSelectEl.value] }) : undefined,
opts.quantizerEncoder ? ort.InferenceSession.create(`https://huggingface.co/rocca/lyra-v2-soundstream/resolve/main/onnx/${modelVersionSelectEl.value}/quantizer_encoder.onnx`, { executionProviders: [backendSelectEl.value] }) : undefined,
opts.quantizerDecoder ? ort.InferenceSession.create(`https://huggingface.co/rocca/lyra-v2-soundstream/resolve/main/onnx/${modelVersionSelectEl.value}/quantizer_decoder.onnx`, { executionProviders: [backendSelectEl.value] }) : undefined,
opts.decoder ? ort.InferenceSession.create(`https://huggingface.co/rocca/lyra-v2-soundstream/resolve/main/onnx/${modelVersionSelectEl.value}/lyragan.onnx`, { executionProviders: [backendSelectEl.value] }) : undefined,
]);
console.log("Models finished loading.");
}
async function encode(audioSamples) { // `audioSamples` should be a float32[1,320]
if(!encoderModel) throw new Error("You need to init the encoder model first.");
const feeds = {"serving_default_input_audio:0": audioSamples};
const results = await encoderModel.run(feeds);
return results["StatefulPartitionedCall:0"].data; // float32[1,1,64]
}
async function quantize(embedding, numQuantizers) { // `embedding` should be a tensor of type float32[1,1,64]. `numQuantizers` is an int32 scalar and can be either 16, 32, or 46. Explanation: https://github.com/google/lyra/issues/92#issuecomment-1265883697
if(!quantizerEncoderModel) throw new Error("You need to init the quantizer model first.");
let outputs = await quantizerEncoderModel.run({"encode_input_frames:0":embedding, "encode_num_quantizers:0":numQuantizers});
return outputs["StatefulPartitionedCall_1:0"].data; // int32[46,1,1]
}
async function dequantize(bits) { // `bits` should be a tensor of type int32[46,1,1]
if(!quantizerDecoderModel) throw new Error("You need to init the quantizer model first.");
let outputs = await quantizerDecoderModel.run({"decode_encoding_indices:0":bits});
return outputs["StatefulPartitionedCall:0"].data; // float32[1,1,64]
}
async function decode(embedding) { // `embedding` should be a tensor of type float32[1,1,64]
if(!decoderModel) throw new Error("You need to init the decoder model first.");
let outputTensor = await decoderModel.run({"serving_default_input_audio:0":embedding});
return outputTensor["StatefulPartitionedCall:0"].data; // float32[1,320]
}
let audioCache = {};
async function getAudioSamples(url, chunkI) {
if(!audioCache[url]) {
// Need audio length to set max length of OfflineAudioContext - (hacky)
let audioDuration = await new Promise(r => {
let audio = new Audio();
audio.onloadedmetadata = () => r(audio.duration);
audio.src = url;
});
let audioData = await fetch(url).then(r => r.arrayBuffer());
let audioCtx = new OfflineAudioContext({sampleRate:16000, length:16000*audioDuration + 1000, numberOfChannels:1});
let decodedData = await audioCtx.decodeAudioData(audioData); // resampled to the AudioContext's sampling rate
let float32Data = decodedData.getChannelData(0); // Float32Array for channel 0
audioCache[url] = float32Data;
}
let audioChunk = audioCache[url].slice(chunkI*320, (chunkI+1)*320); // 320 samples at 16khz = 20ms
return audioChunk;
}
async function testEncoder() {
console.log("Testing encoder:");
await initModels({encoder:true});
let inputSamplesFloat32Arr = await getAudioSamples("https://huggingface.co/rocca/lyra-v2-soundstream/resolve/main/test.wav", 0);
let audioSamples = new ort.Tensor('float32', inputSamplesFloat32Arr, [1,320]);
console.log("INPUT:", audioSamples);
let embedding = await encode(audioSamples);
console.log("OUTPUT:", embedding); // float32[1,1,64]
}
async function testDecoder() {
console.log("Testing decoder:");
await initModels({decoder:true});
let inputEmbedding = new ort.Tensor('float32', new Float32Array(64).fill(1), [1,1,64]);
console.log("INPUT:", inputEmbedding);
let embedding = await decode(inputEmbedding);
console.log("OUTPUT:", embedding); // a 320-element Float32Array
}
async function testQuantizerEncoder() {
console.log("Testing quantizer encoder:");
await initModels({quantizerEncoder:true});
let inputEmbedding = new ort.Tensor('float32', new Float32Array(64).fill(1), [1,1,64]);
let numQuantizers = new ort.Tensor('int32', new Int32Array([46]), []);
console.log("INPUT:", inputEmbedding, numQuantizers);
let bits = await quantize(inputEmbedding, numQuantizers);
console.log("OUTPUT:", bits); // int32[46,1,1]
}
async function testQuantizerDecoder() {
console.log("Testing quantizer decoder:");
await initModels({quantizerDecoder:true});
let bits = new ort.Tensor('int32', new Int32Array(46).fill(1), [46,1,1]);
console.log("INPUT:", bits);
let outputEmbedding = await dequantize(bits);
console.log("OUTPUT:", outputEmbedding); // float32[1,1,64]
}
</script>
</body>
</html>