Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/whispercpp/api.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ class Params:
def on_progress(
self, callback: t.Callable[[Context, int, T], None], userdata: T
) -> None: ...
def on_new_logits(
self, callback: t.Callable[[Context, int, NDArray[float], T], None], userdata: T
) -> None: ...

T = t.TypeVar("T")

Expand Down
20 changes: 14 additions & 6 deletions src/whispercpp/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "pybind11/stl.h"
#include "whisper.h"
#endif
#include <pybind11/numpy.h>
#include <algorithm>
#include <cassert>
#include <cstddef>
Expand Down Expand Up @@ -141,13 +142,17 @@ struct Params {
public:
typedef std::function<void(Context &, int)> NewSegmentCallback;
typedef std::function<void(Context &, int)> ProgressCallback;
typedef std::function<void(Context &, int /*n_logits*/,
py::array_t<float> /*logits*/
)> LogitsFilterCallback;

private:
std::shared_ptr<whisper_full_params> fp;
std::string language;

CallbackAndContext<NewSegmentCallback> new_segment_callback;
CallbackAndContext<ProgressCallback> progress_callback;
CallbackAndContext<LogitsFilterCallback> logits_filter_callback;

friend struct Context;

Expand All @@ -165,9 +170,11 @@ struct Params {

Params(std::shared_ptr<whisper_full_params> &&fp,
CallbackAndContext<NewSegmentCallback> new_segment_callback,
CallbackAndContext<ProgressCallback> progress_callback)
CallbackAndContext<ProgressCallback> progress_callback,
CallbackAndContext<LogitsFilterCallback> logits_filter_callback)
: fp(fp), new_segment_callback(new_segment_callback),
progress_callback(progress_callback){};
progress_callback(progress_callback),
logits_filter_callback(logits_filter_callback){};

Params(Params const &);
Params &operator=(Params const &);
Expand Down Expand Up @@ -434,17 +441,18 @@ struct Params {
// Defaults to None.
void set_progress_callback(ProgressCallback callback);

// Set the callback for each decoder to filter obtained logits.
// Do not use this function unless you know what you are doing.
// Defaults to None.
void set_logits_filter_callback(LogitsFilterCallback callback);

// Set the callback for starting the encoder.
// Do not use this function unless you know what you are doing.
// Defaults to None.
void set_encoder_begin_callback(whisper_encoder_begin_callback callback);
// Set the user data to be passed to the encoder begin callback.
// Defaults to None. See set_encoder_begin_callback.
void set_encoder_begin_callback_user_data(void *user_data);
// Set the callback for each decoder to filter obtained logits.
// Do not use this function unless you know what you are doing.
// Defaults to None.
void set_logits_filter_callback(whisper_logits_filter_callback callback);
// Set the user data to be passed to the logits filter callback.
// Defaults to None. See set_logits_filter_callback.
void set_logits_filter_callback_user_data(void *user_data);
Expand Down
85 changes: 66 additions & 19 deletions src/whispercpp/params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,29 @@ void progress_callback_handler(whisper_context *ctx,
}
}

void logits_filter_callback_handler(whisper_context * ctx,
whisper_state * state,
const whisper_token_data * tokens,
int n_tokens,
float * logits,
void * user_data) {
auto logits_filter_callback =
(CallbackAndContext<Params::LogitsFilterCallback>::Container *)user_data;
auto callback = logits_filter_callback->callback;
if (callback != nullptr) {
int n_vocab = whisper_n_vocab(ctx);

py::array_t<float> logits_buf = py::array_t<float>(
n_vocab,
logits,
py::cast(logits_filter_callback->context)
);

(*callback)(*logits_filter_callback->context, n_tokens, logits_buf);
}
}


Params Params::from_enum(whisper_sampling_strategy *enum_) {
SamplingStrategies ss = SamplingStrategies::from_enum(enum_);
return Params::from_sampling_strategy(&ss);
Expand All @@ -74,6 +97,10 @@ Params Params::from_sampling_strategy(SamplingStrategies *ss) {
fp.progress_callback = progress_callback_handler;
fp.progress_callback_user_data = progress_callback.data.get();

CallbackAndContext<LogitsFilterCallback> logits_filter_callback;
fp.logits_filter_callback = logits_filter_callback_handler;
fp.logits_filter_callback_user_data = logits_filter_callback.data.get();

switch (strategy->to_enum()) {
case WHISPER_SAMPLING_GREEDY:
fp.greedy.best_of = ((SamplingGreedy *)strategy)->best_of;
Expand All @@ -86,23 +113,29 @@ Params Params::from_sampling_strategy(SamplingStrategies *ss) {
throw std::runtime_error("Unknown sampling strategy");
}
return Params(std::make_shared<whisper_full_params>(fp),
new_segment_callback, progress_callback);
new_segment_callback, progress_callback,
logits_filter_callback);
};

Params::Params() {
fp->new_segment_callback = new_segment_callback_handler;
fp->new_segment_callback_user_data = new_segment_callback.data.get();
fp->progress_callback = progress_callback_handler;
fp->progress_callback_user_data = progress_callback.data.get();
fp->logits_filter_callback = logits_filter_callback_handler;
fp->logits_filter_callback_user_data = logits_filter_callback.data.get();
}

Params::Params(Params const &other)
: fp(other.fp), new_segment_callback(other.new_segment_callback),
progress_callback(other.progress_callback) {
progress_callback(other.progress_callback),
logits_filter_callback(other.logits_filter_callback) {
fp->new_segment_callback = new_segment_callback_handler;
fp->new_segment_callback_user_data = new_segment_callback.data.get();
fp->progress_callback = progress_callback_handler;
fp->progress_callback_user_data = progress_callback.data.get();
fp->logits_filter_callback = logits_filter_callback_handler;
fp->logits_filter_callback_user_data = logits_filter_callback.data.get();
}

Params &Params::operator=(Params const &other) {
Expand All @@ -113,6 +146,9 @@ Params &Params::operator=(Params const &other) {
progress_callback = other.progress_callback;
fp->progress_callback = progress_callback_handler;
fp->progress_callback_user_data = progress_callback.data.get();
logits_filter_callback = other.logits_filter_callback;
fp->logits_filter_callback = logits_filter_callback_handler;
fp->logits_filter_callback_user_data = logits_filter_callback.data.get();
return *this;
}

Expand All @@ -124,6 +160,9 @@ Params Params::copy_for_full(Context &context) {
if (params.progress_callback.data) {
params.progress_callback.data->context = &context;
}
if (params.logits_filter_callback.data) {
params.logits_filter_callback.data->context = &context;
}
return params;
}

Expand All @@ -145,13 +184,20 @@ void Params::set_new_segment_callback(NewSegmentCallback callback) {
std::make_shared<NewSegmentCallback>(callback);
}

// Called for progresss updates
// Called for progress updates
// Defaults to None.
void Params::set_progress_callback(ProgressCallback callback) {
(*progress_callback.data).callback =
std::make_shared<ProgressCallback>(callback);
}

// called for every decoding pass when new logits are available
// Defaults to None.
void Params::set_logits_filter_callback(LogitsFilterCallback callback) {
(*logits_filter_callback.data).callback =
std::make_shared<LogitsFilterCallback>(callback);
}

// Set the callback for starting the encoder.
// Do not use this function unless you know what you are
// doing. Defaults to None.
Expand All @@ -166,20 +212,6 @@ void Params::set_encoder_begin_callback_user_data(void *user_data) {
fp->encoder_begin_callback_user_data = user_data;
}

// Set the callback for each decoder to filter obtained
// logits. Do not use this function unless you know what you
// are doing. Defaults to None.
void Params::set_logits_filter_callback(
whisper_logits_filter_callback callback) {
fp->logits_filter_callback = callback;
}
// Set the user data to be passed to the logits filter
// callback. Defaults to None. See
// set_logits_filter_callback.
void Params::set_logits_filter_callback_user_data(void *user_data) {
fp->logits_filter_callback_user_data = user_data;
};

inline std::ostream &operator<<(std::ostream &os, const Params &params) {
os << params.to_string();
return os;
Expand Down Expand Up @@ -242,6 +274,7 @@ std::string Params::to_string() const {

typedef std::function<void(Context &, int, py::object &)> NewSegmentCallback;
typedef std::function<void(Context &, int, py::object &)> ProgressCallback;
typedef std::function<void(Context &, int, py::array_t<float>, py::object &)> LogitsFilterCallback;

#define WITH_DEPRECATION(depr) \
PyErr_WarnEx(PyExc_DeprecationWarning, \
Expand Down Expand Up @@ -722,6 +755,20 @@ void ExportParamsApi(py::module &m) {
std::move(callback), std::move(user_data), _1, _2));
},
"callback"_a, "user_data"_a = py::none(), py::keep_alive<1, 2>(),
py::keep_alive<1, 3>());
// TODO: encoder_begin_callback and logits_filter_callback are still missing
py::keep_alive<1, 3>())
.def("on_new_logits",
[](Params &self, LogitsFilterCallback &callback,
py::object &user_data) {
using namespace std::placeholders;
self.set_logits_filter_callback(std::bind(
[](LogitsFilterCallback &callback, py::object &user_data,
Context &ctx, int n_tokens, py::array_t<float> logits) mutable {
// TODO pass float and tokens and stuff
(callback)(ctx, n_tokens, logits, user_data);
},
std::move(callback), std::move(user_data), _1, _2, _3));
},
"callback"_a, "user_data"_a = py::none(), py::keep_alive<1, 2>(),
py::keep_alive<1, 3>());
// TODO: encoder_begin_callback is still missing
}
18 changes: 18 additions & 0 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,21 @@ def handleProgress(context: w.api.Context, progress: int, progresses: list[int])

m.transcribe(preprocess(ROOT / "samples" / "jfk.wav"))
assert len(progresses) > 0


def test_logits_callback():
def handleLogits(context: w.api.Context, n_tokens: int, logits: NDArray):
logits_data.append(n_tokens, logits)

m = w.Whisper.from_pretrained("tiny.en")

logits_data = []
m.params.on_new_logits(handleLogits, logits_data)

m.trasncripe(preprocess(ROOT / "samples" / "jfk.wav"))
assert len(logits_data) > 0

# make sure logits are passed by reference, so all logits stored
# should be equal to one another as none of them were copied in the
# callback and copies don't happen by default.
assert np.all(logits_data[0][1] == logits_data[1][1])