Skip to content

Commit 5b5a8cf

Browse files
committed
sqlite,test,doc: allow Buffer and URL as database location
1 parent 18fe76b commit 5b5a8cf

File tree

3 files changed

+194
-51
lines changed

3 files changed

+194
-51
lines changed

src/node_sqlite.cc

Lines changed: 102 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "node_sqlite.h"
22
#include <path.h>
3-
#include "ada.h"
43
#include "base_object-inl.h"
54
#include "debug_utils-inl.h"
65
#include "env-inl.h"
@@ -184,7 +183,7 @@ class BackupJob : public ThreadPoolWork {
184183
HandleScope handle_scope(isolate);
185184
backup_status_ = sqlite3_open_v2(destination_name_.c_str(),
186185
&dest_,
187-
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
186+
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
188187
nullptr);
189188
Local<Promise::Resolver> resolver =
190189
Local<Promise::Resolver>::New(env()->isolate(), resolver_);
@@ -504,11 +503,14 @@ bool DatabaseSync::Open() {
504503
}
505504

506505
// TODO(cjihrig): Support additional flags.
506+
int default_flags = SQLITE_OPEN_URI;
507507
int flags = open_config_.get_read_only()
508508
? SQLITE_OPEN_READONLY
509509
: SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE;
510-
int r = sqlite3_open_v2(
511-
open_config_.location().c_str(), &connection_, flags, nullptr);
510+
int r = sqlite3_open_v2(open_config_.location().c_str(),
511+
&connection_,
512+
flags | default_flags,
513+
nullptr);
512514
CHECK_ERROR_OR_THROW(env()->isolate(), this, r, SQLITE_OK, false);
513515

514516
r = sqlite3_db_config(connection_,
@@ -586,8 +588,84 @@ bool DatabaseSync::ShouldIgnoreSQLiteError() {
586588
return ignore_next_sqlite_error_;
587589
}
588590

589-
bool IsURL(Local<Value> value) {
590-
return false;
591+
bool IsURL(Environment* env, Local<Value> path) {
592+
Local<Object> url;
593+
if (!path->ToObject(env->context()).ToLocal(&url)) {
594+
return false;
595+
}
596+
597+
Local<Value> href;
598+
if (!url->Get(env->context(), FIXED_ONE_BYTE_STRING(env->isolate(), "href"))
599+
.ToLocal(&href)) {
600+
return false;
601+
}
602+
603+
if (!href->IsString()) {
604+
return false;
605+
}
606+
607+
return true;
608+
}
609+
610+
Local<String> BufferToString(Environment* env, Local<Uint8Array> buffer) {
611+
size_t byteOffset = buffer->ByteOffset();
612+
size_t byteLength = buffer->ByteLength();
613+
if (byteLength == 0) {
614+
THROW_ERR_INVALID_ARG_TYPE(env->isolate(),
615+
"The \"path\" argument must not be empty.");
616+
return Local<String>();
617+
}
618+
619+
auto data =
620+
static_cast<const uint8_t*>(buffer->Buffer()->Data()) + byteOffset;
621+
if (std::find(data, data + byteLength, 0) != data + byteLength) {
622+
THROW_ERR_INVALID_ARG_TYPE(env->isolate(),
623+
"The \"path\" argument must not contain null "
624+
"bytes.");
625+
return Local<String>();
626+
}
627+
628+
auto path = std::string(reinterpret_cast<const char*>(data), byteLength);
629+
return String::NewFromUtf8(
630+
env->isolate(), path.c_str(), NewStringType::kNormal)
631+
.ToLocalChecked();
632+
}
633+
634+
Local<String> ToPathIfURL(Environment* env, Local<Value> path) {
635+
if (!IsURL(env, path)) {
636+
if (path->IsString()) {
637+
return path.As<String>();
638+
}
639+
640+
return BufferToString(env, path.As<Uint8Array>());
641+
}
642+
643+
Local<Object> url = path.As<Object>();
644+
Local<Value> href;
645+
Local<Value> protocol;
646+
if (!url->Get(env->context(), FIXED_ONE_BYTE_STRING(env->isolate(), "href"))
647+
.ToLocal(&href)) {
648+
return Local<String>();
649+
}
650+
651+
if (!url->Get(env->context(),
652+
FIXED_ONE_BYTE_STRING(env->isolate(), "protocol"))
653+
.ToLocal(&protocol)) {
654+
return Local<String>();
655+
}
656+
657+
if (!href->IsString() || !protocol->IsString()) {
658+
return Local<String>();
659+
}
660+
661+
std::string protocol_v =
662+
Utf8Value(env->isolate(), protocol.As<String>()).ToString();
663+
if (protocol_v != "file:") {
664+
THROW_ERR_INVALID_URL_SCHEME(env->isolate());
665+
return Local<String>();
666+
}
667+
668+
return href.As<String>();
591669
}
592670

593671
void DatabaseSync::New(const FunctionCallbackInfo<Value>& args) {
@@ -598,47 +676,20 @@ void DatabaseSync::New(const FunctionCallbackInfo<Value>& args) {
598676
return;
599677
}
600678

601-
Local<Value> path = args[0]; // if object, it's a URL, so path will be the
602-
// "href" property then i can check the scheme
603-
// and if it's not file, throw an error
604-
if (!path->IsString() && !path->IsUint8Array() && !IsURL(path)) {
679+
Local<Value> path = args[0];
680+
if (!path->IsString() && !path->IsUint8Array() && !IsURL(env, path)) {
605681
THROW_ERR_INVALID_ARG_TYPE(env->isolate(),
606-
"The \"path\" argument must be a string, "
682+
"The \"location\" argument must be a string, "
607683
"Uint8Array, or URL without null bytes.");
608684
return;
609685
}
610686

611-
std::string location;
612-
if (path->IsUint8Array()) {
613-
Local<Uint8Array> buffer = path.As<Uint8Array>();
614-
size_t byteOffset = buffer->ByteOffset();
615-
size_t byteLength = buffer->ByteLength();
616-
if (byteLength == 0) {
617-
THROW_ERR_INVALID_ARG_TYPE(env->isolate(),
618-
"The \"path\" argument must not be empty.");
619-
return;
620-
}
621-
622-
auto data =
623-
static_cast<const uint8_t*>(buffer->Buffer()->Data()) + byteOffset;
624-
if (std::find(data, data + byteLength, 0) != data + byteLength) {
625-
THROW_ERR_INVALID_ARG_TYPE(env->isolate(),
626-
"The \"path\" argument must not contain null "
627-
"bytes.");
628-
return;
629-
}
630-
631-
location = std::string(reinterpret_cast<const char*>(data), byteLength);
632-
} else {
633-
location = Utf8Value(env->isolate(), args[0].As<String>()).ToString();
687+
Local<String> path_str = ToPathIfURL(env, path);
688+
if (path_str.IsEmpty()) {
689+
return;
634690
}
635691

636-
// TODO: uncomment this we still need to handle URLs
637-
/* auto parsed_url = ada::parse<ada::url_aggregator>(location, nullptr); */
638-
/* if (parsed_url && parsed_url->type != ada::scheme::FILE) { */
639-
/* THROW_ERR_INVALID_URL_SCHEME(env->isolate()); */
640-
/* } */
641-
692+
std::string location = Utf8Value(env->isolate(), path_str).ToString();
642693
DatabaseOpenConfiguration open_config(std::move(location));
643694
bool open = true;
644695
bool allow_load_extension = false;
@@ -1020,17 +1071,23 @@ void Backup(const FunctionCallbackInfo<Value>& args) {
10201071
DatabaseSync* db;
10211072
ASSIGN_OR_RETURN_UNWRAP(&db, args[0].As<Object>());
10221073
THROW_AND_RETURN_ON_BAD_STATE(env, !db->IsOpen(), "database is not open");
1023-
if (!args[1]->IsString()) {
1024-
THROW_ERR_INVALID_ARG_TYPE(
1025-
env->isolate(), "The \"destination\" argument must be a string.");
1074+
Local<Value> path = args[1];
1075+
if (!path->IsString() && !path->IsUint8Array() && !IsURL(env, path)) {
1076+
THROW_ERR_INVALID_ARG_TYPE(env->isolate(),
1077+
"The \"destination\" argument must be a string, "
1078+
"Uint8Array, or URL without null bytes.");
1079+
return;
1080+
}
1081+
1082+
Local<String> path_str = ToPathIfURL(env, path);
1083+
if (path_str.IsEmpty()) {
10261084
return;
10271085
}
10281086

10291087
int rate = 100;
10301088
std::string source_db = "main";
10311089
std::string dest_db = "main";
1032-
1033-
Utf8Value dest_path(env->isolate(), args[1].As<String>());
1090+
Utf8Value dest_path(env->isolate(), path_str);
10341091
Local<Function> progressFunc = Local<Function>();
10351092

10361093
if (args.Length() > 2) {

test/parallel/test-sqlite-backup.mjs

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { join } from 'node:path';
44
import { backup, DatabaseSync } from 'node:sqlite';
55
import { describe, test } from 'node:test';
66
import { writeFileSync } from 'node:fs';
7+
import { pathToFileURL } from 'node:url';
78

89
let cnt = 0;
910

@@ -42,23 +43,21 @@ describe('backup()', () => {
4243
});
4344
});
4445

45-
test('throws if path is not a string', (t) => {
46-
// const database = makeSourceDb();
47-
// TODO: have a separate test handling buffer
48-
const database = makeSourceDb(Buffer.from(':memory:'));
46+
test('throws if path is not a string, URL, or Buffer', (t) => {
47+
const database = makeSourceDb();
4948

5049
t.assert.throws(() => {
5150
backup(database);
5251
}, {
5352
code: 'ERR_INVALID_ARG_TYPE',
54-
message: 'The "destination" argument must be a string.'
53+
message: 'The "destination" argument must be a string, Uint8Array, or URL without null bytes.'
5554
});
5655

5756
t.assert.throws(() => {
5857
backup(database, {});
5958
}, {
6059
code: 'ERR_INVALID_ARG_TYPE',
61-
message: 'The "destination" argument must be a string.'
60+
message: 'The "destination" argument must be a string, Uint8Array, or URL without null bytes.'
6261
});
6362
});
6463

@@ -143,6 +142,64 @@ test('database backup', async (t) => {
143142
});
144143
});
145144

145+
test('backup database using location as URL', async (t) => {
146+
const progressFn = t.mock.fn();
147+
const database = makeSourceDb();
148+
const destDb = pathToFileURL(nextDb());
149+
150+
await backup(database, destDb, {
151+
rate: 1,
152+
progress: progressFn,
153+
});
154+
155+
const backupDb = new DatabaseSync(destDb);
156+
const rows = backupDb.prepare('SELECT * FROM data').all();
157+
158+
// The source database has two pages - using the default page size -,
159+
// so the progress function should be called once (the last call is not made since
160+
// the promise resolves)
161+
t.assert.strictEqual(progressFn.mock.calls.length, 1);
162+
t.assert.deepStrictEqual(progressFn.mock.calls[0].arguments, [{ totalPages: 2, remainingPages: 1 }]);
163+
t.assert.deepStrictEqual(rows, [
164+
{ __proto__: null, key: 1, value: 'value-1' },
165+
{ __proto__: null, key: 2, value: 'value-2' },
166+
]);
167+
168+
t.after(() => {
169+
database.close();
170+
backupDb.close();
171+
});
172+
});
173+
174+
test('backup database using location as Buffer', async (t) => {
175+
const progressFn = t.mock.fn();
176+
const database = makeSourceDb();
177+
const destDb = Buffer.from(nextDb());
178+
179+
await backup(database, destDb, {
180+
rate: 1,
181+
progress: progressFn,
182+
});
183+
184+
const backupDb = new DatabaseSync(destDb);
185+
const rows = backupDb.prepare('SELECT * FROM data').all();
186+
187+
// The source database has two pages - using the default page size -,
188+
// so the progress function should be called once (the last call is not made since
189+
// the promise resolves)
190+
t.assert.strictEqual(progressFn.mock.calls.length, 1);
191+
t.assert.deepStrictEqual(progressFn.mock.calls[0].arguments, [{ totalPages: 2, remainingPages: 1 }]);
192+
t.assert.deepStrictEqual(rows, [
193+
{ __proto__: null, key: 1, value: 'value-1' },
194+
{ __proto__: null, key: 2, value: 'value-2' },
195+
]);
196+
197+
t.after(() => {
198+
database.close();
199+
backupDb.close();
200+
});
201+
});
202+
146203
test('database backup in a single call', async (t) => {
147204
const progressFn = t.mock.fn();
148205
const database = makeSourceDb();

test/parallel/test-sqlite.js

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ const tmpdir = require('../common/tmpdir');
44
const { join } = require('node:path');
55
const { DatabaseSync, constants } = require('node:sqlite');
66
const { suite, test } = require('node:test');
7+
const { pathToFileURL } = require('node:url');
8+
79
let cnt = 0;
810

911
tmpdir.refresh();
@@ -111,3 +113,30 @@ test('math functions are enabled', (t) => {
111113
{ __proto__: null, pi: 3.141592653589793 },
112114
);
113115
});
116+
117+
test('Buffer is supported as the database location', (t) => {
118+
const db = new DatabaseSync(Buffer.from(nextDb()));
119+
db.exec(`
120+
CREATE TABLE data(key INTEGER PRIMARY KEY);
121+
INSERT INTO data (key) VALUES (1);
122+
`);
123+
124+
t.assert.deepStrictEqual(
125+
db.prepare('SELECT * FROM data').all(),
126+
[{ __proto__: null, key: 1 }]
127+
);
128+
});
129+
130+
test('URL is supported as the database location', (t) => {
131+
const url = pathToFileURL(nextDb());
132+
const db = new DatabaseSync(url);
133+
db.exec(`
134+
CREATE TABLE data(key INTEGER PRIMARY KEY);
135+
INSERT INTO data (key) VALUES (1);
136+
`);
137+
138+
t.assert.deepStrictEqual(
139+
db.prepare('SELECT * FROM data').all(),
140+
[{ __proto__: null, key: 1 }]
141+
);
142+
});

0 commit comments

Comments
 (0)