diff --git a/CMakeLists.txt b/CMakeLists.txt index 83395293..d5c4564c 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ # CMakeLists in this exact order for cmake to work correctly cmake_minimum_required(VERSION 3.16) -set(PROJECT_VER "0.9.1") +set(PROJECT_VER "0.9.2") include($ENV{IDF_PATH}/tools/cmake/project.cmake) project(xiaozhi) diff --git a/main/application.cc b/main/application.cc index 5b76d64c..33a62c99 100644 --- a/main/application.cc +++ b/main/application.cc @@ -123,20 +123,52 @@ void Application::ToggleChatState() { Schedule([this]() { if (chat_state_ == kChatStateIdle) { SetChatState(kChatStateConnecting); - if (protocol_->OpenAudioChannel()) { - opus_encoder_.ResetState(); - SetChatState(kChatStateListening); - } else { + if (!protocol_->OpenAudioChannel()) { + ESP_LOGE(TAG, "Failed to open audio channel"); SetChatState(kChatStateIdle); + return; } + + keep_listening_ = true; + protocol_->SendStartListening(kListeningModeAutoStop); + SetChatState(kChatStateListening); } else if (chat_state_ == kChatStateSpeaking) { - AbortSpeaking(); + AbortSpeaking(kAbortReasonNone); } else if (chat_state_ == kChatStateListening) { protocol_->CloseAudioChannel(); } }); } +void Application::StartListening() { + Schedule([this]() { + keep_listening_ = false; + if (chat_state_ == kChatStateIdle) { + if (!protocol_->IsAudioChannelOpened()) { + SetChatState(kChatStateConnecting); + if (!protocol_->OpenAudioChannel()) { + SetChatState(kChatStateIdle); + ESP_LOGE(TAG, "Failed to open audio channel"); + return; + } + } + protocol_->SendStartListening(kListeningModeManualStop); + SetChatState(kChatStateListening); + } else if (chat_state_ == kChatStateSpeaking) { + AbortSpeaking(kAbortReasonNone); + protocol_->SendStartListening(kListeningModeManualStop); + SetChatState(kChatStateListening); + } + }); +} + +void Application::StopListening() { + Schedule([this]() { + protocol_->SendStopListening(); + SetChatState(kChatStateIdle); + }); +} + void Application::Start() { auto& board = Board::GetInstance(); board.Initialize(); @@ -248,26 +280,31 @@ void Application::Start() { }); }); - wake_word_detect_.OnWakeWordDetected([this]() { - Schedule([this]() { + wake_word_detect_.OnWakeWordDetected([this](const std::string& wake_word) { + Schedule([this, &wake_word]() { if (chat_state_ == kChatStateIdle) { SetChatState(kChatStateConnecting); wake_word_detect_.EncodeWakeWordData(); - if (protocol_->OpenAudioChannel()) { - std::string opus; - // Encode and send the wake word data to the server - while (wake_word_detect_.GetWakeWordOpus(opus)) { - protocol_->SendAudio(opus); - } - opus_encoder_.ResetState(); - // Send a ready message to indicate the server that the wake word data is sent - SetChatState(kChatStateWakeWordDetected); - } else { + if (!protocol_->OpenAudioChannel()) { + ESP_LOGE(TAG, "Failed to open audio channel"); SetChatState(kChatStateIdle); + wake_word_detect_.StartDetection(); + return; } + + std::string opus; + // Encode and send the wake word data to the server + while (wake_word_detect_.GetWakeWordOpus(opus)) { + protocol_->SendAudio(opus); + } + // Set the chat state to wake word detected + protocol_->SendWakeWordDetected(wake_word); + ESP_LOGI(TAG, "Wake word detected: %s", wake_word.c_str()); + keep_listening_ = true; + SetChatState(kChatStateListening); } else if (chat_state_ == kChatStateSpeaking) { - AbortSpeaking(); + AbortSpeaking(kAbortReasonWakeWordDetected); } // Resume detection @@ -313,15 +350,23 @@ void Application::Start() { auto state = cJSON_GetObjectItem(root, "state"); if (strcmp(state->valuestring, "start") == 0) { Schedule([this]() { - skip_to_end_ = false; - SetChatState(kChatStateSpeaking); + if (chat_state_ == kChatStateIdle || chat_state_ == kChatStateListening) { + skip_to_end_ = false; + opus_decoder_ctl(opus_decoder_, OPUS_RESET_STATE); + SetChatState(kChatStateSpeaking); + } }); } else if (strcmp(state->valuestring, "stop") == 0) { Schedule([this]() { auto codec = Board::GetInstance().GetAudioCodec(); codec->WaitForOutputDone(); if (chat_state_ == kChatStateSpeaking) { - SetChatState(kChatStateListening); + if (keep_listening_) { + protocol_->SendStartListening(kListeningModeAutoStop); + SetChatState(kChatStateListening); + } else { + SetChatState(kChatStateIdle); + } } }); } else if (strcmp(state->valuestring, "sentence_start") == 0) { @@ -375,9 +420,9 @@ void Application::MainLoop() { } } -void Application::AbortSpeaking() { +void Application::AbortSpeaking(AbortReason reason) { ESP_LOGI(TAG, "Abort speaking"); - protocol_->SendAbort(); + protocol_->SendAbortSpeaking(reason); skip_to_end_ = true; auto codec = Board::GetInstance().GetAudioCodec(); @@ -391,7 +436,6 @@ void Application::SetChatState(ChatState state) { "connecting", "listening", "speaking", - "wake_word_detected", "upgrading", "invalid_state" }; @@ -399,12 +443,10 @@ void Application::SetChatState(ChatState state) { // No need to update the state return; } - chat_state_ = state; - ESP_LOGI(TAG, "STATE: %s", state_str[chat_state_]); auto display = Board::GetInstance().GetDisplay(); auto builtin_led = Board::GetInstance().GetBuiltinLed(); - switch (chat_state_) { + switch (state) { case kChatStateUnknown: case kChatStateIdle: builtin_led->TurnOff(); @@ -424,6 +466,7 @@ void Application::SetChatState(ChatState state) { builtin_led->TurnOn(); display->SetStatus("聆听中..."); display->SetEmotion("neutral"); + opus_encoder_.ResetState(); #ifdef CONFIG_USE_AFE_SR audio_processor_.Start(); #endif @@ -436,17 +479,17 @@ void Application::SetChatState(ChatState state) { audio_processor_.Stop(); #endif break; - case kChatStateWakeWordDetected: - builtin_led->SetBlue(); - builtin_led->TurnOn(); - break; case kChatStateUpgrading: builtin_led->SetGreen(); builtin_led->StartContinuousBlink(100); break; + default: + ESP_LOGE(TAG, "Invalid chat state: %d", chat_state_); + return; } - protocol_->SendState(state_str[chat_state_]); + chat_state_ = state; + ESP_LOGI(TAG, "STATE: %s", state_str[chat_state_]); } void Application::AudioEncodeTask() { @@ -474,7 +517,7 @@ void Application::AudioEncodeTask() { audio_decode_queue_.pop_front(); lock.unlock(); - if (skip_to_end_) { + if (skip_to_end_ || chat_state_ != kChatStateSpeaking) { continue; } diff --git a/main/application.h b/main/application.h index fe774eca..fc91f242 100644 --- a/main/application.h +++ b/main/application.h @@ -29,7 +29,6 @@ enum ChatState { kChatStateConnecting, kChatStateListening, kChatStateSpeaking, - kChatStateWakeWordDetected, kChatStateUpgrading }; @@ -41,17 +40,19 @@ class Application { static Application instance; return instance; } + // 删除拷贝构造函数和赋值运算符 + Application(const Application&) = delete; + Application& operator=(const Application&) = delete; void Start(); ChatState GetChatState() const { return chat_state_; } void Schedule(std::function callback); void SetChatState(ChatState state); void Alert(const std::string&& title, const std::string&& message); - void AbortSpeaking(); + void AbortSpeaking(AbortReason reason); void ToggleChatState(); - // 删除拷贝构造函数和赋值运算符 - Application(const Application&) = delete; - Application& operator=(const Application&) = delete; + void StartListening(); + void StopListening(); private: Application(); @@ -68,6 +69,7 @@ class Application { Protocol* protocol_ = nullptr; EventGroupHandle_t event_group_; volatile ChatState chat_state_ = kChatStateUnknown; + bool keep_listening_ = false; bool skip_to_end_ = false; // Audio encode / decode diff --git a/main/protocols/mqtt_protocol.cc b/main/protocols/mqtt_protocol.cc index cbf4b580..ba6d9d06 100644 --- a/main/protocols/mqtt_protocol.cc +++ b/main/protocols/mqtt_protocol.cc @@ -72,9 +72,9 @@ bool MqttProtocol::StartMqttClient() { } else if (strcmp(type->valuestring, "goodbye") == 0) { auto session_id = cJSON_GetObjectItem(root, "session_id"); if (session_id == nullptr || session_id_ == session_id->valuestring) { - if (on_audio_channel_closed_ != nullptr) { - on_audio_channel_closed_(); - } + Application::GetInstance().Schedule([this]() { + CloseAudioChannel(); + }); } } else if (on_incoming_json_ != nullptr) { on_incoming_json_(root); @@ -129,23 +129,6 @@ void MqttProtocol::SendAudio(const std::string& data) { udp_->Send(encrypted); } -void MqttProtocol::SendState(const std::string& state) { - std::string message = "{"; - message += "\"session_id\":\"" + session_id_ + "\","; - message += "\"type\":\"state\","; - message += "\"state\":\"" + state + "\""; - message += "}"; - SendText(message); -} - -void MqttProtocol::SendAbort() { - std::string message = "{"; - message += "\"session_id\":\"" + session_id_ + "\","; - message += "\"type\":\"abort\""; - message += "}"; - SendText(message); -} - void MqttProtocol::CloseAudioChannel() { { std::lock_guard lock(channel_mutex_); diff --git a/main/protocols/mqtt_protocol.h b/main/protocols/mqtt_protocol.h index 6cc6be89..c6da3efe 100644 --- a/main/protocols/mqtt_protocol.h +++ b/main/protocols/mqtt_protocol.h @@ -26,9 +26,6 @@ class MqttProtocol : public Protocol { ~MqttProtocol(); void SendAudio(const std::string& data) override; - void SendText(const std::string& text) override; - void SendState(const std::string& state) override; - void SendAbort() override; bool OpenAudioChannel() override; void CloseAudioChannel() override; bool IsAudioChannelOpened() const override; @@ -52,11 +49,12 @@ class MqttProtocol : public Protocol { int udp_port_; uint32_t local_sequence_; uint32_t remote_sequence_; - std::string session_id_; bool StartMqttClient(); void ParseServerHello(const cJSON* root); std::string DecodeHexString(const std::string& hex_string); + + void SendText(const std::string& text) override; }; diff --git a/main/protocols/protocol.cc b/main/protocols/protocol.cc index d9906c2a..5c55141e 100644 --- a/main/protocols/protocol.cc +++ b/main/protocols/protocol.cc @@ -23,3 +23,37 @@ void Protocol::OnAudioChannelClosed(std::function callback) { void Protocol::OnNetworkError(std::function callback) { on_network_error_ = callback; } + +void Protocol::SendAbortSpeaking(AbortReason reason) { + std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"abort\""; + if (reason == kAbortReasonWakeWordDetected) { + message += ",\"reason\":\"wake_word_detected\""; + } + message += "}"; + SendText(message); +} + +void Protocol::SendWakeWordDetected(const std::string& wake_word) { + std::string json = "{\"session_id\":\"" + session_id_ + + "\",\"type\":\"listen\",\"state\":\"detect\",\"text\":\"" + wake_word + "\"}"; + SendText(json); +} + +void Protocol::SendStartListening(ListeningMode mode) { + std::string message = "{\"session_id\":\"" + session_id_ + "\""; + message += ",\"type\":\"listen\",\"state\":\"start\""; + if (mode == kListeningModeAlwaysOn) { + message += ",\"mode\":\"realtime\""; + } else if (mode == kListeningModeAutoStop) { + message += ",\"mode\":\"auto\""; + } else { + message += ",\"mode\":\"manual\""; + } + message += "}"; + SendText(message); +} + +void Protocol::SendStopListening() { + std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"listen\",\"state\":\"stop\"}"; + SendText(message); +} diff --git a/main/protocols/protocol.h b/main/protocols/protocol.h index 6261b9ca..5b6216ab 100644 --- a/main/protocols/protocol.h +++ b/main/protocols/protocol.h @@ -12,6 +12,16 @@ struct BinaryProtocol3 { uint8_t payload[]; } __attribute__((packed)); +enum AbortReason { + kAbortReasonNone, + kAbortReasonWakeWordDetected +}; + +enum ListeningMode { + kListeningModeAutoStop, + kListeningModeManualStop, + kListeningModeAlwaysOn // 需要 AEC 支持 +}; class Protocol { public: @@ -27,13 +37,14 @@ class Protocol { void OnAudioChannelClosed(std::function callback); void OnNetworkError(std::function callback); - virtual void SendAudio(const std::string& data) = 0; - virtual void SendText(const std::string& text) = 0; - virtual void SendState(const std::string& state) = 0; - virtual void SendAbort() = 0; virtual bool OpenAudioChannel() = 0; virtual void CloseAudioChannel() = 0; virtual bool IsAudioChannelOpened() const = 0; + virtual void SendAudio(const std::string& data) = 0; + virtual void SendWakeWordDetected(const std::string& wake_word); + virtual void SendStartListening(ListeningMode mode); + virtual void SendStopListening(); + virtual void SendAbortSpeaking(AbortReason reason); protected: std::function on_incoming_json_; @@ -43,6 +54,9 @@ class Protocol { std::function on_network_error_; int server_sample_rate_ = 16000; + std::string session_id_; + + virtual void SendText(const std::string& text) = 0; }; #endif // PROTOCOL_H diff --git a/main/protocols/websocket_protocol.cc b/main/protocols/websocket_protocol.cc index 078185ad..500615be 100644 --- a/main/protocols/websocket_protocol.cc +++ b/main/protocols/websocket_protocol.cc @@ -39,21 +39,6 @@ void WebsocketProtocol::SendText(const std::string& text) { websocket_->Send(text); } -void WebsocketProtocol::SendState(const std::string& state) { - std::string message = "{"; - message += "\"type\":\"state\","; - message += "\"state\":\"" + state + "\""; - message += "}"; - SendText(message); -} - -void WebsocketProtocol::SendAbort() { - std::string message = "{"; - message += "\"type\":\"abort\""; - message += "}"; - SendText(message); -} - bool WebsocketProtocol::IsAudioChannelOpened() const { return websocket_ != nullptr; } diff --git a/main/protocols/websocket_protocol.h b/main/protocols/websocket_protocol.h index b4bd7670..f62b04f0 100644 --- a/main/protocols/websocket_protocol.h +++ b/main/protocols/websocket_protocol.h @@ -16,9 +16,6 @@ class WebsocketProtocol : public Protocol { ~WebsocketProtocol(); void SendAudio(const std::string& data) override; - void SendText(const std::string& text) override; - void SendState(const std::string& state) override; - void SendAbort() override; bool OpenAudioChannel() override; void CloseAudioChannel() override; bool IsAudioChannelOpened() const override; @@ -28,6 +25,7 @@ class WebsocketProtocol : public Protocol { WebSocket* websocket_ = nullptr; void ParseServerHello(const cJSON* root); + void SendText(const std::string& text) override; }; #endif diff --git a/main/wake_word_detect.cc b/main/wake_word_detect.cc index 622447fc..68b05f32 100644 --- a/main/wake_word_detect.cc +++ b/main/wake_word_detect.cc @@ -4,9 +4,9 @@ #include #include #include +#include #define DETECTION_RUNNING_EVENT 1 -#define WAKE_WORD_ENCODED_EVENT 2 static const char* TAG = "WakeWordDetect"; @@ -40,6 +40,13 @@ void WakeWordDetect::Initialize(int channels, bool reference) { ESP_LOGI(TAG, "Model %d: %s", i, models->model_name[i]); if (strstr(models->model_name[i], ESP_WN_PREFIX) != NULL) { wakenet_model_ = models->model_name[i]; + auto words = esp_srmodel_get_wake_words(models, wakenet_model_); + // split by ";" to get all wake words + std::stringstream ss(words); + std::string word; + while (std::getline(ss, word, ';')) { + wake_words_.push_back(word); + } } } @@ -84,7 +91,7 @@ void WakeWordDetect::Initialize(int channels, bool reference) { }, "audio_detection", 4096 * 2, this, 1, nullptr); } -void WakeWordDetect::OnWakeWordDetected(std::function callback) { +void WakeWordDetect::OnWakeWordDetected(std::function callback) { wake_word_detected_callback_ = callback; } @@ -144,11 +151,11 @@ void WakeWordDetect::AudioDetectionTask() { } if (res->wakeup_state == WAKENET_DETECTED) { - ESP_LOGI(TAG, "Wake word detected"); StopDetection(); + last_detected_wake_word_ = wake_words_[res->wake_word_index - 1]; if (wake_word_detected_callback_) { - wake_word_detected_callback_(); + wake_word_detected_callback_(last_detected_wake_word_); } } } @@ -165,7 +172,6 @@ void WakeWordDetect::StoreWakeWordData(uint16_t* data, size_t samples) { } void WakeWordDetect::EncodeWakeWordData() { - xEventGroupClearBits(event_group_, WAKE_WORD_ENCODED_EVENT); wake_word_opus_.clear(); if (wake_word_encode_task_stack_ == nullptr) { wake_word_encode_task_stack_ = (StackType_t*)heap_caps_malloc(4096 * 8, MALLOC_CAP_SPIRAM); @@ -182,15 +188,18 @@ void WakeWordDetect::EncodeWakeWordData() { encoder->Encode(pcm, [this_](const uint8_t* opus, size_t opus_size) { std::lock_guard lock(this_->wake_word_mutex_); this_->wake_word_opus_.emplace_back(std::string(reinterpret_cast(opus), opus_size)); - this_->wake_word_cv_.notify_one(); + this_->wake_word_cv_.notify_all(); }); } this_->wake_word_pcm_.clear(); auto end_time = esp_timer_get_time(); ESP_LOGI(TAG, "Encode wake word opus %zu packets in %lld ms", this_->wake_word_opus_.size(), (end_time - start_time) / 1000); - xEventGroupSetBits(this_->event_group_, WAKE_WORD_ENCODED_EVENT); - this_->wake_word_cv_.notify_one(); + { + std::lock_guard lock(this_->wake_word_mutex_); + this_->wake_word_opus_.push_back(""); + this_->wake_word_cv_.notify_all(); + } delete encoder; vTaskDelete(NULL); }, "encode_detect_packets", 4096 * 8, this, 1, wake_word_encode_task_stack_, &wake_word_encode_task_buffer_); @@ -199,12 +208,9 @@ void WakeWordDetect::EncodeWakeWordData() { bool WakeWordDetect::GetWakeWordOpus(std::string& opus) { std::unique_lock lock(wake_word_mutex_); wake_word_cv_.wait(lock, [this]() { - return !wake_word_opus_.empty() || (xEventGroupGetBits(event_group_) & WAKE_WORD_ENCODED_EVENT); + return !wake_word_opus_.empty(); }); - if (wake_word_opus_.empty()) { - return false; - } opus.swap(wake_word_opus_.front()); wake_word_opus_.pop_front(); - return true; + return !opus.empty(); } diff --git a/main/wake_word_detect.h b/main/wake_word_detect.h index 7a472be9..892ea56b 100644 --- a/main/wake_word_detect.h +++ b/main/wake_word_detect.h @@ -23,24 +23,27 @@ class WakeWordDetect { void Initialize(int channels, bool reference); void Feed(std::vector& data); - void OnWakeWordDetected(std::function callback); + void OnWakeWordDetected(std::function callback); void OnVadStateChange(std::function callback); void StartDetection(); void StopDetection(); bool IsDetectionRunning(); void EncodeWakeWordData(); bool GetWakeWordOpus(std::string& opus); + const std::string& GetLastDetectedWakeWord() const { return last_detected_wake_word_; } private: esp_afe_sr_data_t* afe_detection_data_ = nullptr; char* wakenet_model_ = NULL; + std::vector wake_words_; std::vector input_buffer_; EventGroupHandle_t event_group_; - std::function wake_word_detected_callback_; + std::function wake_word_detected_callback_; std::function vad_state_change_callback_; bool is_speaking_ = false; int channels_; bool reference_; + std::string last_detected_wake_word_; TaskHandle_t wake_word_encode_task_ = nullptr; StaticTask_t wake_word_encode_task_buffer_;