Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix sample_rate to frame_rate where appropriate #40

Merged
merged 2 commits into from
Mar 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions encodec/quantization/vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,43 +66,46 @@ def __init__(
threshold_ema_dead_code=self.threshold_ema_dead_code,
)

def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult:
def forward(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult:
"""Residual vector quantization on the given input tensor.
Args:
x (torch.Tensor): Input tensor.
sample_rate (int): Sample rate of the input tensor.
frame_rate (int): Sample rate of the input tensor.
bandwidth (float): Target bandwidth.
Returns:
QuantizedResult:
The quantized (or approximately quantized) representation with
the associated bandwidth and any penalty term for the loss.
"""
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
bw_per_q = self.get_bandwidth_per_quantizer(frame_rate)
n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth)
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
bw = torch.tensor(n_q * bw_per_q).to(x)
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))

def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int:
def get_num_quantizers_for_bandwidth(self, frame_rate: int, bandwidth: tp.Optional[float] = None) -> int:
"""Return n_q based on specified target bandwidth.
"""
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
bw_per_q = self.get_bandwidth_per_quantizer(frame_rate)
n_q = self.n_q
if bandwidth and bandwidth > 0.:
n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
# bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as
# bandwidth == 6.0
n_q = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
return n_q

def get_bandwidth_per_quantizer(self, sample_rate: int):
"""Return bandwidth per quantizer for a given input sample rate.
def get_bandwidth_per_quantizer(self, frame_rate: int):
"""Return bandwidth per quantizer for a given input frame rate.
Each quantizer encodes a frame with lg(bins) bits.
"""
return math.log2(self.bins) * sample_rate / 1000
return math.log2(self.bins) * frame_rate

def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
def encode(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizers to use
and returns indices for each quantizer.
"""
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth)
codes = self.vq.encode(x, n_q=n_q)
return codes

Expand Down