Skip to content

Add App Check support to Database #1260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 63 additions & 26 deletions app_check/integration_test/src/integration_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class FirebaseAppCheckTest : public FirebaseTest {

firebase::database::DatabaseReference CreateWorkingPath(
bool suppress_cleanup = false);
void CleanupDatabase(int expected_error);

firebase::firestore::CollectionReference GetFirestoreCollection();
firebase::firestore::DocumentReference CreateFirestoreDoc();
Expand Down Expand Up @@ -312,20 +313,7 @@ void FirebaseAppCheckTest::TerminateDatabase() {
if (!initialized_) return;

if (database_) {
if (!database_cleanup_.empty() && database_ && app_) {
LogDebug("Cleaning up...");
std::vector<firebase::Future<void>> cleanups;
cleanups.reserve(database_cleanup_.size());
for (int i = 0; i < database_cleanup_.size(); ++i) {
cleanups.push_back(database_cleanup_[i].RemoveValue());
}
for (int i = 0; i < cleanups.size(); ++i) {
std::string cleanup_name =
"Cleanup (" + database_cleanup_[i].url() + ")";
WaitForCompletion(cleanups[i], cleanup_name.c_str());
}
database_cleanup_.clear();
}
CleanupDatabase(0);

LogDebug("Shutdown the Database library.");
delete database_;
Expand All @@ -336,6 +324,22 @@ void FirebaseAppCheckTest::TerminateDatabase() {
ProcessEvents(100);
}

void FirebaseAppCheckTest::CleanupDatabase(int expected_error) {
if (!database_cleanup_.empty()) {
LogDebug("Cleaning up Database...");
std::vector<firebase::Future<void>> cleanups;
cleanups.reserve(database_cleanup_.size());
for (int i = 0; i < database_cleanup_.size(); ++i) {
cleanups.push_back(database_cleanup_[i].RemoveValue());
}
for (int i = 0; i < cleanups.size(); ++i) {
std::string cleanup_name = "Cleanup (" + database_cleanup_[i].url() + ")";
WaitForCompletion(cleanups[i], cleanup_name.c_str(), expected_error);
}
database_cleanup_.clear();
}
}

void FirebaseAppCheckTest::InitializeAppAuthDatabase() {
InitializeApp();
InitializeAuth();
Expand Down Expand Up @@ -743,18 +747,19 @@ TEST_F(FirebaseAppCheckTest, TestPlayIntegrityProvider) {
#endif
}

// Disabling the database tests for now, since they are crashing or hanging.
TEST_F(FirebaseAppCheckTest, DISABLED_TestDatabaseFailure) {
TEST_F(FirebaseAppCheckTest, TestDatabaseFailure) {
firebase::SetLogLevel(firebase::kLogLevelVerbose);
// Don't initialize App Check this time. Database should fail.
InitializeAppAuthDatabase();
firebase::database::DatabaseReference ref = CreateWorkingPath();
const char* test_name = test_info_->name();
firebase::Future<void> f = ref.Child(test_name).SetValue("test");
// It is unclear if this should fail, or hang, so disabled for now.
WaitForCompletion(f, "SetString");
WaitForCompletion(f, "SetString", firebase::database::kErrorDisconnected);

CleanupDatabase(firebase::database::kErrorOperationFailed);
}

TEST_F(FirebaseAppCheckTest, DISABLED_TestDatabaseCreateWorkingPath) {
TEST_F(FirebaseAppCheckTest, TestDatabaseCreateWorkingPath) {
InitializeAppCheckWithDebug();
InitializeAppAuthDatabase();
firebase::database::DatabaseReference working_path = CreateWorkingPath();
Expand All @@ -769,7 +774,7 @@ TEST_F(FirebaseAppCheckTest, DISABLED_TestDatabaseCreateWorkingPath) {

static const char kSimpleString[] = "Some simple string";

TEST_F(FirebaseAppCheckTest, DISABLED_TestDatabaseSetAndGet) {
TEST_F(FirebaseAppCheckTest, TestDatabaseSetAndGet) {
InitializeAppCheckWithDebug();
InitializeAppAuthDatabase();

Expand All @@ -795,7 +800,7 @@ TEST_F(FirebaseAppCheckTest, DISABLED_TestDatabaseSetAndGet) {
}
}

TEST_F(FirebaseAppCheckTest, DISABLED_TestRunTransaction) {
TEST_F(FirebaseAppCheckTest, TestDatabaseRunTransaction) {
InitializeAppCheckWithDebug();
InitializeAppAuthDatabase();

Expand Down Expand Up @@ -849,6 +854,42 @@ TEST_F(FirebaseAppCheckTest, DISABLED_TestRunTransaction) {
}
}

TEST_F(FirebaseAppCheckTest, TestDatabaseUpdateToken) {
// Test that after forcing an App Check token update, the database connection
// still works.
InitializeAppCheckWithDebug();
InitializeAppAuthDatabase();

const char* test_name = test_info_->name();
firebase::database::DatabaseReference ref = CreateWorkingPath();

{
LogDebug("Setting value.");
firebase::Future<void> f1 =
ref.Child(test_name).Child("String").SetValue(kSimpleString);
WaitForCompletion(f1, "SetSimpleString");
}

// Force App Check to update its token.
::firebase::app_check::AppCheck* app_check =
::firebase::app_check::AppCheck::GetInstance(app_);
ASSERT_NE(app_check, nullptr);
firebase::Future<::firebase::app_check::AppCheckToken> future =
app_check->GetAppCheckToken(true);
EXPECT_TRUE(WaitForCompletion(future, "GetAppCheckToken"));

// Get the values that we just set, and confirm that they match what we
// set them to.
{
LogDebug("Getting value.");
firebase::Future<firebase::database::DataSnapshot> f1 =
ref.Child(test_name).Child("String").GetValue();
WaitForCompletion(f1, "GetSimpleString");

EXPECT_EQ(f1.result()->value().AsString(), kSimpleString);
}
}

TEST_F(FirebaseAppCheckTest, TestStorageReadFile) {
InitializeAppCheckWithDebug();
InitializeAppAuthStorage();
Expand Down Expand Up @@ -1030,11 +1071,7 @@ TEST_F(FirebaseAppCheckTest, TestFirestoreListenerFailure) {
const firebase::firestore::DocumentSnapshot& result,
firebase::firestore::Error error_code,
const std::string& error_message) {
if (error_code == firebase::firestore::kErrorNone) {
// If we receive a success, it should only be for the cache.
EXPECT_TRUE(result.metadata().has_pending_writes());
EXPECT_TRUE(result.metadata().is_from_cache());
} else {
if (error_code != firebase::firestore::kErrorNone) {
// We expect one call with a Permission Denied error, from the
// server.
std::lock_guard<std::mutex> lock(mutex);
Expand Down
2 changes: 1 addition & 1 deletion app_check/src/desktop/app_check_desktop.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace internal {

// The callback type for psuedo-AppCheckListeners added via the
// function registry.
using FunctionRegistryCallback = void (*)(std::string, void*);
using FunctionRegistryCallback = void (*)(const std::string&, void*);

class FunctionRegistryAppCheckListener : public AppCheckListener {
public:
Expand Down
6 changes: 6 additions & 0 deletions app_check/src/desktop/debug_provider_desktop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ void DebugAppCheckProvider::GetToken(
// options.
const char* debug_token = std::getenv("APP_CHECK_DEBUG_TOKEN");

if (!debug_token) {
completion_callback({}, kAppCheckErrorInvalidConfiguration,
"Missing debug token");
return;
}

// Exchange debug token with the backend to get a proper attestation token.
auto request = MakeShared<DebugTokenRequest>(app_);
request->SetDebugToken(debug_token);
Expand Down
20 changes: 17 additions & 3 deletions database/src/desktop/connection/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

#include <cstdlib>
#include <cstring>
#include <string>

#include "app/src/assert.h"
#include "app/src/log.h"
#include "app/src/variant_util.h"
#include "database/src/desktop/connection/util_connection.h"
#include "database/src/desktop/connection/web_socket_client_impl.h"

namespace firebase {
namespace database {
Expand Down Expand Up @@ -52,7 +54,8 @@ compat::Atomic<uint32_t> Connection::next_log_id_(0);

Connection::Connection(scheduler::Scheduler* scheduler, const HostInfo& info,
const char* opt_last_session_id,
ConnectionEventHandler* event_handler, Logger* logger)
ConnectionEventHandler* event_handler, Logger* logger,
const std::string& app_check_token)
: safe_this_(this),
event_handler_(event_handler),
scheduler_(scheduler),
Expand All @@ -72,7 +75,7 @@ Connection::Connection(scheduler::Scheduler* scheduler, const HostInfo& info,

// Create web socket client regardless of its implementation
client_ = CreateWebSocketClient(host_info_, this, opt_last_session_id, logger,
scheduler);
scheduler, app_check_token);
}

Connection::~Connection() {
Expand Down Expand Up @@ -404,7 +407,10 @@ void Connection::OnConnectionShutdown(const std::string& reason) {

event_handler_->OnKill(reason);

Close(kDisconnectReasonShutdownMessage);
// OnKill can result in the client being torn down, so check for that.
if (client_) {
Close(kDisconnectReasonShutdownMessage);
}
}

void Connection::OnHandshake(const Variant& handshake) {
Expand Down Expand Up @@ -462,6 +468,14 @@ void Connection::OnReset(const std::string& host) {
Close(kDisconnectReasonServerReset);
}

void Connection::RefreshAppCheckToken(const std::string& token) {
WebSocketClientImpl* client_impl =
dynamic_cast<WebSocketClientImpl*>(client_.get());
if (client_impl) {
client_impl->RefreshAppCheckToken(token);
}
}

} // namespace connection
} // namespace internal
} // namespace database
Expand Down
9 changes: 8 additions & 1 deletion database/src/desktop/connection/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ class Connection : public WebSocketClientEventHandler {

explicit Connection(scheduler::Scheduler* scheduler, const HostInfo& info,
const char* opt_last_session_id,
ConnectionEventHandler* event_handler, Logger* logger);
ConnectionEventHandler* event_handler, Logger* logger,
const std::string& app_check_token = "");
~Connection() override;

// Connection is neither copyable nor movable.
Expand Down Expand Up @@ -104,6 +105,12 @@ class Connection : public WebSocketClientEventHandler {
void OnError(const WebSocketClientErrorData& error_data) override;
// END WebSocketClientEventHandler

// Refresh the stored App Check token being used by the connection.
// This doesn't change the connection itself, just the data used for
// establishing new connections.
// Expect to be called from scheduler thread.
void RefreshAppCheckToken(const std::string& token);

private:
// State of the connection
enum State {
Expand Down
Loading