Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunked inference for codec #22

Merged
merged 25 commits into from
Jul 20, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Chunked vs unchunked inference.
  • Loading branch information
prem committed Jul 15, 2023
commit 350cb1b72048f61da88f596f38a9129ecab9eaf7
31 changes: 20 additions & 11 deletions dac/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class DACFile:
input_db: float
channels: int
sample_rate: int
padding: bool

def save(self, path):
artifacts = {
Expand All @@ -30,6 +31,7 @@ def save(self, path):
"sample_rate": self.sample_rate,
"chunk_length": self.chunk_length,
"channels": self.channels,
"padding": self.padding,
},
}
path = Path(path).with_suffix(".dac")
Expand Down Expand Up @@ -146,7 +148,6 @@ def compress(
self.eval()
original_padding = self.padding
original_device = audio_signal.device
self.padding = False

audio_signal = audio_signal.clone()
original_sr = audio_signal.sample_rate
Expand All @@ -170,16 +171,23 @@ def compress(
nb, nac, nt = audio_signal.audio_data.shape
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)

# Zero-pad signal on either side by the delay
audio_signal.zero_pad(self.delay, self.delay)
n_samples = int(win_duration * self.sample_rate)
# Round n_samples to nearest hop length multiple
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
if win_duration is None or audio_signal.signal_duration <= win_duration:
# Unchunked compression (used if signal length < win duration)
self.padding = True
n_samples = nt
hop = nt
else:
# Chunked inference
self.padding = False
# Zero-pad signal on either side by the delay
audio_signal.zero_pad(self.delay, self.delay)
n_samples = int(win_duration * self.sample_rate)
# Round n_samples to nearest hop length multiple
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
hop = self.get_output_length(n_samples)

codes = []

range_fn = range if not verbose else tqdm.trange
hop = self.get_output_length(n_samples)

for i in range_fn(0, nt, hop):
x = audio_signal[..., i : i + n_samples]
Expand All @@ -200,6 +208,7 @@ def compress(
input_db=input_db,
channels=nac,
sample_rate=original_sr,
padding=self.padding,
)

if n_quantizers is not None:
Expand All @@ -215,12 +224,12 @@ def decompress(
verbose: bool = False,
):
self.eval()
original_padding = self.padding
self.padding = False

if isinstance(obj, (str, Path)):
obj = DACFile.load(obj)

original_padding = self.padding
self.padding = obj.padding

range_fn = range if not verbose else tqdm.trange
codes = obj.codes
original_device = codes.device
Expand Down