|
7 | 7 | #include "node.h"
|
8 | 8 | #include "node_errors.h"
|
9 | 9 | #include "node_mem-inl.h"
|
| 10 | +#include "node_url.h" |
10 | 11 | #include "sqlite3.h"
|
11 | 12 | #include "threadpoolwork-inl.h"
|
12 | 13 | #include "util-inl.h"
|
@@ -181,10 +182,11 @@ class BackupJob : public ThreadPoolWork {
|
181 | 182 | void ScheduleBackup() {
|
182 | 183 | Isolate* isolate = env()->isolate();
|
183 | 184 | HandleScope handle_scope(isolate);
|
184 |
| - backup_status_ = sqlite3_open_v2(destination_name_.c_str(), |
185 |
| - &dest_, |
186 |
| - SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, |
187 |
| - nullptr); |
| 185 | + backup_status_ = sqlite3_open_v2( |
| 186 | + destination_name_.c_str(), |
| 187 | + &dest_, |
| 188 | + SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI, |
| 189 | + nullptr); |
188 | 190 | Local<Promise::Resolver> resolver =
|
189 | 191 | Local<Promise::Resolver>::New(env()->isolate(), resolver_);
|
190 | 192 | if (backup_status_ != SQLITE_OK) {
|
@@ -503,11 +505,14 @@ bool DatabaseSync::Open() {
|
503 | 505 | }
|
504 | 506 |
|
505 | 507 | // TODO(cjihrig): Support additional flags.
|
| 508 | + int default_flags = SQLITE_OPEN_URI; |
506 | 509 | int flags = open_config_.get_read_only()
|
507 | 510 | ? SQLITE_OPEN_READONLY
|
508 | 511 | : SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE;
|
509 |
| - int r = sqlite3_open_v2( |
510 |
| - open_config_.location().c_str(), &connection_, flags, nullptr); |
| 512 | + int r = sqlite3_open_v2(open_config_.location().c_str(), |
| 513 | + &connection_, |
| 514 | + flags | default_flags, |
| 515 | + nullptr); |
511 | 516 | CHECK_ERROR_OR_THROW(env()->isolate(), this, r, SQLITE_OK, false);
|
512 | 517 |
|
513 | 518 | r = sqlite3_db_config(connection_,
|
@@ -585,27 +590,85 @@ bool DatabaseSync::ShouldIgnoreSQLiteError() {
|
585 | 590 | return ignore_next_sqlite_error_;
|
586 | 591 | }
|
587 | 592 |
|
| 593 | +std::optional<std::string> ValidateDatabasePath(Environment* env, |
| 594 | + Local<Value> path, |
| 595 | + const std::string& field_name) { |
| 596 | + auto has_null_bytes = [](const std::string& str) { |
| 597 | + return str.find('\0') != std::string::npos; |
| 598 | + }; |
| 599 | + std::string location; |
| 600 | + if (path->IsString()) { |
| 601 | + location = Utf8Value(env->isolate(), path.As<String>()).ToString(); |
| 602 | + if (!has_null_bytes(location)) { |
| 603 | + return location; |
| 604 | + } |
| 605 | + } |
| 606 | + |
| 607 | + if (path->IsUint8Array()) { |
| 608 | + Local<Uint8Array> buffer = path.As<Uint8Array>(); |
| 609 | + size_t byteOffset = buffer->ByteOffset(); |
| 610 | + size_t byteLength = buffer->ByteLength(); |
| 611 | + auto data = |
| 612 | + static_cast<const uint8_t*>(buffer->Buffer()->Data()) + byteOffset; |
| 613 | + if (!(std::find(data, data + byteLength, 0) != data + byteLength)) { |
| 614 | + Local<Value> out; |
| 615 | + if (String::NewFromUtf8(env->isolate(), |
| 616 | + reinterpret_cast<const char*>(data), |
| 617 | + NewStringType::kNormal, |
| 618 | + static_cast<int>(byteLength)) |
| 619 | + .ToLocal(&out)) { |
| 620 | + return Utf8Value(env->isolate(), out.As<String>()).ToString(); |
| 621 | + } |
| 622 | + } |
| 623 | + } |
| 624 | + |
| 625 | + // When is URL |
| 626 | + if (path->IsObject()) { |
| 627 | + Local<Object> url = path.As<Object>(); |
| 628 | + Local<Value> href; |
| 629 | + Local<Value> protocol; |
| 630 | + if (url->Get(env->context(), env->href_string()).ToLocal(&href) && |
| 631 | + href->IsString() && |
| 632 | + url->Get(env->context(), env->protocol_string()).ToLocal(&protocol) && |
| 633 | + protocol->IsString()) { |
| 634 | + location = Utf8Value(env->isolate(), href.As<String>()).ToString(); |
| 635 | + if (!has_null_bytes(location)) { |
| 636 | + auto file_url = ada::parse(location); |
| 637 | + CHECK(file_url); |
| 638 | + if (file_url->type != ada::scheme::FILE) { |
| 639 | + THROW_ERR_INVALID_URL_SCHEME(env->isolate()); |
| 640 | + return std::nullopt; |
| 641 | + } |
| 642 | + |
| 643 | + return location; |
| 644 | + } |
| 645 | + } |
| 646 | + } |
| 647 | + |
| 648 | + THROW_ERR_INVALID_ARG_TYPE(env->isolate(), |
| 649 | + "The \"%s\" argument must be a string, " |
| 650 | + "Uint8Array, or URL without null bytes.", |
| 651 | + field_name.c_str()); |
| 652 | + |
| 653 | + return std::nullopt; |
| 654 | +} |
| 655 | + |
588 | 656 | void DatabaseSync::New(const FunctionCallbackInfo<Value>& args) {
|
589 | 657 | Environment* env = Environment::GetCurrent(args);
|
590 |
| - |
591 | 658 | if (!args.IsConstructCall()) {
|
592 | 659 | THROW_ERR_CONSTRUCT_CALL_REQUIRED(env);
|
593 | 660 | return;
|
594 | 661 | }
|
595 | 662 |
|
596 |
| - if (!args[0]->IsString()) { |
597 |
| - THROW_ERR_INVALID_ARG_TYPE(env->isolate(), |
598 |
| - "The \"path\" argument must be a string."); |
| 663 | + std::optional<std::string> location = |
| 664 | + ValidateDatabasePath(env, args[0], "path"); |
| 665 | + if (!location.has_value()) { |
599 | 666 | return;
|
600 | 667 | }
|
601 | 668 |
|
602 |
| - std::string location = |
603 |
| - Utf8Value(env->isolate(), args[0].As<String>()).ToString(); |
604 |
| - DatabaseOpenConfiguration open_config(std::move(location)); |
605 |
| - |
| 669 | + DatabaseOpenConfiguration open_config(std::move(location.value())); |
606 | 670 | bool open = true;
|
607 | 671 | bool allow_load_extension = false;
|
608 |
| - |
609 | 672 | if (args.Length() > 1) {
|
610 | 673 | if (!args[1]->IsObject()) {
|
611 | 674 | THROW_ERR_INVALID_ARG_TYPE(env->isolate(),
|
@@ -984,17 +1047,15 @@ void Backup(const FunctionCallbackInfo<Value>& args) {
|
984 | 1047 | DatabaseSync* db;
|
985 | 1048 | ASSIGN_OR_RETURN_UNWRAP(&db, args[0].As<Object>());
|
986 | 1049 | THROW_AND_RETURN_ON_BAD_STATE(env, !db->IsOpen(), "database is not open");
|
987 |
| - if (!args[1]->IsString()) { |
988 |
| - THROW_ERR_INVALID_ARG_TYPE( |
989 |
| - env->isolate(), "The \"destination\" argument must be a string."); |
| 1050 | + std::optional<std::string> dest_path = |
| 1051 | + ValidateDatabasePath(env, args[1], "path"); |
| 1052 | + if (!dest_path.has_value()) { |
990 | 1053 | return;
|
991 | 1054 | }
|
992 | 1055 |
|
993 | 1056 | int rate = 100;
|
994 | 1057 | std::string source_db = "main";
|
995 | 1058 | std::string dest_db = "main";
|
996 |
| - |
997 |
| - Utf8Value dest_path(env->isolate(), args[1].As<String>()); |
998 | 1059 | Local<Function> progressFunc = Local<Function>();
|
999 | 1060 |
|
1000 | 1061 | if (args.Length() > 2) {
|
@@ -1077,12 +1138,11 @@ void Backup(const FunctionCallbackInfo<Value>& args) {
|
1077 | 1138 | }
|
1078 | 1139 |
|
1079 | 1140 | args.GetReturnValue().Set(resolver->GetPromise());
|
1080 |
| - |
1081 | 1141 | BackupJob* job = new BackupJob(env,
|
1082 | 1142 | db,
|
1083 | 1143 | resolver,
|
1084 | 1144 | std::move(source_db),
|
1085 |
| - *dest_path, |
| 1145 | + dest_path.value(), |
1086 | 1146 | std::move(dest_db),
|
1087 | 1147 | rate,
|
1088 | 1148 | progressFunc);
|
|
0 commit comments