Skip to content

Commit

Permalink
Add c-api, python api, jni
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Jun 21, 2024
1 parent 12e1d4c commit abecc1c
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 0 deletions.
4 changes: 4 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, "");
recognizer_config.hotwords_score =
SHERPA_ONNX_OR(config->hotwords_score, 1.5);
recognizer_config.tokenize_hotwords =
SHERPA_ONNX_OR(config->tokenize_hotwords, true);

recognizer_config.ctc_fst_decoder_config.graph =
SHERPA_ONNX_OR(config->ctc_fst_decoder_config.graph, "");
Expand Down Expand Up @@ -390,6 +392,8 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, "");
recognizer_config.hotwords_score =
SHERPA_ONNX_OR(config->hotwords_score, 1.5);
recognizer_config.tokenize_hotwords =
SHERPA_ONNX_OR(config->tokenize_hotwords, true);

recognizer_config.rule_fsts = SHERPA_ONNX_OR(config->rule_fsts, "");
recognizer_config.rule_fars = SHERPA_ONNX_OR(config->rule_fars, "");
Expand Down
7 changes: 7 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig {
/// Bonus score for each token in hotwords.
float hotwords_score;

/// Whether to tokenize hotwords
bool tokenize_hotwords;

SherpaOnnxOnlineCtcFstDecoderConfig ctc_fst_decoder_config;
const char *rule_fsts;
const char *rule_fars;
Expand Down Expand Up @@ -413,6 +416,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {

/// Bonus score for each token in hotwords.
float hotwords_score;

/// Whether to tokenize hotwords
bool tokenize_hotwords;

const char *rule_fsts;
const char *rule_fars;
} SherpaOnnxOfflineRecognizerConfig;
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/jni/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(cls, "hotwordsScore", "F");
ans.hotwords_score = env->GetFloatField(config, fid);

fid = env->GetFieldID(cls, "tokenizeHotwords", "Z");
ans.tokenize_hotwords = env->GetFloatField(config, fid);

fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/jni/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(cls, "hotwordsScore", "F");
ans.hotwords_score = env->GetFloatField(config, fid);

fid = env->GetFieldID(cls, "tokenizeHotwords", "Z");
ans.tokenize_hotwords = env->GetFloatField(config, fid);

fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/kotlin-api/OfflineRecognizer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ data class OfflineRecognizerConfig(
var maxActivePaths: Int = 4,
var hotwordsFile: String = "",
var hotwordsScore: Float = 1.5f,
var tokenizeHotwords: Boolean = true,
var ruleFsts: String = "",
var ruleFars: String = "",
)
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/kotlin-api/OnlineRecognizer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ data class OnlineRecognizerConfig(
var maxActivePaths: Int = 4,
var hotwordsFile: String = "",
var hotwordsScore: Float = 1.5f,
var tokenizeHotwords: Boolean = true,
var ruleFsts: String = "",
var ruleFars: String = "",
)
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def from_transducer(
max_active_paths: int = 4,
hotwords_file: str = "",
hotwords_score: float = 1.5,
tokenize_hotwords: bool = True,
blank_penalty: float = 0.0,
modeling_unit: str = "cjkchar",
bpe_vocab: str = "",
Expand Down Expand Up @@ -96,6 +97,9 @@ def from_transducer(
hotwords_score:
The hotword score of each token for biasing word/phrase. Used only if
hotwords_file is given with modified_beam_search as decoding method.
tokenize_hotwords:
Whether to tokenize hotwords, true will tokenize hotwords in the engine
if false, you have to tokenize hotwords by yourself.
blank_penalty:
The penalty applied on blank symbol during decoding.
modeling_unit:
Expand Down Expand Up @@ -165,6 +169,7 @@ def from_transducer(
max_active_paths=max_active_paths,
hotwords_file=hotwords_file,
hotwords_score=hotwords_score,
tokenize_hotwords=tokenize_hotwords,
blank_penalty=blank_penalty,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/online_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def from_transducer(
decoding_method: str = "greedy_search",
max_active_paths: int = 4,
hotwords_score: float = 1.5,
tokenize_hotwords: bool = True,
blank_penalty: float = 0.0,
hotwords_file: str = "",
provider: str = "cpu",
Expand Down Expand Up @@ -131,6 +132,9 @@ def from_transducer(
hotwords_score:
The hotword score of each token for biasing word/phrase. Used only if
hotwords_file is given with modified_beam_search as decoding method.
tokenize_hotwords:
Whether to tokenize hotwords, true will tokenize hotwords in the engine
if false, you have to tokenize hotwords by yourself.
temperature_scale:
Temperature scaling for output symbol confidence estiamation.
It affects only confidence values, the decoding uses the original
Expand Down Expand Up @@ -222,6 +226,7 @@ def from_transducer(
decoding_method=decoding_method,
max_active_paths=max_active_paths,
hotwords_score=hotwords_score,
tokenize_hotwords=tokenize_hotwords,
hotwords_file=hotwords_file,
blank_penalty=blank_penalty,
temperature_scale=temperature_scale,
Expand Down

0 comments on commit abecc1c

Please sign in to comment.