From 2130e4d1a1f04abd681f6877a7253f061693272e Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 1 Mar 2024 22:11:50 +0100 Subject: [PATCH 001/201] moved from 'temporary_file_manager' branch --- src/include/duckdb/main/config.hpp | 3 + src/include/duckdb/main/settings.hpp | 10 +++ .../duckdb/storage/temporary_file_manager.hpp | 33 ++++++++-- src/main/config.cpp | 9 +++ src/main/database.cpp | 3 + src/main/settings/settings.cpp | 19 ++++++ src/storage/temporary_file_manager.cpp | 62 ++++++++++++++++--- test/api/test_reset.cpp | 1 + 8 files changed, 128 insertions(+), 12 deletions(-) diff --git a/src/include/duckdb/main/config.hpp b/src/include/duckdb/main/config.hpp index 468ec5bff298..b1b6a0fc0756 100644 --- a/src/include/duckdb/main/config.hpp +++ b/src/include/duckdb/main/config.hpp @@ -118,6 +118,8 @@ struct DBConfigOptions { string autoinstall_extension_repo = ""; //! The maximum memory used by the database system (in bytes). Default: 80% of System available memory idx_t maximum_memory = (idx_t)-1; + //! The maximum size of the 'temp_directory' folder when set (in bytes). Default 2x 'maximum_memory' + idx_t maximum_swap_space = (idx_t)-1; //! The maximum amount of CPU threads used by the database system. Default: all available. idx_t maximum_threads = (idx_t)-1; //! The number of external threads that work on DuckDB tasks. Default: 1. @@ -276,6 +278,7 @@ struct DBConfig { DUCKDB_API IndexTypeSet &GetIndexTypes(); static idx_t GetSystemMaxThreads(FileSystem &fs); void SetDefaultMaxMemory(); + void SetDefaultMaxSwapSpace(); OrderType ResolveOrder(OrderType order_type) const; OrderByNullType ResolveNullOrder(OrderType order_type, OrderByNullType null_type) const; diff --git a/src/include/duckdb/main/settings.hpp b/src/include/duckdb/main/settings.hpp index 31d8b1db8931..9ae16ea0fb91 100644 --- a/src/include/duckdb/main/settings.hpp +++ b/src/include/duckdb/main/settings.hpp @@ -418,6 +418,16 @@ struct MaximumMemorySetting { static Value GetSetting(ClientContext &context); }; +struct MaximumTempDirectorySize { + static constexpr const char *Name = "max_temp_directory_size"; + static constexpr const char *Description = + "The maximum amount of data stored inside the 'temp_directory' (when set) (e.g. 1GB)"; + static constexpr const LogicalTypeId InputType = LogicalTypeId::VARCHAR; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(ClientContext &context); +}; + struct OldImplicitCasting { static constexpr const char *Name = "old_implicit_casting"; static constexpr const char *Description = "Allow implicit casting to/from VARCHAR"; diff --git a/src/include/duckdb/storage/temporary_file_manager.hpp b/src/include/duckdb/storage/temporary_file_manager.hpp index bfb22ee96139..7c118fb4cf76 100644 --- a/src/include/duckdb/storage/temporary_file_manager.hpp +++ b/src/include/duckdb/storage/temporary_file_manager.hpp @@ -23,9 +23,26 @@ namespace duckdb { // BlockIndexManager //===--------------------------------------------------------------------===// +class TemporaryFileManager; + +struct FileSizeMonitor { +public: + static constexpr idx_t TEMPFILE_BLOCK_SIZE = Storage::BLOCK_ALLOC_SIZE; + +public: + FileSizeMonitor(TemporaryFileManager &manager); + +public: + void Increase(idx_t blocks); + void Decrease(idx_t blocks); + +private: + TemporaryFileManager &manager; +}; + struct BlockIndexManager { public: - BlockIndexManager(); + BlockIndexManager(unique_ptr file_size_monitor = nullptr); public: //! Obtains a new block index from the index manager @@ -43,6 +60,7 @@ struct BlockIndexManager { idx_t max_index; set free_indexes; set indexes_in_use; + unique_ptr file_size_monitor; }; //===--------------------------------------------------------------------===// @@ -69,7 +87,8 @@ class TemporaryFileHandle { constexpr static idx_t MAX_ALLOWED_INDEX_BASE = 4000; public: - TemporaryFileHandle(idx_t temp_file_count, DatabaseInstance &db, const string &temp_directory, idx_t index); + TemporaryFileHandle(idx_t temp_file_count, DatabaseInstance &db, const string &temp_directory, idx_t index, + TemporaryFileManager &manager); public: struct TemporaryFileLock { @@ -103,8 +122,6 @@ class TemporaryFileHandle { BlockIndexManager index_manager; }; -class TemporaryFileManager; - //===--------------------------------------------------------------------===// // TemporaryDirectoryHandle //===--------------------------------------------------------------------===// @@ -130,6 +147,7 @@ class TemporaryDirectoryHandle { class TemporaryFileManager { public: TemporaryFileManager(DatabaseInstance &db, const string &temp_directory_p); + ~TemporaryFileManager(); public: struct TemporaryManagerLock { @@ -145,6 +163,11 @@ class TemporaryFileManager { unique_ptr ReadTemporaryBuffer(block_id_t id, unique_ptr reusable_buffer); void DeleteTemporaryBuffer(block_id_t id); vector GetTemporaryFiles(); + idx_t GetTotalUsedSpaceInBytes(); + //! Register temporary file size growth + void IncreaseSizeOnDisk(idx_t amount); + //! Register temporary file size decrease + void DecreaseSizeOnDisk(idx_t amount); private: void EraseUsedBlock(TemporaryManagerLock &lock, block_id_t id, TemporaryFileHandle *handle, @@ -164,6 +187,8 @@ class TemporaryFileManager { unordered_map used_blocks; //! Manager of in-use temporary file indexes BlockIndexManager index_manager; + //! The size in bytes of the temporary files that are currently alive + atomic size_on_disk; }; } // namespace duckdb diff --git a/src/main/config.cpp b/src/main/config.cpp index 2dc67e40eac6..cbc030783264 100644 --- a/src/main/config.cpp +++ b/src/main/config.cpp @@ -1,5 +1,6 @@ #include "duckdb/main/config.hpp" +#include "duckdb/common/operator/multiply.hpp" #include "duckdb/common/operator/cast_operators.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/main/settings.hpp" @@ -258,6 +259,14 @@ void DBConfig::SetDefaultMaxMemory() { } } +void DBConfig::SetDefaultMaxSwapSpace() { + auto memory_limit = options.maximum_memory; + if (!TryMultiplyOperator::Operation(memory_limit, 2, options.maximum_memory)) { + // Can't default to 2x memory: fall back to 5GB instead + options.maximum_memory = ParseMemoryLimit("5GB"); + } +} + void DBConfig::CheckLock(const string &name) { if (!options.lock_configuration) { // not locked diff --git a/src/main/database.cpp b/src/main/database.cpp index 61655d4273ac..9744fc97d808 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -333,6 +333,9 @@ void DatabaseInstance::Configure(DBConfig &new_config) { if (config.options.maximum_memory == (idx_t)-1) { config.SetDefaultMaxMemory(); } + if (config.options.maximum_swap_space == (idx_t)-1) { + config.SetDefaultMaxSwapSpace(); + } if (new_config.options.maximum_threads == (idx_t)-1) { config.options.maximum_threads = config.GetSystemMaxThreads(*config.file_system); } diff --git a/src/main/settings/settings.cpp b/src/main/settings/settings.cpp index 2d801e988e3c..e16c2f795f7b 100644 --- a/src/main/settings/settings.cpp +++ b/src/main/settings/settings.cpp @@ -948,6 +948,25 @@ Value MaximumMemorySetting::GetSetting(ClientContext &context) { return Value(StringUtil::BytesToHumanReadableString(config.options.maximum_memory)); } +//===--------------------------------------------------------------------===// +// Maximum Temp Directory Size +//===--------------------------------------------------------------------===// +void MaximumTempDirectorySize::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + config.options.maximum_swap_space = DBConfig::ParseMemoryLimit(input.ToString()); + if (db) { + BufferManager::GetBufferManager(*db).SetLimit(config.options.maximum_swap_space); + } +} + +void MaximumTempDirectorySize::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.SetDefaultMaxSwapSpace(); +} + +Value MaximumTempDirectorySize::GetSetting(ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value(StringUtil::BytesToHumanReadableString(config.options.maximum_swap_space)); +} + //===--------------------------------------------------------------------===// // Old Implicit Casting //===--------------------------------------------------------------------===// diff --git a/src/storage/temporary_file_manager.cpp b/src/storage/temporary_file_manager.cpp index c374829037ec..6a92048f5046 100644 --- a/src/storage/temporary_file_manager.cpp +++ b/src/storage/temporary_file_manager.cpp @@ -4,11 +4,29 @@ namespace duckdb { +//===--------------------------------------------------------------------===// +// FileSizeMonitor +//===--------------------------------------------------------------------===// + +FileSizeMonitor::FileSizeMonitor(TemporaryFileManager &manager) : manager(manager) { +} + +void FileSizeMonitor::Increase(idx_t blocks) { + auto size_on_disk = blocks * TEMPFILE_BLOCK_SIZE; + manager.IncreaseSizeOnDisk(size_on_disk); +} + +void FileSizeMonitor::Decrease(idx_t blocks) { + auto size_on_disk = blocks * TEMPFILE_BLOCK_SIZE; + manager.DecreaseSizeOnDisk(size_on_disk); +} + //===--------------------------------------------------------------------===// // BlockIndexManager //===--------------------------------------------------------------------===// -BlockIndexManager::BlockIndexManager() : max_index(0) { +BlockIndexManager::BlockIndexManager(unique_ptr file_size_monitor) + : max_index(0), file_size_monitor(std::move(file_size_monitor)) { } idx_t BlockIndexManager::GetNewBlockIndex() { @@ -27,12 +45,17 @@ bool BlockIndexManager::RemoveIndex(idx_t index) { free_indexes.insert(index); // check if we can truncate the file + auto old_max = max_index; + // get the max_index in use right now - auto max_index_in_use = indexes_in_use.empty() ? 0 : *indexes_in_use.rbegin(); + auto max_index_in_use = indexes_in_use.empty() ? 0 : *indexes_in_use.rbegin() + 1; if (max_index_in_use < max_index) { // max index in use is lower than the max_index // reduce the max_index max_index = indexes_in_use.empty() ? 0 : max_index_in_use + 1; + if (file_size_monitor) { + file_size_monitor->Decrease(old_max - max_index); + } // we can remove any free_indexes that are larger than the current max_index while (!free_indexes.empty()) { auto max_entry = *free_indexes.rbegin(); @@ -56,7 +79,12 @@ bool BlockIndexManager::HasFreeBlocks() { idx_t BlockIndexManager::GetNewBlockIndexInternal() { if (free_indexes.empty()) { - return max_index++; + auto new_index = max_index; + max_index++; + if (file_size_monitor) { + file_size_monitor->Increase(1); + } + return new_index; } auto entry = free_indexes.begin(); auto index = *entry; @@ -69,9 +97,10 @@ idx_t BlockIndexManager::GetNewBlockIndexInternal() { //===--------------------------------------------------------------------===// TemporaryFileHandle::TemporaryFileHandle(idx_t temp_file_count, DatabaseInstance &db, const string &temp_directory, - idx_t index) + idx_t index, TemporaryFileManager &manager) : max_allowed_index((1 << temp_file_count) * MAX_ALLOWED_INDEX_BASE), db(db), file_index(index), - path(FileSystem::GetFileSystem(db).JoinPath(temp_directory, "duckdb_temp_storage-" + to_string(index) + ".tmp")) { + path(FileSystem::GetFileSystem(db).JoinPath(temp_directory, "duckdb_temp_storage-" + to_string(index) + ".tmp")), + index_manager(make_uniq(manager)) { } TemporaryFileHandle::TemporaryFileLock::TemporaryFileLock(mutex &mutex) : lock(mutex) { @@ -135,7 +164,7 @@ void TemporaryFileHandle::CreateFileIfNotExists(TemporaryFileLock &) { return; } auto &fs = FileSystem::GetFileSystem(db); - auto open_flags = FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE; + uint8_t open_flags = FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE; handle = fs.OpenFile(path, open_flags); } @@ -225,7 +254,12 @@ bool TemporaryFileIndex::IsValid() const { //===--------------------------------------------------------------------===// TemporaryFileManager::TemporaryFileManager(DatabaseInstance &db, const string &temp_directory_p) - : db(db), temp_directory(temp_directory_p) { + : db(db), temp_directory(temp_directory_p), size_on_disk(0) { +} + +TemporaryFileManager::~TemporaryFileManager() { + files.clear(); + D_ASSERT(size_on_disk.load() == 0); } TemporaryFileManager::TemporaryManagerLock::TemporaryManagerLock(mutex &mutex) : lock(mutex) { @@ -250,7 +284,7 @@ void TemporaryFileManager::WriteTemporaryBuffer(block_id_t block_id, FileBuffer if (!handle) { // no existing handle to write to; we need to create & open a new file auto new_file_index = index_manager.GetNewBlockIndex(); - auto new_file = make_uniq(files.size(), db, temp_directory, new_file_index); + auto new_file = make_uniq(files.size(), db, temp_directory, new_file_index, *this); handle = new_file.get(); files[new_file_index] = std::move(new_file); @@ -269,6 +303,18 @@ bool TemporaryFileManager::HasTemporaryBuffer(block_id_t block_id) { return used_blocks.find(block_id) != used_blocks.end(); } +idx_t TemporaryFileManager::GetTotalUsedSpaceInBytes() { + return size_on_disk.load(); +} + +void TemporaryFileManager::IncreaseSizeOnDisk(idx_t bytes) { + size_on_disk += bytes; +} + +void TemporaryFileManager::DecreaseSizeOnDisk(idx_t bytes) { + size_on_disk -= bytes; +} + unique_ptr TemporaryFileManager::ReadTemporaryBuffer(block_id_t id, unique_ptr reusable_buffer) { TemporaryFileIndex index; diff --git a/test/api/test_reset.cpp b/test/api/test_reset.cpp index 18f795d837c9..e06f965cd703 100644 --- a/test/api/test_reset.cpp +++ b/test/api/test_reset.cpp @@ -87,6 +87,7 @@ OptionValueSet &GetValueForOption(const string &name) { {"immediate_transaction_mode", {true}}, {"max_expression_depth", {50}}, {"max_memory", {"4.0 GiB"}}, + {"max_temp_directory_size", {"10.0 GiB"}}, {"memory_limit", {"4.0 GiB"}}, {"ordered_aggregate_threshold", {Value::UBIGINT(idx_t(1) << 12)}}, {"null_order", {"nulls_first"}}, From 1e441a68212c9fcefce95d2b63d8622e6e9985a5 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 1 Mar 2024 23:11:12 +0100 Subject: [PATCH 002/201] create the exception, thrown whenever we try to increase the temp directory size beyond the max --- src/main/config.cpp | 5 +++-- src/main/settings/settings.cpp | 3 --- src/storage/temporary_file_manager.cpp | 15 ++++++++++++++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/main/config.cpp b/src/main/config.cpp index cbc030783264..78e5d36b1014 100644 --- a/src/main/config.cpp +++ b/src/main/config.cpp @@ -94,6 +94,7 @@ static ConfigurationOption internal_options[] = {DUCKDB_GLOBAL(AccessModeSetting DUCKDB_LOCAL(IntegerDivisionSetting), DUCKDB_LOCAL(MaximumExpressionDepthSetting), DUCKDB_GLOBAL(MaximumMemorySetting), + DUCKDB_GLOBAL(MaximumTempDirectorySize), DUCKDB_GLOBAL(OldImplicitCasting), DUCKDB_GLOBAL_ALIAS("memory_limit", MaximumMemorySetting), DUCKDB_GLOBAL_ALIAS("null_order", DefaultNullOrderSetting), @@ -261,9 +262,9 @@ void DBConfig::SetDefaultMaxMemory() { void DBConfig::SetDefaultMaxSwapSpace() { auto memory_limit = options.maximum_memory; - if (!TryMultiplyOperator::Operation(memory_limit, 2, options.maximum_memory)) { + if (!TryMultiplyOperator::Operation(memory_limit, static_cast(2), options.maximum_swap_space)) { // Can't default to 2x memory: fall back to 5GB instead - options.maximum_memory = ParseMemoryLimit("5GB"); + options.maximum_swap_space = ParseMemoryLimit("5GB"); } } diff --git a/src/main/settings/settings.cpp b/src/main/settings/settings.cpp index e16c2f795f7b..bdda84adf07f 100644 --- a/src/main/settings/settings.cpp +++ b/src/main/settings/settings.cpp @@ -953,9 +953,6 @@ Value MaximumMemorySetting::GetSetting(ClientContext &context) { //===--------------------------------------------------------------------===// void MaximumTempDirectorySize::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { config.options.maximum_swap_space = DBConfig::ParseMemoryLimit(input.ToString()); - if (db) { - BufferManager::GetBufferManager(*db).SetLimit(config.options.maximum_swap_space); - } } void MaximumTempDirectorySize::ResetGlobal(DatabaseInstance *db, DBConfig &config) { diff --git a/src/storage/temporary_file_manager.cpp b/src/storage/temporary_file_manager.cpp index 6a92048f5046..c97f73f197a1 100644 --- a/src/storage/temporary_file_manager.cpp +++ b/src/storage/temporary_file_manager.cpp @@ -259,7 +259,6 @@ TemporaryFileManager::TemporaryFileManager(DatabaseInstance &db, const string &t TemporaryFileManager::~TemporaryFileManager() { files.clear(); - D_ASSERT(size_on_disk.load() == 0); } TemporaryFileManager::TemporaryManagerLock::TemporaryManagerLock(mutex &mutex) : lock(mutex) { @@ -308,7 +307,21 @@ idx_t TemporaryFileManager::GetTotalUsedSpaceInBytes() { } void TemporaryFileManager::IncreaseSizeOnDisk(idx_t bytes) { + auto &config = DBConfig::GetConfig(db); + auto max_swap_space = config.options.maximum_swap_space; + + auto current_size_on_disk = size_on_disk.load(); size_on_disk += bytes; + if (size_on_disk.load() > max_swap_space) { + auto used = StringUtil::BytesToHumanReadableString(current_size_on_disk); + auto max = StringUtil::BytesToHumanReadableString(max_swap_space); + auto data_size = StringUtil::BytesToHumanReadableString(bytes); + throw OutOfMemoryException(R"(failed to offload data block of size %s (%s/%s used). +This limit was set by the 'max_temp_directory_size' setting. +This defaults to twice the size of 'max_memory'. +You can adjust this setting, by using (for example) PRAGMA max_temp_directory_size='10GiB')", + data_size, used, max); + } } void TemporaryFileManager::DecreaseSizeOnDisk(idx_t bytes) { From b7d999792a3e527a8abc709fc52786a356396bc7 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 7 Mar 2024 11:45:30 +0100 Subject: [PATCH 003/201] increase to 5x memory limit --- src/include/duckdb/main/config.hpp | 2 +- src/main/config.cpp | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/include/duckdb/main/config.hpp b/src/include/duckdb/main/config.hpp index b1b6a0fc0756..397376b80333 100644 --- a/src/include/duckdb/main/config.hpp +++ b/src/include/duckdb/main/config.hpp @@ -118,7 +118,7 @@ struct DBConfigOptions { string autoinstall_extension_repo = ""; //! The maximum memory used by the database system (in bytes). Default: 80% of System available memory idx_t maximum_memory = (idx_t)-1; - //! The maximum size of the 'temp_directory' folder when set (in bytes). Default 2x 'maximum_memory' + //! The maximum size of the 'temp_directory' folder when set (in bytes). Default 5x 'maximum_memory' idx_t maximum_swap_space = (idx_t)-1; //! The maximum amount of CPU threads used by the database system. Default: all available. idx_t maximum_threads = (idx_t)-1; diff --git a/src/main/config.cpp b/src/main/config.cpp index 78e5d36b1014..5159bf22864f 100644 --- a/src/main/config.cpp +++ b/src/main/config.cpp @@ -262,9 +262,9 @@ void DBConfig::SetDefaultMaxMemory() { void DBConfig::SetDefaultMaxSwapSpace() { auto memory_limit = options.maximum_memory; - if (!TryMultiplyOperator::Operation(memory_limit, static_cast(2), options.maximum_swap_space)) { - // Can't default to 2x memory: fall back to 5GB instead - options.maximum_swap_space = ParseMemoryLimit("5GB"); + if (!TryMultiplyOperator::Operation(memory_limit, static_cast(5), options.maximum_swap_space)) { + // Can't default to 5x memory: fall back to same limit as memory instead + options.maximum_swap_space = memory_limit; } } From 3d105684328589bc150f7d7c4abea037a81bbc86 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 8 Mar 2024 10:09:36 +0100 Subject: [PATCH 004/201] collect information about the disk when possible --- src/common/file_system.cpp | 31 ++++++ src/include/duckdb/common/file_system.hpp | 2 + src/include/duckdb/main/config.hpp | 35 ++++++- src/main/config.cpp | 21 +++- src/main/database.cpp | 2 +- src/main/settings/settings.cpp | 10 +- src/storage/temporary_file_manager.cpp | 5 + test/sql/storage/max_swap_space.test | 112 ++++++++++++++++++++++ 8 files changed, 208 insertions(+), 10 deletions(-) create mode 100644 test/sql/storage/max_swap_space.test diff --git a/src/common/file_system.cpp b/src/common/file_system.cpp index f2124165fdc7..f21e6c46dd79 100644 --- a/src/common/file_system.cpp +++ b/src/common/file_system.cpp @@ -12,6 +12,7 @@ #include "duckdb/main/database.hpp" #include "duckdb/main/extension_helper.hpp" #include "duckdb/common/windows_util.hpp" +#include "duckdb/common/operator/multiply.hpp" #include #include @@ -21,6 +22,7 @@ #include #include #include +#include #include #include @@ -98,6 +100,23 @@ idx_t FileSystem::GetAvailableMemory() { return max_memory; } +idx_t FileSystem::GetAvailableDiskSpace(const string &path) { + struct statvfs vfs; + + if (statvfs(path.c_str(), &vfs) == -1) { + return DConstants::INVALID_INDEX; + } + auto block_size = vfs.f_frsize; + // These are the blocks available for creating new files or extending existing ones + auto available_blocks = vfs.f_bavail; + idx_t available_disk_space = DConstants::INVALID_INDEX; + if (!TryMultiplyOperator::Operation(static_cast(block_size), static_cast(available_blocks), + available_disk_space)) { + return DConstants::INVALID_INDEX; + } + return available_disk_space; +} + string FileSystem::GetWorkingDirectory() { auto buffer = make_unsafe_uniq_array(PATH_MAX); char *ret = getcwd(buffer.get(), PATH_MAX); @@ -198,6 +217,18 @@ idx_t FileSystem::GetAvailableMemory() { return DConstants::INVALID_INDEX; } +idx_t FileSystem::GetAvailableDiskSpace(const string &path) { + ULARGE_INTEGER available_bytes, total_bytes, free_bytes; + + auto unicode_path = WindowsUtil::UTF8ToUnicode(path.c_str()); + if (!GetDiskFreeSpaceExW(unicode_path.c_str(), &available_bytes, &total_bytes, &free_bytes)) { + return DConstants::INVALID_INDEX; + } + (void)total_bytes; + (void)free_bytes; + return NumericCast(available_bytes.QuadPart); +} + string FileSystem::GetWorkingDirectory() { idx_t count = GetCurrentDirectoryW(0, nullptr); if (count == 0) { diff --git a/src/include/duckdb/common/file_system.hpp b/src/include/duckdb/common/file_system.hpp index 3dc81acfa2a3..a082e81db614 100644 --- a/src/include/duckdb/common/file_system.hpp +++ b/src/include/duckdb/common/file_system.hpp @@ -186,6 +186,8 @@ class FileSystem { DUCKDB_API virtual string ExpandPath(const string &path); //! Returns the system-available memory in bytes. Returns DConstants::INVALID_INDEX if the system function fails. DUCKDB_API static idx_t GetAvailableMemory(); + //! Returns the space available on the disk. Returns DConstants::INVALID_INDEX if the information was not available. + DUCKDB_API static idx_t GetAvailableDiskSpace(const string &path); //! Path separator for path DUCKDB_API virtual string PathSeparator(const string &path); //! Checks if path is starts with separator (i.e., '/' on UNIX '\\' on Windows) diff --git a/src/include/duckdb/main/config.hpp b/src/include/duckdb/main/config.hpp index 397376b80333..ec2658afd2bb 100644 --- a/src/include/duckdb/main/config.hpp +++ b/src/include/duckdb/main/config.hpp @@ -61,6 +61,37 @@ typedef void (*reset_global_function_t)(DatabaseInstance *db, DBConfig &config); typedef void (*reset_local_function_t)(ClientContext &context); typedef Value (*get_setting_function_t)(ClientContext &context); +struct NumericSetting { +public: + NumericSetting() : value(0), set_by_user(false) { + } + +public: + NumericSetting &operator=(idx_t val) = delete; + +public: + operator idx_t() { + return value; + } + +public: + bool ExplicitlySet() const { + return set_by_user; + } + void SetDefault(idx_t val) { + value = val; + set_by_user = false; + } + void SetExplicit(idx_t val) { + value = val; + set_by_user = true; + } + +private: + idx_t value; + bool set_by_user; +}; + struct ConfigurationOption { const char *name; const char *description; @@ -119,7 +150,7 @@ struct DBConfigOptions { //! The maximum memory used by the database system (in bytes). Default: 80% of System available memory idx_t maximum_memory = (idx_t)-1; //! The maximum size of the 'temp_directory' folder when set (in bytes). Default 5x 'maximum_memory' - idx_t maximum_swap_space = (idx_t)-1; + NumericSetting maximum_swap_space = NumericSetting(); //! The maximum amount of CPU threads used by the database system. Default: all available. idx_t maximum_threads = (idx_t)-1; //! The number of external threads that work on DuckDB tasks. Default: 1. @@ -278,7 +309,7 @@ struct DBConfig { DUCKDB_API IndexTypeSet &GetIndexTypes(); static idx_t GetSystemMaxThreads(FileSystem &fs); void SetDefaultMaxMemory(); - void SetDefaultMaxSwapSpace(); + void SetDefaultMaxSwapSpace(optional_ptr db); OrderType ResolveOrder(OrderType order_type) const; OrderByNullType ResolveNullOrder(OrderType order_type, OrderByNullType null_type) const; diff --git a/src/main/config.cpp b/src/main/config.cpp index 5159bf22864f..9aa8c50712cb 100644 --- a/src/main/config.cpp +++ b/src/main/config.cpp @@ -260,12 +260,23 @@ void DBConfig::SetDefaultMaxMemory() { } } -void DBConfig::SetDefaultMaxSwapSpace() { - auto memory_limit = options.maximum_memory; - if (!TryMultiplyOperator::Operation(memory_limit, static_cast(5), options.maximum_swap_space)) { - // Can't default to 5x memory: fall back to same limit as memory instead - options.maximum_swap_space = memory_limit; +void DBConfig::SetDefaultMaxSwapSpace(optional_ptr db) { + options.maximum_swap_space.SetDefault(0); + if (options.temporary_directory.empty()) { + return; + } + if (!db) { + return; + } + auto &fs = FileSystem::GetFileSystem(*db); + if (!fs.DirectoryExists(options.temporary_directory)) { + // Directory doesnt exist yet, we will look up the disk space once we have created the directory + // FIXME: do we want to proactively create the directory instead ??? + return; } + // Use the available disk space if temp directory is set + auto disk_space = FileSystem::GetAvailableDiskSpace(options.temporary_directory); + options.maximum_swap_space.SetDefault(disk_space); } void DBConfig::CheckLock(const string &name) { diff --git a/src/main/database.cpp b/src/main/database.cpp index 9744fc97d808..7a0ef6fcc99c 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -334,7 +334,7 @@ void DatabaseInstance::Configure(DBConfig &new_config) { config.SetDefaultMaxMemory(); } if (config.options.maximum_swap_space == (idx_t)-1) { - config.SetDefaultMaxSwapSpace(); + config.SetDefaultMaxSwapSpace(this); } if (new_config.options.maximum_threads == (idx_t)-1) { config.options.maximum_threads = config.GetSystemMaxThreads(*config.file_system); diff --git a/src/main/settings/settings.cpp b/src/main/settings/settings.cpp index bdda84adf07f..674188f3b714 100644 --- a/src/main/settings/settings.cpp +++ b/src/main/settings/settings.cpp @@ -952,11 +952,11 @@ Value MaximumMemorySetting::GetSetting(ClientContext &context) { // Maximum Temp Directory Size //===--------------------------------------------------------------------===// void MaximumTempDirectorySize::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.maximum_swap_space = DBConfig::ParseMemoryLimit(input.ToString()); + config.options.maximum_swap_space.SetExplicit(DBConfig::ParseMemoryLimit(input.ToString())); } void MaximumTempDirectorySize::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.SetDefaultMaxSwapSpace(); + config.SetDefaultMaxSwapSpace(db); } Value MaximumTempDirectorySize::GetSetting(ClientContext &context) { @@ -1229,6 +1229,12 @@ Value SecretDirectorySetting::GetSetting(ClientContext &context) { void TempDirectorySetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { config.options.temporary_directory = input.ToString(); config.options.use_temporary_directory = !config.options.temporary_directory.empty(); + if (!config.options.temporary_directory.empty()) { + // Maximum swap space isn't set explicitly, initialize to default + if (!config.options.maximum_swap_space.ExplicitlySet()) { + config.SetDefaultMaxSwapSpace(db); + } + } if (db) { auto &buffer_manager = BufferManager::GetBufferManager(*db); buffer_manager.SetTemporaryDirectory(config.options.temporary_directory); diff --git a/src/storage/temporary_file_manager.cpp b/src/storage/temporary_file_manager.cpp index c97f73f197a1..90761f5d6010 100644 --- a/src/storage/temporary_file_manager.cpp +++ b/src/storage/temporary_file_manager.cpp @@ -194,8 +194,13 @@ TemporaryDirectoryHandle::TemporaryDirectoryHandle(DatabaseInstance &db, string auto &fs = FileSystem::GetFileSystem(db); if (!temp_directory.empty()) { if (!fs.DirectoryExists(temp_directory)) { + auto &config = DBConfig::GetConfig(db); fs.CreateDirectory(temp_directory); created_directory = true; + // Maximum swap space isn't set explicitly, initialize to default + if (!config.options.maximum_swap_space.ExplicitlySet()) { + config.SetDefaultMaxSwapSpace(&db); + } } } } diff --git a/test/sql/storage/max_swap_space.test b/test/sql/storage/max_swap_space.test new file mode 100644 index 000000000000..2f52b3c7940a --- /dev/null +++ b/test/sql/storage/max_swap_space.test @@ -0,0 +1,112 @@ +# name: test/sql/storage/max_swap_space.test +# group: [storage] + +require skip_reload + +statement ok +set temp_directory=''; + +statement ok +PRAGMA memory_limit='2MB' + +# --- Set by default to 0 when temp_directory is not set --- + +# If 'temp_directory' is not set, this defaults to 0 +query I +select current_setting('max_temp_directory_size') +---- +0 bytes + +# Set the max size explicitly + +statement ok +set max_temp_directory_size='15gb' + +# Then reset it, default should be 0 again + +statement ok +reset max_temp_directory_size; + +query I +select current_setting('max_temp_directory_size') +---- +0 bytes + +# --- Set by default to the available disk space when temp_directory exists --- + +statement ok +set temp_directory = '__TEST_DIR__'; + +# '__TEST_DIR__' is guaranteed to exist, we can get the disk space +query I +select current_setting('max_temp_directory_size') a where a == '0 bytes' +---- + +# So the reported max size should not be 0 bytes +query I +select current_setting('max_temp_directory_size') a where a == '0 bytes' +---- + +# --- Set explicitly by the user --- + +# If we set 'max_temp_directory_size' explicitly, it will not be overridden +statement ok +set max_temp_directory_size='15gb' + +# Reported size should not be 0, we set it explicitly +query I +select current_setting('max_temp_directory_size') a where a == '0 bytes' +---- + +query I nosort unchanged +select current_setting('max_temp_directory_size') +---- + +# When we change the temp_directory to something that doesnt exist +statement ok +set temp_directory = '__TEST_DIR__/does_not_exist' + +query I nosort unchanged +select current_setting('max_temp_directory_size') +---- + +# When we change the temp_directory to something that does exist +statement ok +set temp_directory = '__TEST_DIR__' + +query I nosort unchanged +select current_setting('max_temp_directory_size') +---- + +# When we reset the temp_directory .. +statement ok +reset temp_directory; + +query I nosort unchanged +select current_setting('max_temp_directory_size') +---- + +# --- Set to the available disk space when we create the (previously non-existant) 'temp_directory' + +statement ok +reset max_temp_directory_size; + +# When we change the temp_directory to something that doesnt exist +statement ok +set temp_directory = '__TEST_DIR__/does_not_exist' + +query I +select current_setting('max_temp_directory_size') +---- +0 bytes + +statement ok +CREATE TABLE t2 AS SELECT * FROM range(1000000); + +# Reported size should not be 0, the directory was created +query I +select current_setting('max_temp_directory_size') a where a == '0 bytes' +---- + +## TODO: test that the explicitly set value by the user does not get overridden when 'temp_directory' is set to a directory that doesn't exist yet +# when the 'temp_directory' is created by us - the explicitly set value should not be overridden From a0d98f41987e1500ba9814a721c8cfc37e780be2 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 8 Mar 2024 10:17:17 +0100 Subject: [PATCH 005/201] further thinking --- src/main/settings/settings.cpp | 2 ++ test/sql/storage/max_swap_space.test | 13 ++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/main/settings/settings.cpp b/src/main/settings/settings.cpp index 674188f3b714..90ef20522a24 100644 --- a/src/main/settings/settings.cpp +++ b/src/main/settings/settings.cpp @@ -952,6 +952,8 @@ Value MaximumMemorySetting::GetSetting(ClientContext &context) { // Maximum Temp Directory Size //===--------------------------------------------------------------------===// void MaximumTempDirectorySize::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + // FIXME: should this not use 'SetExplicit' when the value is 0? + // So it acts as RESET instead when 0 is passed? config.options.maximum_swap_space.SetExplicit(DBConfig::ParseMemoryLimit(input.ToString())); } diff --git a/test/sql/storage/max_swap_space.test b/test/sql/storage/max_swap_space.test index 2f52b3c7940a..0b200b8039c8 100644 --- a/test/sql/storage/max_swap_space.test +++ b/test/sql/storage/max_swap_space.test @@ -18,12 +18,15 @@ select current_setting('max_temp_directory_size') 0 bytes # Set the max size explicitly - statement ok set max_temp_directory_size='15gb' -# Then reset it, default should be 0 again +# Should not be 0 anymore +query I +select current_setting('max_temp_directory_size') a where a == '0 bytes' +---- +# Then reset it, default should be 0 again statement ok reset max_temp_directory_size; @@ -42,11 +45,6 @@ query I select current_setting('max_temp_directory_size') a where a == '0 bytes' ---- -# So the reported max size should not be 0 bytes -query I -select current_setting('max_temp_directory_size') a where a == '0 bytes' ----- - # --- Set explicitly by the user --- # If we set 'max_temp_directory_size' explicitly, it will not be overridden @@ -88,6 +86,7 @@ select current_setting('max_temp_directory_size') # --- Set to the available disk space when we create the (previously non-existant) 'temp_directory' +# Reset it so it's no longer set explicitly statement ok reset max_temp_directory_size; From 8c8ffe6ef20c7983f336c3e9d001a447b848367c Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 8 Mar 2024 10:27:13 +0100 Subject: [PATCH 006/201] test that explicitly set values are not overridden when we create the temp_directory --- .../{ => temp_directory}/max_swap_space.test | 7 ++--- .../max_swap_space_explicit.test | 29 +++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) rename test/sql/storage/{ => temp_directory}/max_swap_space.test (88%) create mode 100644 test/sql/storage/temp_directory/max_swap_space_explicit.test diff --git a/test/sql/storage/max_swap_space.test b/test/sql/storage/temp_directory/max_swap_space.test similarity index 88% rename from test/sql/storage/max_swap_space.test rename to test/sql/storage/temp_directory/max_swap_space.test index 0b200b8039c8..0c93cd2852c9 100644 --- a/test/sql/storage/max_swap_space.test +++ b/test/sql/storage/temp_directory/max_swap_space.test @@ -1,5 +1,5 @@ -# name: test/sql/storage/max_swap_space.test -# group: [storage] +# name: test/sql/storage/temp_directory/max_swap_space.test +# group: [temp_directory] require skip_reload @@ -106,6 +106,3 @@ CREATE TABLE t2 AS SELECT * FROM range(1000000); query I select current_setting('max_temp_directory_size') a where a == '0 bytes' ---- - -## TODO: test that the explicitly set value by the user does not get overridden when 'temp_directory' is set to a directory that doesn't exist yet -# when the 'temp_directory' is created by us - the explicitly set value should not be overridden diff --git a/test/sql/storage/temp_directory/max_swap_space_explicit.test b/test/sql/storage/temp_directory/max_swap_space_explicit.test new file mode 100644 index 000000000000..ea85ced65be3 --- /dev/null +++ b/test/sql/storage/temp_directory/max_swap_space_explicit.test @@ -0,0 +1,29 @@ +# name: test/sql/storage/temp_directory/max_swap_space_explicit.test +# group: [temp_directory] + +require skip_reload + +statement ok +PRAGMA memory_limit='2MB' + +# --- Not changed when set explicitly by the user + +# If we set 'max_temp_directory_size' explicitly, it will not be overridden +statement ok +set max_temp_directory_size='15gb' + +# When we change the temp_directory to something that doesnt exist +statement ok +set temp_directory = '__TEST_DIR__/this_directory_should_not_exist__swap_space' + +query I nosort explicitly_set +select current_setting('max_temp_directory_size') +---- + +statement ok +CREATE TABLE t2 AS SELECT * FROM range(1000000); + +# The 'temp_directory' was created, but the value of 'max_temp_directory_size' was set explicitly, so it was unaltered +query I nosort explicitly_set +select current_setting('max_temp_directory_size') +---- From d3ecab62a5584241d0d4cce01c8fa0180d84be82 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 8 Mar 2024 12:50:36 +0100 Subject: [PATCH 007/201] add initial tests --- src/storage/temporary_file_manager.cpp | 8 ++--- .../temp_directory/max_swap_space_error.test | 36 +++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) create mode 100644 test/sql/storage/temp_directory/max_swap_space_error.test diff --git a/src/storage/temporary_file_manager.cpp b/src/storage/temporary_file_manager.cpp index 90761f5d6010..0fda4781124b 100644 --- a/src/storage/temporary_file_manager.cpp +++ b/src/storage/temporary_file_manager.cpp @@ -80,10 +80,10 @@ bool BlockIndexManager::HasFreeBlocks() { idx_t BlockIndexManager::GetNewBlockIndexInternal() { if (free_indexes.empty()) { auto new_index = max_index; - max_index++; if (file_size_monitor) { file_size_monitor->Increase(1); } + max_index++; return new_index; } auto entry = free_indexes.begin(); @@ -316,17 +316,17 @@ void TemporaryFileManager::IncreaseSizeOnDisk(idx_t bytes) { auto max_swap_space = config.options.maximum_swap_space; auto current_size_on_disk = size_on_disk.load(); - size_on_disk += bytes; - if (size_on_disk.load() > max_swap_space) { + if (current_size_on_disk + bytes > max_swap_space) { auto used = StringUtil::BytesToHumanReadableString(current_size_on_disk); auto max = StringUtil::BytesToHumanReadableString(max_swap_space); auto data_size = StringUtil::BytesToHumanReadableString(bytes); throw OutOfMemoryException(R"(failed to offload data block of size %s (%s/%s used). This limit was set by the 'max_temp_directory_size' setting. -This defaults to twice the size of 'max_memory'. +By default, this setting utilizes the available disk space on the drive where the 'temp_directory' is located. You can adjust this setting, by using (for example) PRAGMA max_temp_directory_size='10GiB')", data_size, used, max); } + size_on_disk += bytes; } void TemporaryFileManager::DecreaseSizeOnDisk(idx_t bytes) { diff --git a/test/sql/storage/temp_directory/max_swap_space_error.test b/test/sql/storage/temp_directory/max_swap_space_error.test new file mode 100644 index 000000000000..80e75df91904 --- /dev/null +++ b/test/sql/storage/temp_directory/max_swap_space_error.test @@ -0,0 +1,36 @@ +# name: test/sql/storage/temp_directory/max_swap_space_error.test +# group: [temp_directory] + +require skip_reload + +# Set a temp_directory to offload data +statement ok +set temp_directory='__TEST_DIR__/max_swap_space_reached' + +# Ensure the temp_directory is used +statement ok +PRAGMA memory_limit='2MB' + + +# 0 blocks +statement ok +set max_temp_directory_size='0KiB' + +statement error +CREATE OR REPLACE TABLE t2 AS SELECT * FROM range(1000000); +---- +failed to offload data block of size 256.0 KiB (0 bytes/0 bytes used) + +# 1 block max +statement ok +set max_temp_directory_size='256KiB' + +statement error +CREATE OR REPLACE TABLE t2 AS SELECT * FROM range(1000000); +---- +failed to offload data block of size 256.0 KiB (256.0 KiB/256.0 KiB used) + +statement error +CREATE OR REPLACE TABLE t2 AS SELECT * FROM range(1000000); +---- +failed to offload data block of size 256.0 KiB (256.0 KiB/256.0 KiB used) From a97bcb78c0008f7416eefea1512328a73064a7c2 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 8 Mar 2024 13:14:26 +0100 Subject: [PATCH 008/201] more tests with different max swap sizes --- .../temp_directory/max_swap_space_error.test | 38 ++++++++++++++++++- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/test/sql/storage/temp_directory/max_swap_space_error.test b/test/sql/storage/temp_directory/max_swap_space_error.test index 80e75df91904..68c55f520665 100644 --- a/test/sql/storage/temp_directory/max_swap_space_error.test +++ b/test/sql/storage/temp_directory/max_swap_space_error.test @@ -9,8 +9,7 @@ set temp_directory='__TEST_DIR__/max_swap_space_reached' # Ensure the temp_directory is used statement ok -PRAGMA memory_limit='2MB' - +PRAGMA memory_limit='1024KiB' # 0 blocks statement ok @@ -21,6 +20,11 @@ CREATE OR REPLACE TABLE t2 AS SELECT * FROM range(1000000); ---- failed to offload data block of size 256.0 KiB (0 bytes/0 bytes used) +query I +select "size" from duckdb_temporary_files() +---- +0 + # 1 block max statement ok set max_temp_directory_size='256KiB' @@ -34,3 +38,33 @@ statement error CREATE OR REPLACE TABLE t2 AS SELECT * FROM range(1000000); ---- failed to offload data block of size 256.0 KiB (256.0 KiB/256.0 KiB used) + +query I +select "size" from duckdb_temporary_files() +---- + +# 6 blocks +statement ok +set max_temp_directory_size='1536KiB' + +statement ok +pragma threads=2; + +statement ok +set preserve_insertion_order=true; + +# This is 1600000 bytes of BIGINT data, which is roughly 6.1 blocks +# Because our memory limit is set at 1MiB (4 blocks) this works +statement ok +CREATE OR REPLACE TABLE t2 AS SELECT * FROM range(200000); + +# When we increase the size to 2400000 bytes of BIGINT data (9.1 blocks) this errors +statement error +CREATE OR REPLACE TABLE t2 AS SELECT * FROM range(300000); +---- +failed to offload data block of size 256.0 KiB (1.5 MiB/1.5 MiB used) + +query I +select "size" from duckdb_temporary_files() +---- +1572864 From 0227e2d3ea730b465402d4b6b0d1583bfa7bfb52 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 8 Mar 2024 13:20:00 +0100 Subject: [PATCH 009/201] fix up comment --- src/include/duckdb/main/config.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/duckdb/main/config.hpp b/src/include/duckdb/main/config.hpp index ec2658afd2bb..9bedd5d6160f 100644 --- a/src/include/duckdb/main/config.hpp +++ b/src/include/duckdb/main/config.hpp @@ -149,7 +149,7 @@ struct DBConfigOptions { string autoinstall_extension_repo = ""; //! The maximum memory used by the database system (in bytes). Default: 80% of System available memory idx_t maximum_memory = (idx_t)-1; - //! The maximum size of the 'temp_directory' folder when set (in bytes). Default 5x 'maximum_memory' + //! The maximum size of the 'temp_directory' folder when set (in bytes) NumericSetting maximum_swap_space = NumericSetting(); //! The maximum amount of CPU threads used by the database system. Default: all available. idx_t maximum_threads = (idx_t)-1; From d56137adec45305b3a9c54f3ba8935d39504c975 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 8 Mar 2024 13:23:19 +0100 Subject: [PATCH 010/201] check if the config was set explicitly or not in DatabaseInstance::Configure --- src/main/database.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/database.cpp b/src/main/database.cpp index 7a0ef6fcc99c..39d554147dbc 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -333,7 +333,7 @@ void DatabaseInstance::Configure(DBConfig &new_config) { if (config.options.maximum_memory == (idx_t)-1) { config.SetDefaultMaxMemory(); } - if (config.options.maximum_swap_space == (idx_t)-1) { + if (!config.options.maximum_swap_space.ExplicitlySet()) { config.SetDefaultMaxSwapSpace(this); } if (new_config.options.maximum_threads == (idx_t)-1) { From b37d1932f5e1778daacaccc47376f263c4772cb8 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 8 Mar 2024 15:02:10 +0100 Subject: [PATCH 011/201] avoid modifying the passed in DBConfig --- src/include/duckdb/main/database.hpp | 2 +- src/main/database.cpp | 45 +++---- src/storage/temporary_file_manager.cpp | 2 +- .../max_swap_space_inmemory.test | 112 ++++++++++++++++++ ...ce.test => max_swap_space_persistent.test} | 11 +- 5 files changed, 147 insertions(+), 25 deletions(-) create mode 100644 test/sql/storage/temp_directory/max_swap_space_inmemory.test rename test/sql/storage/temp_directory/{max_swap_space.test => max_swap_space_persistent.test} (90%) diff --git a/src/include/duckdb/main/database.hpp b/src/include/duckdb/main/database.hpp index a5798c8059b5..8bef82e99acd 100644 --- a/src/include/duckdb/main/database.hpp +++ b/src/include/duckdb/main/database.hpp @@ -62,7 +62,7 @@ class DatabaseInstance : public std::enable_shared_from_this { void Initialize(const char *path, DBConfig *config); void CreateMainDatabase(); - void Configure(DBConfig &config); + void Configure(DBConfig &config, const char *path); private: unique_ptr buffer_manager; diff --git a/src/main/database.cpp b/src/main/database.cpp index 39d554147dbc..5b3c04db2983 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -185,27 +185,7 @@ void DatabaseInstance::Initialize(const char *database_path, DBConfig *user_conf config_ptr = user_config; } - if (config_ptr->options.duckdb_api.empty()) { - config_ptr->SetOptionByName("duckdb_api", "cpp"); - } - - if (config_ptr->options.temporary_directory.empty() && database_path) { - // no directory specified: use default temp path - config_ptr->options.temporary_directory = string(database_path) + ".tmp"; - - // special treatment for in-memory mode - if (strcmp(database_path, IN_MEMORY_PATH) == 0) { - config_ptr->options.temporary_directory = ".tmp"; - } - } - - if (database_path) { - config_ptr->options.database_path = database_path; - } else { - config_ptr->options.database_path.clear(); - } - - Configure(*config_ptr); + Configure(*config_ptr, database_path); if (user_config && !user_config->options.use_temporary_directory) { // temporary directories explicitly disabled @@ -316,8 +296,29 @@ Allocator &Allocator::Get(AttachedDatabase &db) { return Allocator::Get(db.GetDatabase()); } -void DatabaseInstance::Configure(DBConfig &new_config) { +void DatabaseInstance::Configure(DBConfig &new_config, const char *database_path) { config.options = new_config.options; + + if (new_config.options.duckdb_api.empty()) { + config.SetOptionByName("duckdb_api", "cpp"); + } + + if (new_config.options.temporary_directory.empty() && database_path) { + // no directory specified: use default temp path + config.options.temporary_directory = string(database_path) + ".tmp"; + + // special treatment for in-memory mode + if (strcmp(database_path, IN_MEMORY_PATH) == 0) { + config.options.temporary_directory = ".tmp"; + } + } + + if (database_path) { + config.options.database_path = database_path; + } else { + config.options.database_path.clear(); + } + if (config.options.access_mode == AccessMode::UNDEFINED) { config.options.access_mode = AccessMode::READ_WRITE; } diff --git a/src/storage/temporary_file_manager.cpp b/src/storage/temporary_file_manager.cpp index 0fda4781124b..d4fa09470596 100644 --- a/src/storage/temporary_file_manager.cpp +++ b/src/storage/temporary_file_manager.cpp @@ -52,7 +52,7 @@ bool BlockIndexManager::RemoveIndex(idx_t index) { if (max_index_in_use < max_index) { // max index in use is lower than the max_index // reduce the max_index - max_index = indexes_in_use.empty() ? 0 : max_index_in_use + 1; + max_index = max_index_in_use; if (file_size_monitor) { file_size_monitor->Decrease(old_max - max_index); } diff --git a/test/sql/storage/temp_directory/max_swap_space_inmemory.test b/test/sql/storage/temp_directory/max_swap_space_inmemory.test new file mode 100644 index 000000000000..9de30a592cb0 --- /dev/null +++ b/test/sql/storage/temp_directory/max_swap_space_inmemory.test @@ -0,0 +1,112 @@ +# name: test/sql/storage/temp_directory/max_swap_space_inmemory.test +# group: [temp_directory] + +require skip_reload + +# In in-memory mode, the 'temp_directory' defaults to '.tmp' +# So there are no guarantees about the value of 'max_temp_directory_size' +# (if .tmp exists, it's set to the available disk space | if it doesn't exist, it'll stay at 0 bytes) + +statement ok +set temp_directory=''; + +statement ok +PRAGMA memory_limit='2MB' + +# --- Set by default to 0 when temp_directory is not set --- + +# If 'temp_directory' is not set, this defaults to 0 +query I +select current_setting('max_temp_directory_size') +---- +0 bytes + +# Set the max size explicitly +statement ok +set max_temp_directory_size='15gb' + +# Should not be 0 anymore +query I +select current_setting('max_temp_directory_size') a where a == '0 bytes' +---- + +# Then reset it, default should be 0 again +statement ok +reset max_temp_directory_size; + +query I +select current_setting('max_temp_directory_size') +---- +0 bytes + +# --- Set by default to the available disk space when temp_directory exists --- + +statement ok +set temp_directory = '__TEST_DIR__'; + +# '__TEST_DIR__' is guaranteed to exist, we can get the disk space +query I +select current_setting('max_temp_directory_size') a where a == '0 bytes' +---- + +# --- Set explicitly by the user --- + +# If we set 'max_temp_directory_size' explicitly, it will not be overridden +statement ok +set max_temp_directory_size='15gb' + +# Reported size should not be 0, we set it explicitly +query I +select current_setting('max_temp_directory_size') a where a == '0 bytes' +---- + +query I nosort unchanged +select current_setting('max_temp_directory_size') +---- + +# When we change the temp_directory to something that doesnt exist +statement ok +set temp_directory = '__TEST_DIR__/does_not_exist' + +query I nosort unchanged +select current_setting('max_temp_directory_size') +---- + +# When we change the temp_directory to something that does exist +statement ok +set temp_directory = '__TEST_DIR__' + +query I nosort unchanged +select current_setting('max_temp_directory_size') +---- + +# When we reset the temp_directory .. +statement ok +reset temp_directory; + +query I nosort unchanged +select current_setting('max_temp_directory_size') +---- + +# --- Set to the available disk space when we create the (previously non-existant) 'temp_directory' + +# Reset it so it's no longer set explicitly +statement ok +reset max_temp_directory_size; + +# When we change the temp_directory to something that doesnt exist +statement ok +set temp_directory = '__TEST_DIR__/does_not_exist' + +query I +select current_setting('max_temp_directory_size') +---- +0 bytes + +statement ok +CREATE TABLE t2 AS SELECT * FROM range(1000000); + +# Reported size should not be 0, the directory was created +query I +select current_setting('max_temp_directory_size') a where a == '0 bytes' +---- diff --git a/test/sql/storage/temp_directory/max_swap_space.test b/test/sql/storage/temp_directory/max_swap_space_persistent.test similarity index 90% rename from test/sql/storage/temp_directory/max_swap_space.test rename to test/sql/storage/temp_directory/max_swap_space_persistent.test index 0c93cd2852c9..354c65215e3a 100644 --- a/test/sql/storage/temp_directory/max_swap_space.test +++ b/test/sql/storage/temp_directory/max_swap_space_persistent.test @@ -1,8 +1,17 @@ -# name: test/sql/storage/temp_directory/max_swap_space.test +# name: test/sql/storage/temp_directory/max_swap_space_persistent.test # group: [temp_directory] require skip_reload +# Create a persistent database +load __TEST_DIR__/max_swap_space.db + +## If 'temp_directory' is not set, this defaults to 0 +#query I +#select current_setting('max_temp_directory_size') +#---- +#0 bytes + statement ok set temp_directory=''; From 1491dd798ed82e25b052301f700be3c90ee28202 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 8 Mar 2024 15:11:49 +0100 Subject: [PATCH 012/201] fix up some behavior --- src/main/settings/settings.cpp | 7 ++----- .../storage/temp_directory/max_swap_space_inmemory.test | 4 ++-- .../storage/temp_directory/max_swap_space_persistent.test | 6 ++++++ 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/main/settings/settings.cpp b/src/main/settings/settings.cpp index 90ef20522a24..1eb7ce7eb030 100644 --- a/src/main/settings/settings.cpp +++ b/src/main/settings/settings.cpp @@ -1231,11 +1231,8 @@ Value SecretDirectorySetting::GetSetting(ClientContext &context) { void TempDirectorySetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { config.options.temporary_directory = input.ToString(); config.options.use_temporary_directory = !config.options.temporary_directory.empty(); - if (!config.options.temporary_directory.empty()) { - // Maximum swap space isn't set explicitly, initialize to default - if (!config.options.maximum_swap_space.ExplicitlySet()) { - config.SetDefaultMaxSwapSpace(db); - } + if (!config.options.maximum_swap_space.ExplicitlySet()) { + config.SetDefaultMaxSwapSpace(db); } if (db) { auto &buffer_manager = BufferManager::GetBufferManager(*db); diff --git a/test/sql/storage/temp_directory/max_swap_space_inmemory.test b/test/sql/storage/temp_directory/max_swap_space_inmemory.test index 9de30a592cb0..e6aad865c921 100644 --- a/test/sql/storage/temp_directory/max_swap_space_inmemory.test +++ b/test/sql/storage/temp_directory/max_swap_space_inmemory.test @@ -7,15 +7,15 @@ require skip_reload # So there are no guarantees about the value of 'max_temp_directory_size' # (if .tmp exists, it's set to the available disk space | if it doesn't exist, it'll stay at 0 bytes) +# So we set it explicitly to empty, to have guarantees statement ok set temp_directory=''; statement ok PRAGMA memory_limit='2MB' -# --- Set by default to 0 when temp_directory is not set --- +# --- Set by default to 0 when temp_directory is empty --- -# If 'temp_directory' is not set, this defaults to 0 query I select current_setting('max_temp_directory_size') ---- diff --git a/test/sql/storage/temp_directory/max_swap_space_persistent.test b/test/sql/storage/temp_directory/max_swap_space_persistent.test index 354c65215e3a..aa18759f7a83 100644 --- a/test/sql/storage/temp_directory/max_swap_space_persistent.test +++ b/test/sql/storage/temp_directory/max_swap_space_persistent.test @@ -6,6 +6,12 @@ require skip_reload # Create a persistent database load __TEST_DIR__/max_swap_space.db +# Default temp_directory for a persistent database is .tmp +query I +select current_setting('temp_directory').split('/')[-1] +---- +max_swap_space.db.tmp + ## If 'temp_directory' is not set, this defaults to 0 #query I #select current_setting('max_temp_directory_size') From 8b2b5cd0dde552c03abbcaa28585ad1d21eee22a Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 8 Mar 2024 15:26:58 +0100 Subject: [PATCH 013/201] make the in-memory database detection better --- src/main/database.cpp | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/main/database.cpp b/src/main/database.cpp index 5b3c04db2983..d8dd6cca8212 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -296,6 +296,24 @@ Allocator &Allocator::Get(AttachedDatabase &db) { return Allocator::Get(db.GetDatabase()); } +static bool IsInMemoryDatabase(const char *database_path) { + if (!database_path) { + // Entirely empty + return true; + } + if (strlen(database_path) == 0) { + // '' empty string + return true; + } + constexpr const char *IN_MEMORY_PATH_PREFIX = ":memory:"; + const idx_t PREFIX_LENGTH = strlen(IN_MEMORY_PATH_PREFIX); + if (strncmp(database_path, IN_MEMORY_PATH_PREFIX, PREFIX_LENGTH) == 0) { + // Starts with :memory:, i.e ':memory:named_conn' is valid + return true; + } + return false; +} + void DatabaseInstance::Configure(DBConfig &new_config, const char *database_path) { config.options = new_config.options; @@ -308,7 +326,7 @@ void DatabaseInstance::Configure(DBConfig &new_config, const char *database_path config.options.temporary_directory = string(database_path) + ".tmp"; // special treatment for in-memory mode - if (strcmp(database_path, IN_MEMORY_PATH) == 0) { + if (IsInMemoryDatabase(database_path)) { config.options.temporary_directory = ".tmp"; } } From 96fc46e187f3fed630364342829a199baffd7dbf Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 8 Mar 2024 19:58:37 +0100 Subject: [PATCH 014/201] initialize temp_directory to '.tmp' for every version of in-memory connection, add test for the behavior --- src/main/database.cpp | 7 +++---- test/api/test_api.cpp | 10 ++++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/main/database.cpp b/src/main/database.cpp index d8dd6cca8212..ac8df84f555f 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -321,13 +321,12 @@ void DatabaseInstance::Configure(DBConfig &new_config, const char *database_path config.SetOptionByName("duckdb_api", "cpp"); } - if (new_config.options.temporary_directory.empty() && database_path) { + if (new_config.options.temporary_directory.empty()) { // no directory specified: use default temp path - config.options.temporary_directory = string(database_path) + ".tmp"; - - // special treatment for in-memory mode if (IsInMemoryDatabase(database_path)) { config.options.temporary_directory = ".tmp"; + } else { + config.options.temporary_directory = string(database_path) + ".tmp"; } } diff --git a/test/api/test_api.cpp b/test/api/test_api.cpp index ca41df7e3acc..5d17b2710e46 100644 --- a/test/api/test_api.cpp +++ b/test/api/test_api.cpp @@ -139,6 +139,16 @@ static void parallel_query(Connection *conn, bool *correct, size_t threadnr) { } } +TEST_CASE("Test temp_directory defaults", "[api][.]") { + const char *db_paths[] = {nullptr, "", ":memory:", ":memory:named_conn"}; + for (auto &path : db_paths) { + auto db = make_uniq(path); + auto conn = make_uniq(*db); + + REQUIRE(db->instance->config.options.temporary_directory == ".tmp"); + } +} + TEST_CASE("Test parallel usage of single client", "[api][.]") { auto db = make_uniq(nullptr); auto conn = make_uniq(*db); From 8ad835366e73b6bfa67e2725ae5bb5f62b6abbca Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 8 Mar 2024 20:06:31 +0100 Subject: [PATCH 015/201] use 90% of the available disk space by default --- src/main/config.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/config.cpp b/src/main/config.cpp index 9aa8c50712cb..af47bd0d1623 100644 --- a/src/main/config.cpp +++ b/src/main/config.cpp @@ -276,7 +276,9 @@ void DBConfig::SetDefaultMaxSwapSpace(optional_ptr db) { } // Use the available disk space if temp directory is set auto disk_space = FileSystem::GetAvailableDiskSpace(options.temporary_directory); - options.maximum_swap_space.SetDefault(disk_space); + // Only use 90% of the available disk space + auto default_value = disk_space == DConstants::INVALID_INDEX ? 0 : static_cast(disk_space) * 0.9; + options.maximum_swap_space.SetDefault(default_value); } void DBConfig::CheckLock(const string &name) { From 29678c5af439bf68339f5b4ac9d4553750132093 Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 11 Mar 2024 10:19:03 +0100 Subject: [PATCH 016/201] RESET temp_directory should use the same behavior as DatabaseInstance::Configure --- src/include/duckdb/main/config.hpp | 2 ++ src/main/config.cpp | 26 +++++++++++++++++++++++++ src/main/database.cpp | 31 ++++-------------------------- src/main/settings/settings.cpp | 3 ++- 4 files changed, 34 insertions(+), 28 deletions(-) diff --git a/src/include/duckdb/main/config.hpp b/src/include/duckdb/main/config.hpp index 9bedd5d6160f..fcfda6ea8c40 100644 --- a/src/include/duckdb/main/config.hpp +++ b/src/include/duckdb/main/config.hpp @@ -278,6 +278,7 @@ struct DBConfig { DUCKDB_API static vector GetOptions(); DUCKDB_API static idx_t GetOptionCount(); DUCKDB_API static vector GetOptionNames(); + DUCKDB_API static bool IsInMemoryDatabase(const char *database_path); DUCKDB_API void AddExtensionOption(const string &name, string description, LogicalType parameter, const Value &default_value = Value(), set_option_callback_t function = nullptr); @@ -310,6 +311,7 @@ struct DBConfig { static idx_t GetSystemMaxThreads(FileSystem &fs); void SetDefaultMaxMemory(); void SetDefaultMaxSwapSpace(optional_ptr db); + void SetDefaultTempDirectory(); OrderType ResolveOrder(OrderType order_type) const; OrderByNullType ResolveNullOrder(OrderType order_type, OrderByNullType null_type) const; diff --git a/src/main/config.cpp b/src/main/config.cpp index af47bd0d1623..33310d20b1a7 100644 --- a/src/main/config.cpp +++ b/src/main/config.cpp @@ -245,6 +245,24 @@ void DBConfig::AddExtensionOption(const string &name, string description, Logica } } +bool DBConfig::IsInMemoryDatabase(const char *database_path) { + if (!database_path) { + // Entirely empty + return true; + } + if (strlen(database_path) == 0) { + // '' empty string + return true; + } + constexpr const char *IN_MEMORY_PATH_PREFIX = ":memory:"; + const idx_t PREFIX_LENGTH = strlen(IN_MEMORY_PATH_PREFIX); + if (strncmp(database_path, IN_MEMORY_PATH_PREFIX, PREFIX_LENGTH) == 0) { + // Starts with :memory:, i.e ':memory:named_conn' is valid + return true; + } + return false; +} + CastFunctionSet &DBConfig::GetCastFunctions() { return *cast_functions; } @@ -260,6 +278,14 @@ void DBConfig::SetDefaultMaxMemory() { } } +void DBConfig::SetDefaultTempDirectory() { + if (DBConfig::IsInMemoryDatabase(options.database_path.c_str())) { + options.temporary_directory = ".tmp"; + } else { + options.temporary_directory = options.database_path + ".tmp"; + } +} + void DBConfig::SetDefaultMaxSwapSpace(optional_ptr db) { options.maximum_swap_space.SetDefault(0); if (options.temporary_directory.empty()) { diff --git a/src/main/database.cpp b/src/main/database.cpp index e63469eb00f5..47da40517c75 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -297,24 +297,6 @@ Allocator &Allocator::Get(AttachedDatabase &db) { return Allocator::Get(db.GetDatabase()); } -static bool IsInMemoryDatabase(const char *database_path) { - if (!database_path) { - // Entirely empty - return true; - } - if (strlen(database_path) == 0) { - // '' empty string - return true; - } - constexpr const char *IN_MEMORY_PATH_PREFIX = ":memory:"; - const idx_t PREFIX_LENGTH = strlen(IN_MEMORY_PATH_PREFIX); - if (strncmp(database_path, IN_MEMORY_PATH_PREFIX, PREFIX_LENGTH) == 0) { - // Starts with :memory:, i.e ':memory:named_conn' is valid - return true; - } - return false; -} - void DatabaseInstance::Configure(DBConfig &new_config, const char *database_path) { config.options = new_config.options; @@ -322,21 +304,16 @@ void DatabaseInstance::Configure(DBConfig &new_config, const char *database_path config.SetOptionByName("duckdb_api", "cpp"); } - if (new_config.options.temporary_directory.empty()) { - // no directory specified: use default temp path - if (IsInMemoryDatabase(database_path)) { - config.options.temporary_directory = ".tmp"; - } else { - config.options.temporary_directory = string(database_path) + ".tmp"; - } - } - if (database_path) { config.options.database_path = database_path; } else { config.options.database_path.clear(); } + if (new_config.options.temporary_directory.empty()) { + config.SetDefaultTempDirectory(); + } + if (config.options.access_mode == AccessMode::UNDEFINED) { config.options.access_mode = AccessMode::READ_WRITE; } diff --git a/src/main/settings/settings.cpp b/src/main/settings/settings.cpp index 1eb7ce7eb030..87bb424e79cc 100644 --- a/src/main/settings/settings.cpp +++ b/src/main/settings/settings.cpp @@ -1241,7 +1241,8 @@ void TempDirectorySetting::SetGlobal(DatabaseInstance *db, DBConfig &config, con } void TempDirectorySetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.temporary_directory = DBConfig().options.temporary_directory; + config.SetDefaultTempDirectory(); + config.options.use_temporary_directory = DBConfig().options.use_temporary_directory; if (db) { auto &buffer_manager = BufferManager::GetBufferManager(*db); From 9e5d10fcfbd34a73891b6663762fbcbde153b7ce Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 11 Mar 2024 10:34:38 +0100 Subject: [PATCH 017/201] add missing PRAGMA statement, because of a bug the temp directory was left empty before --- test/sql/storage/test_buffer_manager.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/test/sql/storage/test_buffer_manager.cpp b/test/sql/storage/test_buffer_manager.cpp index e730fe892b92..09991704d1bd 100644 --- a/test/sql/storage/test_buffer_manager.cpp +++ b/test/sql/storage/test_buffer_manager.cpp @@ -79,6 +79,7 @@ TEST_CASE("Modifying the buffer manager limit at runtime for an in-memory databa Connection con(db); REQUIRE_NO_FAIL(con.Query("PRAGMA threads=1")); REQUIRE_NO_FAIL(con.Query("PRAGMA force_compression='uncompressed'")); + REQUIRE_NO_FAIL(con.Query("PRAGMA temp_directory=''")); // initialize an in-memory database of size 10MB uint64_t table_size = (1000 * 1000) / sizeof(int); From dfc5e70cb01aac4c4453541799d653d7b57ee634 Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 12 Mar 2024 12:53:02 +0100 Subject: [PATCH 018/201] the tight constraints we set are broken when --force-storage is used, so we disable it --- test/sql/storage/temp_directory/max_swap_space_error.test | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/sql/storage/temp_directory/max_swap_space_error.test b/test/sql/storage/temp_directory/max_swap_space_error.test index 68c55f520665..6116c552ea7c 100644 --- a/test/sql/storage/temp_directory/max_swap_space_error.test +++ b/test/sql/storage/temp_directory/max_swap_space_error.test @@ -3,6 +3,8 @@ require skip_reload +require noforcestorage + # Set a temp_directory to offload data statement ok set temp_directory='__TEST_DIR__/max_swap_space_reached' From d10d3c8e440e88e84eb46a3127282dc1402f4ff6 Mon Sep 17 00:00:00 2001 From: Tishj Date: Sat, 16 Mar 2024 15:46:15 +0100 Subject: [PATCH 019/201] dynamically generate the wrapper methods instead of defining them all manually --- tools/pythonpkg/duckdb/__init__.py | 223 +++-------- tools/pythonpkg/duckdb_python.cpp | 252 +++--------- .../duckdb_python/connection_wrapper.hpp | 198 ---------- .../src/pyduckdb/connection_wrapper.cpp | 371 ------------------ 4 files changed, 102 insertions(+), 942 deletions(-) delete mode 100644 tools/pythonpkg/src/include/duckdb_python/connection_wrapper.hpp delete mode 100644 tools/pythonpkg/src/pyduckdb/connection_wrapper.cpp diff --git a/tools/pythonpkg/duckdb/__init__.py b/tools/pythonpkg/duckdb/__init__.py index bef10a54bec1..3c33bd7ae4ed 100644 --- a/tools/pythonpkg/duckdb/__init__.py +++ b/tools/pythonpkg/duckdb/__init__.py @@ -3,6 +3,9 @@ # Modules import duckdb.functional as functional import duckdb.typing as typing +import inspect +import functools + _exported_symbols.extend([ "typing", "functional" @@ -38,6 +41,35 @@ "CaseExpression", ]) +# ---- Wrap the connection methods + +def is_dunder_method(method_name: str) -> bool: + if len(method_name) < 4: + return False + return method_name[:2] == '__' and method_name[:-3:-1] == '__' + +def create_connection_wrapper(name): + # Define a decorator function that forwards attribute lookup to the default connection + @functools.wraps(getattr(DuckDBPyConnection, name)) + def decorator(*args, **kwargs): + connection = duckdb.connect(':default:') + if 'connection' in kwargs: + connection = kwargs.pop('connection') + return getattr(connection, name)(*args, **kwargs) + # Set docstring for the wrapper function + decorator.__doc__ = getattr(DuckDBPyConnection, name).__doc__ + return decorator + +methods = inspect.getmembers_static( + DuckDBPyConnection, + predicate=inspect.isfunction +) +methods = [method for method in dir(DuckDBPyConnection) if not is_dunder_method(method)] +for name in methods: + wrapper_function = create_connection_wrapper(name) + globals()[name] = wrapper_function # Define the wrapper function in the module namespace + _exported_symbols.append(name) + # Enums from .duckdb import ( ANALYZE, @@ -54,19 +86,6 @@ "STANDARD" ]) -# Type-creation methods -from .duckdb import ( - struct_type, - list_type, - array_type, - decimal_type -) -_exported_symbols.extend([ - "struct_type", - "list_type", - "array_type", - "decimal_type" -]) # read-only properties from .duckdb import ( @@ -106,171 +125,29 @@ "tokenize" ]) + from .duckdb import ( - filter, - project, - aggregate, - distinct, - limit, - query_df, - order, - alias, connect, - write_csv + #project, + #aggregate, + #distinct, + #limit, + #query_df, + #order, + #alias, + #write_csv ) -_exported_symbols.extend([ - "filter", - "project", - "aggregate", - "distinct", - "limit", - "query_df", - "order", - "alias", - "connect", - "write_csv" -]) -# TODO: might be worth seeing if these methods can be replaced with a pure-python solution -# Connection methods -from .duckdb import ( - append, - array_type, - arrow, - begin, - close, - commit, - create_function, - cursor, - decimal_type, - description, - df, - dtype, - duplicate, - enum_type, - execute, - executemany, - extract_statements, - fetch_arrow_table, - fetch_df, - fetch_df_chunk, - fetch_record_batch, - fetchall, - fetchdf, - fetchmany, - fetchnumpy, - fetchone, - filesystem_is_registered, - from_arrow, - from_csv_auto, - from_df, - from_parquet, - from_query, - from_substrait, - from_substrait_json, - get_substrait, - get_substrait_json, - get_table_names, - install_extension, - interrupt, - list_filesystems, - list_type, - load_extension, - map_type, - pl, - query, - read_csv, - read_json, - read_parquet, - register, - register_filesystem, - remove_function, - rollback, - row_type, - rowcount, - sql, - sqltype, - string_type, - struct_type, - table, - table_function, - tf, - torch, - type, - union_type, - unregister, - unregister_filesystem, - values, - view -) _exported_symbols.extend([ - "append", - "array_type", - "arrow", - "begin", - "close", - "commit", - "create_function", - "cursor", - "decimal_type", - "description", - "df", - "dtype", - "duplicate", - "enum_type", - "execute", - "executemany", - "fetch_arrow_table", - "fetch_df", - "fetch_df_chunk", - "fetch_record_batch", - "fetchall", - "fetchdf", - "fetchmany", - "fetchnumpy", - "fetchone", - "filesystem_is_registered", - "from_arrow", - "from_csv_auto", - "from_df", - "from_parquet", - "from_query", - "from_substrait", - "from_substrait_json", - "get_substrait", - "get_substrait_json", - "get_table_names", - "install_extension", - "interrupt", - "list_filesystems", - "list_type", - "load_extension", - "map_type", - "pl", - "query", - "read_csv", - "read_json", - "read_parquet", - "register", - "register_filesystem", - "remove_function", - "rollback", - "row_type", - "rowcount", - "sql", - "sqltype", - "string_type", - "struct_type", - "table", - "table_function", - "tf", - "torch", - "type", - "union_type", - "unregister", - "unregister_filesystem", - "values", - "view" + "connect", + #"project", + #"aggregate", + #"distinct", + #"limit", + #"query_df", + #"order", + #"alias", + #"write_csv" ]) # Exceptions diff --git a/tools/pythonpkg/duckdb_python.cpp b/tools/pythonpkg/duckdb_python.cpp index 46a05f6996c9..69d9fe5376ef 100644 --- a/tools/pythonpkg/duckdb_python.cpp +++ b/tools/pythonpkg/duckdb_python.cpp @@ -13,7 +13,6 @@ #include "duckdb_python/pybind11/exceptions.hpp" #include "duckdb_python/typing.hpp" #include "duckdb_python/functional.hpp" -#include "duckdb_python/connection_wrapper.hpp" #include "duckdb_python/pybind11/conversions/pyconnection_default.hpp" #include "duckdb/common/box_renderer.hpp" #include "duckdb/function/function.hpp" @@ -72,205 +71,58 @@ static py::list PyTokenize(const string &query) { } static void InitializeConnectionMethods(py::module_ &m) { - m.def("cursor", &PyConnectionWrapper::Cursor, "Create a duplicate of the current connection", - py::arg("connection") = py::none()) - .def("duplicate", &PyConnectionWrapper::Cursor, "Create a duplicate of the current connection", - py::arg("connection") = py::none()); - m.def("create_function", &PyConnectionWrapper::RegisterScalarUDF, - "Create a DuckDB function out of the passing in python function so it can be used in queries", - py::arg("name"), py::arg("function"), py::arg("return_type") = py::none(), py::arg("parameters") = py::none(), - py::kw_only(), py::arg("type") = PythonUDFType::NATIVE, py::arg("null_handling") = 0, - py::arg("exception_handling") = 0, py::arg("side_effects") = false, py::arg("connection") = py::none()); - - m.def("remove_function", &PyConnectionWrapper::UnregisterUDF, "Remove a previously created function", - py::arg("name"), py::arg("connection") = py::none()); - - DefineMethod({"sqltype", "dtype", "type"}, m, &PyConnectionWrapper::Type, "Create a type object from 'type_str'", - py::arg("type_str"), py::arg("connection") = py::none()); - DefineMethod({"struct_type", "row_type"}, m, &PyConnectionWrapper::StructType, - "Create a struct type object from 'fields'", py::arg("fields"), py::arg("connection") = py::none()); - m.def("union_type", &PyConnectionWrapper::UnionType, "Create a union type object from 'members'", - py::arg("members").none(false), py::arg("connection") = py::none()) - .def("string_type", &PyConnectionWrapper::StringType, "Create a string type with an optional collation", - py::arg("collation") = string(), py::arg("connection") = py::none()) - .def("enum_type", &PyConnectionWrapper::EnumType, - "Create an enum type of underlying 'type', consisting of the list of 'values'", py::arg("name"), - py::arg("type"), py::arg("values"), py::arg("connection") = py::none()) - .def("decimal_type", &PyConnectionWrapper::DecimalType, "Create a decimal type with 'width' and 'scale'", - py::arg("width"), py::arg("scale"), py::arg("connection") = py::none()); - m.def("array_type", &PyConnectionWrapper::ArrayType, "Create an array type object of 'type'", - py::arg("type").none(false), py::arg("size").none(false), py::arg("connection") = py::none()); - m.def("list_type", &PyConnectionWrapper::ListType, "Create a list type object of 'type'", - py::arg("type").none(false), py::arg("connection") = py::none()); - m.def("map_type", &PyConnectionWrapper::MapType, "Create a map type object from 'key_type' and 'value_type'", - py::arg("key").none(false), py::arg("value").none(false), py::arg("connection") = py::none()) - .def("execute", &PyConnectionWrapper::Execute, - "Execute the given SQL query, optionally using prepared statements with parameters set", py::arg("query"), - py::arg("parameters") = py::none(), py::arg("multiple_parameter_sets") = false, - py::arg("connection") = py::none()) - .def("executemany", &PyConnectionWrapper::ExecuteMany, - "Execute the given prepared statement multiple times using the list of parameter sets in parameters", - py::arg("query"), py::arg("parameters") = py::none(), py::arg("connection") = py::none()) - .def("close", &PyConnectionWrapper::Close, "Close the connection", py::arg("connection") = py::none()) - .def("interrupt", &PyConnectionWrapper::Interrupt, "Interrupt pending operations", - py::arg("connection") = py::none()) - .def("fetchone", &PyConnectionWrapper::FetchOne, "Fetch a single row from a result following execute", - py::arg("connection") = py::none()) - .def("fetchmany", &PyConnectionWrapper::FetchMany, "Fetch the next set of rows from a result following execute", - py::arg("size") = 1, py::arg("connection") = py::none()) - .def("fetchall", &PyConnectionWrapper::FetchAll, "Fetch all rows from a result following execute", - py::arg("connection") = py::none()) - .def("fetchnumpy", &PyConnectionWrapper::FetchNumpy, "Fetch a result as list of NumPy arrays following execute", - py::arg("connection") = py::none()) - .def("fetchdf", &PyConnectionWrapper::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), - py::arg("date_as_object") = false, py::arg("connection") = py::none()) - .def("fetch_df", &PyConnectionWrapper::FetchDF, "Fetch a result as DataFrame following execute()", - py::kw_only(), py::arg("date_as_object") = false, py::arg("connection") = py::none()) - .def("fetch_df_chunk", &PyConnectionWrapper::FetchDFChunk, - "Fetch a chunk of the result as DataFrame following execute()", py::arg("vectors_per_chunk") = 1, - py::kw_only(), py::arg("date_as_object") = false, py::arg("connection") = py::none()) - .def("df", &PyConnectionWrapper::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), - py::arg("date_as_object") = false, py::arg("connection") = py::none()) - .def("fetch_arrow_table", &PyConnectionWrapper::FetchArrow, "Fetch a result as Arrow table following execute()", - py::arg("rows_per_batch") = 1000000, py::arg("connection") = py::none()) - .def("torch", &PyConnectionWrapper::FetchPyTorch, - "Fetch a result as dict of PyTorch Tensors following execute()", py::arg("connection") = py::none()) - .def("tf", &PyConnectionWrapper::FetchTF, "Fetch a result as dict of TensorFlow Tensors following execute()", - py::arg("connection") = py::none()) - .def("fetch_record_batch", &PyConnectionWrapper::FetchRecordBatchReader, - "Fetch an Arrow RecordBatchReader following execute()", py::arg("rows_per_batch") = 1000000, - py::arg("connection") = py::none()) - .def("arrow", &PyConnectionWrapper::FetchArrow, "Fetch a result as Arrow table following execute()", - py::arg("rows_per_batch") = 1000000, py::arg("connection") = py::none()) - .def("pl", &PyConnectionWrapper::FetchPolars, "Fetch a result as Polars DataFrame following execute()", - py::arg("rows_per_batch") = 1000000, py::arg("connection") = py::none()) - .def("begin", &PyConnectionWrapper::Begin, "Start a new transaction", py::arg("connection") = py::none()) - .def("commit", &PyConnectionWrapper::Commit, "Commit changes performed within a transaction", - py::arg("connection") = py::none()) - .def("rollback", &PyConnectionWrapper::Rollback, "Roll back changes performed within a transaction", - py::arg("connection") = py::none()) - .def("read_json", &PyConnectionWrapper::ReadJSON, "Create a relation object from the JSON file in 'name'", - py::arg("name"), py::arg("connection") = py::none(), py::arg("columns") = py::none(), - py::arg("sample_size") = py::none(), py::arg("maximum_depth") = py::none(), - py::arg("records") = py::none(), py::arg("format") = py::none()); - - m.def("values", &PyConnectionWrapper::Values, "Create a relation object from the passed values", py::arg("values"), - py::arg("connection") = py::none()); - m.def("from_substrait", &PyConnectionWrapper::FromSubstrait, "Creates a query object from the substrait plan", - py::arg("proto"), py::arg("connection") = py::none()); - m.def("get_substrait", &PyConnectionWrapper::GetSubstrait, "Serialize a query object to protobuf", py::arg("query"), - py::arg("connection") = py::none(), py::kw_only(), py::arg("enable_optimizer") = true); - m.def("get_substrait_json", &PyConnectionWrapper::GetSubstraitJSON, "Serialize a query object to protobuf", - py::arg("query"), py::arg("connection") = py::none(), py::kw_only(), py::arg("enable_optimizer") = true); - m.def("from_substrait_json", &PyConnectionWrapper::FromSubstraitJSON, "Serialize a query object to protobuf", - py::arg("json"), py::arg("connection") = py::none()); - m.def("df", &PyConnectionWrapper::FromDF, "Create a relation object from the DataFrame df", py::arg("df"), - py::arg("connection") = py::none()); - m.def("from_df", &PyConnectionWrapper::FromDF, "Create a relation object from the DataFrame df", py::arg("df"), - py::arg("connection") = py::none()); - m.def("from_arrow", &PyConnectionWrapper::FromArrow, "Create a relation object from an Arrow object", - py::arg("arrow_object"), py::arg("connection") = py::none()); - m.def("arrow", &PyConnectionWrapper::FromArrow, "Create a relation object from an Arrow object", - py::arg("arrow_object"), py::arg("connection") = py::none()); - m.def("filter", &PyConnectionWrapper::FilterDf, "Filter the DataFrame df by the filter in filter_expr", - py::arg("df"), py::arg("filter_expr"), py::arg("connection") = py::none()); - m.def("project", &PyConnectionWrapper::ProjectDf, "Project the DataFrame df by the projection in project_expr", - py::arg("df"), py::arg("project_expr"), py::arg("connection") = py::none()); - m.def("alias", &PyConnectionWrapper::AliasDF, "Create a relation from DataFrame df with the passed alias", - py::arg("df"), py::arg("alias"), py::arg("connection") = py::none()); - m.def("order", &PyConnectionWrapper::OrderDf, "Reorder the DataFrame df by order_expr", py::arg("df"), - py::arg("order_expr"), py::arg("connection") = py::none()); - m.def("aggregate", &PyConnectionWrapper::AggregateDF, - "Compute the aggregate aggr_expr by the optional groups group_expr on DataFrame df", py::arg("df"), - py::arg("aggr_expr"), py::arg("group_expr") = "", py::arg("connection") = py::none()); - m.def("distinct", &PyConnectionWrapper::DistinctDF, "Compute the distinct rows from DataFrame df ", py::arg("df"), - py::arg("connection") = py::none()); - m.def("limit", &PyConnectionWrapper::LimitDF, "Retrieve the first n rows from the DataFrame df", py::arg("df"), - py::arg("n"), py::arg("connection") = py::none()); - - m.def("query_df", &PyConnectionWrapper::QueryDF, - "Run the given SQL query in sql_query on the view named virtual_table_name that contains the content of " - "DataFrame df", - py::arg("df"), py::arg("virtual_table_name"), py::arg("sql_query"), py::arg("connection") = py::none()); - - m.def("write_csv", &PyConnectionWrapper::WriteCsvDF, "Write the DataFrame df to a CSV file in file_name", - py::arg("df"), py::arg("file_name"), py::arg("connection") = py::none()); - - DefineMethod( - {"read_csv", "from_csv_auto"}, m, &PyConnectionWrapper::ReadCSV, - "Create a relation object from the CSV file in 'name'", py::arg("name"), py::arg("connection") = py::none(), - py::arg("header") = py::none(), py::arg("compression") = py::none(), py::arg("sep") = py::none(), - py::arg("delimiter") = py::none(), py::arg("dtype") = py::none(), py::arg("na_values") = py::none(), - py::arg("skiprows") = py::none(), py::arg("quotechar") = py::none(), py::arg("escapechar") = py::none(), - py::arg("encoding") = py::none(), py::arg("parallel") = py::none(), py::arg("date_format") = py::none(), - py::arg("timestamp_format") = py::none(), py::arg("sample_size") = py::none(), - py::arg("all_varchar") = py::none(), py::arg("normalize_names") = py::none(), py::arg("filename") = py::none(), - py::arg("null_padding") = py::none(), py::arg("names") = py::none()); - - m.def("append", &PyConnectionWrapper::Append, "Append the passed DataFrame to the named table", - py::arg("table_name"), py::arg("df"), py::kw_only(), py::arg("by_name") = false, - py::arg("connection") = py::none()) - .def("register", &PyConnectionWrapper::RegisterPythonObject, - "Register the passed Python Object value for querying with a view", py::arg("view_name"), - py::arg("python_object"), py::arg("connection") = py::none()) - .def("unregister", &PyConnectionWrapper::UnregisterPythonObject, "Unregister the view name", - py::arg("view_name"), py::arg("connection") = py::none()) - .def("table", &PyConnectionWrapper::Table, "Create a relation object for the name'd table", - py::arg("table_name"), py::arg("connection") = py::none()) - .def("view", &PyConnectionWrapper::View, "Create a relation object for the name'd view", py::arg("view_name"), - py::arg("connection") = py::none()) - .def("values", &PyConnectionWrapper::Values, "Create a relation object from the passed values", - py::arg("values"), py::arg("connection") = py::none()) - .def("table_function", &PyConnectionWrapper::TableFunction, - "Create a relation object from the name'd table function with given parameters", py::arg("name"), - py::arg("parameters") = py::none(), py::arg("connection") = py::none()) - .def("extract_statements", &PyConnectionWrapper::ExtractStatements, - "Parse the query string and extract the Statement object(s) produced", py::arg("query"), - py::arg("connection") = py::none()); - - DefineMethod({"sql", "query", "from_query"}, m, &PyConnectionWrapper::RunQuery, - "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, " - "otherwise run the query as-is.", - py::arg("query"), py::arg("alias") = "", py::arg("connection") = py::none()); - - DefineMethod({"from_parquet", "read_parquet"}, m, &PyConnectionWrapper::FromParquet, - "Create a relation object from the Parquet files in file_glob", py::arg("file_glob"), - py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, - py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, - py::arg("compression") = py::none(), py::arg("connection") = py::none()); - - DefineMethod({"from_parquet", "read_parquet"}, m, &PyConnectionWrapper::FromParquets, - "Create a relation object from the Parquet files in file_globs", py::arg("file_globs"), - py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, - py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, - py::arg("compression") = py::none(), py::arg("connection") = py::none()); - - m.def("from_substrait", &PyConnectionWrapper::FromSubstrait, "Create a query object from protobuf plan", - py::arg("proto"), py::arg("connection") = py::none()) - .def("get_substrait", &PyConnectionWrapper::GetSubstrait, "Serialize a query to protobuf", py::arg("query"), - py::arg("connection") = py::none(), py::kw_only(), py::arg("enable_optimizer") = true) - .def("get_substrait_json", &PyConnectionWrapper::GetSubstraitJSON, - "Serialize a query to protobuf on the JSON format", py::arg("query"), py::arg("connection") = py::none(), - py::kw_only(), py::arg("enable_optimizer") = true) - .def("get_table_names", &PyConnectionWrapper::GetTableNames, "Extract the required table names from a query", - py::arg("query"), py::arg("connection") = py::none()) - .def("description", &PyConnectionWrapper::GetDescription, "Get result set attributes, mainly column names", - py::arg("connection") = py::none()) - .def("rowcount", &PyConnectionWrapper::GetRowcount, "Get result set row count", - py::arg("connection") = py::none()) - .def("install_extension", &PyConnectionWrapper::InstallExtension, "Install an extension by name", - py::arg("extension"), py::kw_only(), py::arg("force_install") = false, py::arg("connection") = py::none()) - .def("load_extension", &PyConnectionWrapper::LoadExtension, "Load an installed extension", py::arg("extension"), - py::arg("connection") = py::none()) - .def("register_filesystem", &PyConnectionWrapper::RegisterFilesystem, "Register a fsspec compliant filesystem", - py::arg("filesystem"), py::arg("connection") = py::none()) - .def("unregister_filesystem", &PyConnectionWrapper::UnregisterFilesystem, "Unregister a filesystem", - py::arg("name"), py::arg("connection") = py::none()) - .def("list_filesystems", &PyConnectionWrapper::ListFilesystems, - "List registered filesystems, including builtin ones", py::arg("connection") = py::none()) - .def("filesystem_is_registered", &PyConnectionWrapper::FileSystemIsRegistered, - "Check if a filesystem with the provided name is currently registered", py::arg("name"), - py::arg("connection") = py::none()); + m.def("project", + [](const PandasDataFrame &df, const py::object &expr, + shared_ptr conn) -> unique_ptr { + // FIXME: if we want to support passing in DuckDBPyExpressions here + // we could also accept 'expr' as a List[DuckDBPyExpression], without changing the signature + if (!py::isinstance(expr)) { + throw InvalidInputException("Please provide 'expr' as a string"); + } + return conn->FromDF(df)->Project(expr); + }); + + // unique_ptr PyConnectionWrapper::DistinctDF(const PandasDataFrame &df, + // shared_ptr conn) { + // return conn->FromDF(df)->Distinct(); + //} + + // void PyConnectionWrapper::WriteCsvDF(const PandasDataFrame &df, const string &file, + // shared_ptr conn) { + // return conn->FromDF(df)->ToCSV(file); + //} + + // unique_ptr PyConnectionWrapper::QueryDF(const PandasDataFrame &df, const string &view_name, + // const string &sql_query, + // shared_ptr conn) { + // return conn->FromDF(df)->Query(view_name, sql_query); + //} + + // unique_ptr PyConnectionWrapper::AggregateDF(const PandasDataFrame &df, const string &expr, + // const string &groups, + // shared_ptr conn) { + // return conn->FromDF(df)->Aggregate(expr, groups); + //} + + // unique_ptr PyConnectionWrapper::AliasDF(const PandasDataFrame &df, const string &expr, + // shared_ptr conn) { + // return conn->FromDF(df)->SetAlias(expr); + //} + + // unique_ptr PyConnectionWrapper::FilterDf(const PandasDataFrame &df, const string &expr, + // shared_ptr conn) { + // return conn->FromDF(df)->FilterFromExpression(expr); + //} + + // unique_ptr PyConnectionWrapper::LimitDF(const PandasDataFrame &df, int64_t n, + // shared_ptr conn) { + // return conn->FromDF(df)->Limit(n); + //} + + // unique_ptr PyConnectionWrapper::OrderDf(const PandasDataFrame &df, const string &expr, + // shared_ptr conn) { + // return conn->FromDF(df)->Order(expr); + //} } static void RegisterStatementType(py::handle &m) { diff --git a/tools/pythonpkg/src/include/duckdb_python/connection_wrapper.hpp b/tools/pythonpkg/src/include/duckdb_python/connection_wrapper.hpp deleted file mode 100644 index f40476fdea3e..000000000000 --- a/tools/pythonpkg/src/include/duckdb_python/connection_wrapper.hpp +++ /dev/null @@ -1,198 +0,0 @@ -#pragma once - -#include "duckdb_python/pyconnection/pyconnection.hpp" -#include "duckdb_python/pyrelation.hpp" -#include "duckdb_python/python_objects.hpp" - -namespace duckdb { - -class PyConnectionWrapper { -public: - PyConnectionWrapper() = delete; - -public: - static shared_ptr ExecuteMany(const py::object &query, py::object params = py::list(), - shared_ptr conn = nullptr); - - static unique_ptr DistinctDF(const PandasDataFrame &df, - shared_ptr conn = nullptr); - - static unique_ptr QueryDF(const PandasDataFrame &df, const string &view_name, - const string &sql_query, shared_ptr conn = nullptr); - - static void WriteCsvDF(const PandasDataFrame &df, const string &file, - shared_ptr conn = nullptr); - - static unique_ptr AggregateDF(const PandasDataFrame &df, const string &expr, - const string &groups = "", - shared_ptr conn = nullptr); - - static shared_ptr Execute(const py::object &query, py::object params = py::list(), - bool many = false, shared_ptr conn = nullptr); - - static shared_ptr - RegisterScalarUDF(const string &name, const py::function &udf, const py::object &arguments = py::none(), - const shared_ptr &return_type = nullptr, PythonUDFType type = PythonUDFType::NATIVE, - FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING, - PythonExceptionHandling exception_handling = PythonExceptionHandling::FORWARD_ERROR, - bool side_effects = false, shared_ptr conn = nullptr); - - static shared_ptr UnregisterUDF(const string &name, - shared_ptr conn = nullptr); - - static py::list ExtractStatements(const string &query, shared_ptr conn = nullptr); - - static shared_ptr ArrayType(const shared_ptr &type, idx_t size, - shared_ptr conn = nullptr); - static shared_ptr ListType(const shared_ptr &type, - shared_ptr conn = nullptr); - static shared_ptr MapType(const shared_ptr &key, const shared_ptr &value, - shared_ptr conn = nullptr); - static shared_ptr StructType(const py::object &fields, - const shared_ptr conn = nullptr); - static shared_ptr UnionType(const py::object &members, shared_ptr conn = nullptr); - static shared_ptr EnumType(const string &name, const shared_ptr &type, - const py::list &values_p, shared_ptr conn = nullptr); - static shared_ptr DecimalType(int width, int scale, shared_ptr conn = nullptr); - static shared_ptr StringType(const string &collation = string(), - shared_ptr conn = nullptr); - static shared_ptr Type(const string &type_str, shared_ptr conn = nullptr); - - static shared_ptr Append(const string &name, PandasDataFrame value, bool by_name, - shared_ptr conn = nullptr); - - static shared_ptr RegisterPythonObject(const string &name, py::object python_object, - shared_ptr conn = nullptr); - - static void InstallExtension(const string &extension, bool force_install = false, - shared_ptr conn = nullptr); - - static void LoadExtension(const string &extension, shared_ptr conn = nullptr); - - static unique_ptr RunQuery(const py::object &query, const string &alias = "query_relation", - shared_ptr conn = nullptr); - - static unique_ptr Table(const string &tname, shared_ptr conn = nullptr); - - static unique_ptr Values(py::object params = py::none(), - shared_ptr conn = nullptr); - - static unique_ptr View(const string &vname, shared_ptr conn = nullptr); - - static unique_ptr TableFunction(const string &fname, py::object params = py::list(), - shared_ptr conn = nullptr); - - static unique_ptr FromParquet(const string &file_glob, bool binary_as_string, - bool file_row_number, bool filename, bool hive_partitioning, - bool union_by_name, const py::object &compression = py::none(), - shared_ptr conn = nullptr); - - static unique_ptr FromParquets(const vector &file_globs, bool binary_as_string, - bool file_row_number, bool filename, bool hive_partitioning, - bool union_by_name, const py::object &compression = py::none(), - shared_ptr conn = nullptr); - - static unique_ptr FromArrow(py::object &arrow_object, - shared_ptr conn = nullptr); - - static unique_ptr GetSubstrait(const string &query, shared_ptr conn = nullptr, - bool enable_optimizer = true); - - static unique_ptr - GetSubstraitJSON(const string &query, shared_ptr conn = nullptr, bool enable_optimizer = true); - - static unordered_set GetTableNames(const string &query, shared_ptr conn = nullptr); - - static shared_ptr UnregisterPythonObject(const string &name, - shared_ptr conn = nullptr); - - static shared_ptr Begin(shared_ptr conn = nullptr); - - static shared_ptr Commit(shared_ptr conn = nullptr); - - static shared_ptr Rollback(shared_ptr conn = nullptr); - - static void Close(shared_ptr conn = nullptr); - - static void Interrupt(shared_ptr conn = nullptr); - - static shared_ptr Cursor(shared_ptr conn = nullptr); - - static Optional GetDescription(shared_ptr conn = nullptr); - - static int GetRowcount(shared_ptr conn = nullptr); - - static Optional FetchOne(shared_ptr conn = nullptr); - - static py::list FetchMany(idx_t size, shared_ptr conn = nullptr); - - static unique_ptr ReadJSON(const string &filename, shared_ptr conn = nullptr, - const Optional &columns = py::none(), - const Optional &sample_size = py::none(), - const Optional &maximum_depth = py::none(), - const Optional &records = py::none(), - const Optional &format = py::none()); - static unique_ptr - ReadCSV(const py::object &name, shared_ptr conn, const py::object &header = py::none(), - const py::object &compression = py::none(), const py::object &sep = py::none(), - const py::object &delimiter = py::none(), const py::object &dtype = py::none(), - const py::object &na_values = py::none(), const py::object &skiprows = py::none(), - const py::object "echar = py::none(), const py::object &escapechar = py::none(), - const py::object &encoding = py::none(), const py::object ¶llel = py::none(), - const py::object &date_format = py::none(), const py::object ×tamp_format = py::none(), - const py::object &sample_size = py::none(), const py::object &all_varchar = py::none(), - const py::object &normalize_names = py::none(), const py::object &filename = py::none(), - const py::object &null_padding = py::none(), const py::object &names = py::none()); - - static py::list FetchAll(shared_ptr conn = nullptr); - - static py::dict FetchNumpy(shared_ptr conn = nullptr); - - static PandasDataFrame FetchDF(bool date_as_object, shared_ptr conn = nullptr); - - static PandasDataFrame FetchDFChunk(const idx_t vectors_per_chunk = 1, bool date_as_object = false, - shared_ptr conn = nullptr); - - static duckdb::pyarrow::Table FetchArrow(idx_t rows_per_batch, shared_ptr conn = nullptr); - - static py::dict FetchPyTorch(shared_ptr conn = nullptr); - - static py::dict FetchTF(shared_ptr conn = nullptr); - - static duckdb::pyarrow::RecordBatchReader FetchRecordBatchReader(const idx_t rows_per_batch, - shared_ptr conn = nullptr); - - static PolarsDataFrame FetchPolars(idx_t rows_per_batch, shared_ptr conn = nullptr); - - static void RegisterFilesystem(AbstractFileSystem file_system, shared_ptr conn); - static void UnregisterFilesystem(const py::str &name, shared_ptr conn); - static py::list ListFilesystems(shared_ptr conn); - static bool FileSystemIsRegistered(const string &name, shared_ptr conn); - - static unique_ptr FromDF(const PandasDataFrame &df, - shared_ptr conn = nullptr); - - static unique_ptr FromSubstrait(py::bytes &proto, shared_ptr conn = nullptr); - - static unique_ptr FromSubstraitJSON(const string &json, - shared_ptr conn = nullptr); - - static unique_ptr FromParquetDefault(const string &filename, - shared_ptr conn = nullptr); - - static unique_ptr ProjectDf(const PandasDataFrame &df, const py::object &expr, - shared_ptr conn = nullptr); - - static unique_ptr AliasDF(const PandasDataFrame &df, const string &expr, - shared_ptr conn = nullptr); - - static unique_ptr FilterDf(const PandasDataFrame &df, const string &expr, - shared_ptr conn = nullptr); - - static unique_ptr LimitDF(const PandasDataFrame &df, int64_t n, - shared_ptr conn = nullptr); - - static unique_ptr OrderDf(const PandasDataFrame &df, const string &expr, - shared_ptr conn = nullptr); -}; -} // namespace duckdb diff --git a/tools/pythonpkg/src/pyduckdb/connection_wrapper.cpp b/tools/pythonpkg/src/pyduckdb/connection_wrapper.cpp deleted file mode 100644 index c35942cacb4d..000000000000 --- a/tools/pythonpkg/src/pyduckdb/connection_wrapper.cpp +++ /dev/null @@ -1,371 +0,0 @@ -#include "duckdb_python/connection_wrapper.hpp" -#include "duckdb/common/constants.hpp" - -namespace duckdb { - -shared_ptr PyConnectionWrapper::UnionType(const py::object &members, - shared_ptr conn) { - if (!conn) { - conn = DuckDBPyConnection::DefaultConnection(); - } - return conn->UnionType(members); -} - -py::list PyConnectionWrapper::ExtractStatements(const string &query, shared_ptr conn) { - if (!conn) { - conn = DuckDBPyConnection::DefaultConnection(); - } - return conn->ExtractStatements(query); -} - -shared_ptr PyConnectionWrapper::EnumType(const string &name, const shared_ptr &type, - const py::list &values, shared_ptr conn) { - if (!conn) { - conn = DuckDBPyConnection::DefaultConnection(); - } - return conn->EnumType(name, type, values); -} - -shared_ptr PyConnectionWrapper::DecimalType(int width, int scale, shared_ptr conn) { - if (!conn) { - conn = DuckDBPyConnection::DefaultConnection(); - } - return conn->DecimalType(width, scale); -} - -shared_ptr PyConnectionWrapper::StringType(const string &collation, shared_ptr conn) { - if (!conn) { - conn = DuckDBPyConnection::DefaultConnection(); - } - return conn->StringType(collation); -} - -shared_ptr PyConnectionWrapper::ArrayType(const shared_ptr &type, idx_t size, - shared_ptr conn) { - if (!conn) { - conn = DuckDBPyConnection::DefaultConnection(); - } - return conn->ArrayType(type, size); -} - -shared_ptr PyConnectionWrapper::ListType(const shared_ptr &type, - shared_ptr conn) { - if (!conn) { - conn = DuckDBPyConnection::DefaultConnection(); - } - return conn->ListType(type); -} - -shared_ptr PyConnectionWrapper::MapType(const shared_ptr &key, - const shared_ptr &value, - shared_ptr conn) { - if (!conn) { - conn = DuckDBPyConnection::DefaultConnection(); - } - return conn->MapType(key, value); -} - -shared_ptr PyConnectionWrapper::StructType(const py::object &fields, - shared_ptr conn) { - if (!conn) { - conn = DuckDBPyConnection::DefaultConnection(); - } - return conn->StructType(fields); -} - -shared_ptr PyConnectionWrapper::Type(const string &type_str, shared_ptr conn) { - if (!conn) { - conn = DuckDBPyConnection::DefaultConnection(); - } - return conn->Type(type_str); -} - -shared_ptr PyConnectionWrapper::ExecuteMany(const py::object &query, py::object params, - shared_ptr conn) { - return conn->ExecuteMany(query, params); -} - -unique_ptr PyConnectionWrapper::DistinctDF(const PandasDataFrame &df, - shared_ptr conn) { - return conn->FromDF(df)->Distinct(); -} - -void PyConnectionWrapper::WriteCsvDF(const PandasDataFrame &df, const string &file, - shared_ptr conn) { - return conn->FromDF(df)->ToCSV(file); -} - -unique_ptr PyConnectionWrapper::QueryDF(const PandasDataFrame &df, const string &view_name, - const string &sql_query, - shared_ptr conn) { - return conn->FromDF(df)->Query(view_name, sql_query); -} - -unique_ptr PyConnectionWrapper::AggregateDF(const PandasDataFrame &df, const string &expr, - const string &groups, - shared_ptr conn) { - return conn->FromDF(df)->Aggregate(expr, groups); -} - -shared_ptr PyConnectionWrapper::Execute(const py::object &query, py::object params, bool many, - shared_ptr conn) { - return conn->Execute(query, params, many); -} - -shared_ptr PyConnectionWrapper::UnregisterUDF(const string &name, - shared_ptr conn) { - return conn->UnregisterUDF(name); -} - -shared_ptr -PyConnectionWrapper::RegisterScalarUDF(const string &name, const py::function &udf, const py::object ¶meters_p, - const shared_ptr &return_type_p, PythonUDFType type, - FunctionNullHandling null_handling, PythonExceptionHandling exception_handling, - bool side_effects, shared_ptr conn) { - return conn->RegisterScalarUDF(name, udf, parameters_p, return_type_p, type, null_handling, exception_handling, - side_effects); -} - -shared_ptr PyConnectionWrapper::Append(const string &name, PandasDataFrame value, bool by_name, - shared_ptr conn) { - return conn->Append(name, value, by_name); -} - -shared_ptr PyConnectionWrapper::RegisterPythonObject(const string &name, py::object python_object, - shared_ptr conn) { - return conn->RegisterPythonObject(name, python_object); -} - -void PyConnectionWrapper::InstallExtension(const string &extension, bool force_install, - shared_ptr conn) { - conn->InstallExtension(extension, force_install); -} - -void PyConnectionWrapper::LoadExtension(const string &extension, shared_ptr conn) { - conn->LoadExtension(extension); -} - -unique_ptr PyConnectionWrapper::Table(const string &tname, shared_ptr conn) { - return conn->Table(tname); -} - -unique_ptr PyConnectionWrapper::View(const string &vname, shared_ptr conn) { - return conn->View(vname); -} - -unique_ptr PyConnectionWrapper::TableFunction(const string &fname, py::object params, - shared_ptr conn) { - return conn->TableFunction(fname, params); -} - -unique_ptr PyConnectionWrapper::FromDF(const PandasDataFrame &value, - shared_ptr conn) { - return conn->FromDF(value); -} - -unique_ptr PyConnectionWrapper::FromParquet(const string &file_glob, bool binary_as_string, - bool file_row_number, bool filename, - bool hive_partitioning, bool union_by_name, - const py::object &compression, - shared_ptr conn) { - return conn->FromParquet(file_glob, binary_as_string, file_row_number, filename, hive_partitioning, union_by_name, - compression); -} - -unique_ptr PyConnectionWrapper::FromParquets(const vector &file_globs, bool binary_as_string, - bool file_row_number, bool filename, - bool hive_partitioning, bool union_by_name, - const py::object &compression, - shared_ptr conn) { - return conn->FromParquets(file_globs, binary_as_string, file_row_number, filename, hive_partitioning, union_by_name, - compression); -} - -unique_ptr PyConnectionWrapper::FromArrow(py::object &arrow_object, - shared_ptr conn) { - return conn->FromArrow(arrow_object); -} - -unique_ptr PyConnectionWrapper::FromSubstrait(py::bytes &proto, shared_ptr conn) { - return conn->FromSubstrait(proto); -} - -unique_ptr PyConnectionWrapper::FromSubstraitJSON(const string &json, - shared_ptr conn) { - return conn->FromSubstraitJSON(json); -} - -unique_ptr PyConnectionWrapper::GetSubstrait(const string &query, shared_ptr conn, - bool enable_optimizer) { - return conn->GetSubstrait(query, enable_optimizer); -} - -unique_ptr -PyConnectionWrapper::GetSubstraitJSON(const string &query, shared_ptr conn, bool enable_optimizer) { - return conn->GetSubstraitJSON(query, enable_optimizer); -} - -unordered_set PyConnectionWrapper::GetTableNames(const string &query, shared_ptr conn) { - return conn->GetTableNames(query); -} - -shared_ptr PyConnectionWrapper::UnregisterPythonObject(const string &name, - shared_ptr conn) { - return conn->UnregisterPythonObject(name); -} - -shared_ptr PyConnectionWrapper::Begin(shared_ptr conn) { - return conn->Begin(); -} - -shared_ptr PyConnectionWrapper::Commit(shared_ptr conn) { - return conn->Commit(); -} - -shared_ptr PyConnectionWrapper::Rollback(shared_ptr conn) { - return conn->Rollback(); -} - -void PyConnectionWrapper::Close(shared_ptr conn) { - conn->Close(); -} - -void PyConnectionWrapper::Interrupt(shared_ptr conn) { - conn->Interrupt(); -} - -shared_ptr PyConnectionWrapper::Cursor(shared_ptr conn) { - return conn->Cursor(); -} - -Optional PyConnectionWrapper::GetDescription(shared_ptr conn) { - return conn->GetDescription(); -} - -int PyConnectionWrapper::GetRowcount(shared_ptr conn) { - return conn->GetRowcount(); -} - -Optional PyConnectionWrapper::FetchOne(shared_ptr conn) { - return conn->FetchOne(); -} - -unique_ptr PyConnectionWrapper::ReadJSON(const string &filename, shared_ptr conn, - const Optional &columns, - const Optional &sample_size, - const Optional &maximum_depth, - const Optional &records, - const Optional &format) { - - return conn->ReadJSON(filename, columns, sample_size, maximum_depth, records, format); -} - -unique_ptr -PyConnectionWrapper::ReadCSV(const py::object &name, shared_ptr conn, const py::object &header, - const py::object &compression, const py::object &sep, const py::object &delimiter, - const py::object &dtype, const py::object &na_values, const py::object &skiprows, - const py::object "echar, const py::object &escapechar, const py::object &encoding, - const py::object ¶llel, const py::object &date_format, - const py::object ×tamp_format, const py::object &sample_size, - const py::object &all_varchar, const py::object &normalize_names, - const py::object &filename, const py::object &null_padding, const py::object &names) { - return conn->ReadCSV(name, header, compression, sep, delimiter, dtype, na_values, skiprows, quotechar, escapechar, - encoding, parallel, date_format, timestamp_format, sample_size, all_varchar, normalize_names, - filename, null_padding, names); -} - -py::list PyConnectionWrapper::FetchMany(idx_t size, shared_ptr conn) { - return conn->FetchMany(size); -} - -py::list PyConnectionWrapper::FetchAll(shared_ptr conn) { - return conn->FetchAll(); -} - -py::dict PyConnectionWrapper::FetchNumpy(shared_ptr conn) { - return conn->FetchNumpy(); -} - -PandasDataFrame PyConnectionWrapper::FetchDF(bool date_as_object, shared_ptr conn) { - return conn->FetchDF(date_as_object); -} - -PandasDataFrame PyConnectionWrapper::FetchDFChunk(const idx_t vectors_per_chunk, bool date_as_object, - shared_ptr conn) { - return conn->FetchDFChunk(vectors_per_chunk, date_as_object); -} - -duckdb::pyarrow::Table PyConnectionWrapper::FetchArrow(idx_t rows_per_batch, shared_ptr conn) { - return conn->FetchArrow(rows_per_batch); -} - -py::dict PyConnectionWrapper::FetchPyTorch(shared_ptr conn) { - return conn->FetchPyTorch(); -} - -py::dict PyConnectionWrapper::FetchTF(shared_ptr conn) { - return conn->FetchTF(); -} - -PolarsDataFrame PyConnectionWrapper::FetchPolars(idx_t rows_per_batch, shared_ptr conn) { - return conn->FetchPolars(rows_per_batch); -} - -duckdb::pyarrow::RecordBatchReader PyConnectionWrapper::FetchRecordBatchReader(const idx_t rows_per_batch, - shared_ptr conn) { - return conn->FetchRecordBatchReader(rows_per_batch); -} - -void PyConnectionWrapper::RegisterFilesystem(AbstractFileSystem file_system, shared_ptr conn) { - return conn->RegisterFilesystem(std::move(file_system)); -} -void PyConnectionWrapper::UnregisterFilesystem(const py::str &name, shared_ptr conn) { - return conn->UnregisterFilesystem(name); -} -py::list PyConnectionWrapper::ListFilesystems(shared_ptr conn) { - return conn->ListFilesystems(); -} -bool PyConnectionWrapper::FileSystemIsRegistered(const string &name, shared_ptr conn) { - return conn->FileSystemIsRegistered(name); -} - -unique_ptr PyConnectionWrapper::Values(py::object values, shared_ptr conn) { - return conn->Values(std::move(values)); -} - -unique_ptr PyConnectionWrapper::RunQuery(const py::object &query, const string &alias, - shared_ptr conn) { - return conn->RunQuery(query, alias); -} - -unique_ptr PyConnectionWrapper::ProjectDf(const PandasDataFrame &df, const py::object &expr, - shared_ptr conn) { - // FIXME: if we want to support passing in DuckDBPyExpressions here - // we could also accept 'expr' as a List[DuckDBPyExpression], without changing the signature - if (!py::isinstance(expr)) { - throw InvalidInputException("Please provide 'expr' as a string"); - } - return conn->FromDF(df)->Project(expr); -} - -unique_ptr PyConnectionWrapper::AliasDF(const PandasDataFrame &df, const string &expr, - shared_ptr conn) { - return conn->FromDF(df)->SetAlias(expr); -} - -unique_ptr PyConnectionWrapper::FilterDf(const PandasDataFrame &df, const string &expr, - shared_ptr conn) { - return conn->FromDF(df)->FilterFromExpression(expr); -} - -unique_ptr PyConnectionWrapper::LimitDF(const PandasDataFrame &df, int64_t n, - shared_ptr conn) { - return conn->FromDF(df)->Limit(n); -} - -unique_ptr PyConnectionWrapper::OrderDf(const PandasDataFrame &df, const string &expr, - shared_ptr conn) { - return conn->FromDF(df)->Order(expr); -} - -} // namespace duckdb From 831c9ee2bdcda4ec49a094ad9016878aafc7d4cc Mon Sep 17 00:00:00 2001 From: Tishj Date: Sat, 16 Mar 2024 20:41:28 +0100 Subject: [PATCH 020/201] all tests passing again --- tools/pythonpkg/duckdb/__init__.py | 88 ++++++++++++++----- tools/pythonpkg/duckdb_python.cpp | 82 +++++++---------- .../tests/fast/test_non_default_conn.py | 2 +- 3 files changed, 99 insertions(+), 73 deletions(-) diff --git a/tools/pythonpkg/duckdb/__init__.py b/tools/pythonpkg/duckdb/__init__.py index 3c33bd7ae4ed..d8d1ae36299e 100644 --- a/tools/pythonpkg/duckdb/__init__.py +++ b/tools/pythonpkg/duckdb/__init__.py @@ -3,7 +3,6 @@ # Modules import duckdb.functional as functional import duckdb.typing as typing -import inspect import functools _exported_symbols.extend([ @@ -15,7 +14,7 @@ from .duckdb import ( DuckDBPyRelation, DuckDBPyConnection, - Statement, + Statement, ExplainType, StatementType, ExpectedResultType, @@ -48,27 +47,76 @@ def is_dunder_method(method_name: str) -> bool: return False return method_name[:2] == '__' and method_name[:-3:-1] == '__' -def create_connection_wrapper(name): - # Define a decorator function that forwards attribute lookup to the default connection - @functools.wraps(getattr(DuckDBPyConnection, name)) - def decorator(*args, **kwargs): +# Takes the function to execute on the 'connection' +def create_wrapper(func): + def _wrapper(*args, **kwargs): connection = duckdb.connect(':default:') if 'connection' in kwargs: connection = kwargs.pop('connection') - return getattr(connection, name)(*args, **kwargs) - # Set docstring for the wrapper function - decorator.__doc__ = getattr(DuckDBPyConnection, name).__doc__ - return decorator + return func(connection, *args, **kwargs) + return _wrapper -methods = inspect.getmembers_static( - DuckDBPyConnection, - predicate=inspect.isfunction +# Takes the name of a DuckDBPyConnection function to wrap (copying signature, docs, etc) +# The 'func' is what gets executed when the function is called +def create_connection_wrapper(name, func): + # Define a decorator function that forwards attribute lookup to the default connection + return functools.wraps(getattr(DuckDBPyConnection, name))(create_wrapper(func)) + +# These are overloaded twice, we define them inside of C++ so pybind can deal with it +EXCLUDED_METHODS = [ + 'df', + 'arrow' +] +_exported_symbols.extend(EXCLUDED_METHODS) +from .duckdb import ( + df, + arrow ) -methods = [method for method in dir(DuckDBPyConnection) if not is_dunder_method(method)] -for name in methods: - wrapper_function = create_connection_wrapper(name) - globals()[name] = wrapper_function # Define the wrapper function in the module namespace - _exported_symbols.append(name) + +methods = [method for method in dir(DuckDBPyConnection) if not is_dunder_method(method) and method not in EXCLUDED_METHODS] +for method_name in methods: + def create_method_wrapper(method_name): + def call_method(conn, *args, **kwargs): + return getattr(conn, method_name)(*args, **kwargs) + return call_method + wrapper_function = create_connection_wrapper(method_name, create_method_wrapper(method_name)) + globals()[method_name] = wrapper_function # Define the wrapper function in the module namespace + _exported_symbols.append(method_name) + + +# Specialized "wrapper" methods + +SPECIAL_METHODS = [ + 'project', + 'distinct', + 'write_csv', + 'aggregate', + 'alias', + 'filter', + 'limit', + 'order', + 'query_df' +] + +for method_name in SPECIAL_METHODS: + def create_method_wrapper(name): + def _closure(name=name): + mapping = { + 'alias': 'set_alias', + 'query_df': 'query' + } + def call_method(con, df, *args, **kwargs): + if name in mapping: + mapped_name = mapping[name] + else: + mapped_name = name + return getattr(con.from_df(df), mapped_name)(*args, **kwargs) + return call_method + return _closure(name) + + wrapper_function = create_wrapper(create_method_wrapper(method_name)) + globals()[method_name] = wrapper_function # Define the wrapper function in the module namespace + _exported_symbols.append(method_name) # Enums from .duckdb import ( @@ -76,8 +124,8 @@ def decorator(*args, **kwargs): DEFAULT, RETURN_NULL, STANDARD, - COLUMNS, - ROWS + COLUMNS, + ROWS ) _exported_symbols.extend([ "ANALYZE", diff --git a/tools/pythonpkg/duckdb_python.cpp b/tools/pythonpkg/duckdb_python.cpp index 69d9fe5376ef..70e82c03970a 100644 --- a/tools/pythonpkg/duckdb_python.cpp +++ b/tools/pythonpkg/duckdb_python.cpp @@ -71,58 +71,36 @@ static py::list PyTokenize(const string &query) { } static void InitializeConnectionMethods(py::module_ &m) { - m.def("project", - [](const PandasDataFrame &df, const py::object &expr, - shared_ptr conn) -> unique_ptr { - // FIXME: if we want to support passing in DuckDBPyExpressions here - // we could also accept 'expr' as a List[DuckDBPyExpression], without changing the signature - if (!py::isinstance(expr)) { - throw InvalidInputException("Please provide 'expr' as a string"); - } - return conn->FromDF(df)->Project(expr); - }); - - // unique_ptr PyConnectionWrapper::DistinctDF(const PandasDataFrame &df, - // shared_ptr conn) { - // return conn->FromDF(df)->Distinct(); - //} - - // void PyConnectionWrapper::WriteCsvDF(const PandasDataFrame &df, const string &file, - // shared_ptr conn) { - // return conn->FromDF(df)->ToCSV(file); - //} - - // unique_ptr PyConnectionWrapper::QueryDF(const PandasDataFrame &df, const string &view_name, - // const string &sql_query, - // shared_ptr conn) { - // return conn->FromDF(df)->Query(view_name, sql_query); - //} - - // unique_ptr PyConnectionWrapper::AggregateDF(const PandasDataFrame &df, const string &expr, - // const string &groups, - // shared_ptr conn) { - // return conn->FromDF(df)->Aggregate(expr, groups); - //} - - // unique_ptr PyConnectionWrapper::AliasDF(const PandasDataFrame &df, const string &expr, - // shared_ptr conn) { - // return conn->FromDF(df)->SetAlias(expr); - //} - - // unique_ptr PyConnectionWrapper::FilterDf(const PandasDataFrame &df, const string &expr, - // shared_ptr conn) { - // return conn->FromDF(df)->FilterFromExpression(expr); - //} - - // unique_ptr PyConnectionWrapper::LimitDF(const PandasDataFrame &df, int64_t n, - // shared_ptr conn) { - // return conn->FromDF(df)->Limit(n); - //} - - // unique_ptr PyConnectionWrapper::OrderDf(const PandasDataFrame &df, const string &expr, - // shared_ptr conn) { - // return conn->FromDF(df)->Order(expr); - //} + // We define these "wrapper" methods inside of C++ because they are overloaded + // every other wrapper method is defined inside of __init__.py + m.def( + "arrow", + [](idx_t rows_per_batch, shared_ptr conn) -> duckdb::pyarrow::Table { + return conn->FetchArrow(rows_per_batch); + }, + "Fetch a result as Arrow table following execute()", py::arg("rows_per_batch") = 1000000, py::kw_only(), + py::arg("connection") = py::none()); + m.def( + "arrow", + [](py::object &arrow_object, shared_ptr conn) -> unique_ptr { + return conn->FromArrow(arrow_object); + }, + "Create a relation object from an Arrow object", py::arg("arrow_object"), py::kw_only(), + py::arg("connection") = py::none()); + m.def( + "df", + [](bool date_as_object, shared_ptr conn) -> PandasDataFrame { + return conn->FetchDF(date_as_object); + }, + "Fetch a result as DataFrame following execute()", py::kw_only(), py::arg("date_as_object") = false, + py::arg("connection") = py::none()); + m.def( + "df", + [](const PandasDataFrame &value, shared_ptr conn) -> unique_ptr { + return conn->FromDF(value); + }, + "Create a relation object from the DataFrame df", py::arg("df"), py::kw_only(), + py::arg("connection") = py::none()); } static void RegisterStatementType(py::handle &m) { diff --git a/tools/pythonpkg/tests/fast/test_non_default_conn.py b/tools/pythonpkg/tests/fast/test_non_default_conn.py index 0745f63d4e3e..bc9fa5f094e1 100644 --- a/tools/pythonpkg/tests/fast/test_non_default_conn.py +++ b/tools/pythonpkg/tests/fast/test_non_default_conn.py @@ -8,7 +8,7 @@ class TestNonDefaultConn(object): def test_values(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") - duckdb.values([1], duckdb_cursor).insert_into("t") + duckdb.values([1], connection=duckdb_cursor).insert_into("t") assert duckdb_cursor.execute("select count(*) from t").fetchall()[0] == (1,) def test_query(self, duckdb_cursor): From d513065397f9f429dd04b1baf128956b7e9ce30a Mon Sep 17 00:00:00 2001 From: Tishj Date: Sat, 16 Mar 2024 20:43:36 +0100 Subject: [PATCH 021/201] remove commented out code --- tools/pythonpkg/duckdb/__init__.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/tools/pythonpkg/duckdb/__init__.py b/tools/pythonpkg/duckdb/__init__.py index d8d1ae36299e..8f01da43bc4f 100644 --- a/tools/pythonpkg/duckdb/__init__.py +++ b/tools/pythonpkg/duckdb/__init__.py @@ -175,27 +175,11 @@ def call_method(con, df, *args, **kwargs): from .duckdb import ( - connect, - #project, - #aggregate, - #distinct, - #limit, - #query_df, - #order, - #alias, - #write_csv + connect ) _exported_symbols.extend([ - "connect", - #"project", - #"aggregate", - #"distinct", - #"limit", - #"query_df", - #"order", - #"alias", - #"write_csv" + "connect" ]) # Exceptions From 02e94ed0a5a986c2240e3b663619de0725e4d431 Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 18 Mar 2024 10:23:30 +0100 Subject: [PATCH 022/201] this is entirely redundant, because we have an implicit conversion from None todefault connection handled by pybind, but tidy will not be happy about passing by value if I don't use it, and its suggestion to make it a will break the implicit conversion rules we set up - so the *only* purpose of this change is to make tidy happy --- tools/pythonpkg/duckdb_python.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tools/pythonpkg/duckdb_python.cpp b/tools/pythonpkg/duckdb_python.cpp index 70e82c03970a..f5eb42d446ea 100644 --- a/tools/pythonpkg/duckdb_python.cpp +++ b/tools/pythonpkg/duckdb_python.cpp @@ -76,6 +76,9 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "arrow", [](idx_t rows_per_batch, shared_ptr conn) -> duckdb::pyarrow::Table { + if (!connection) { + connection = DuckDBPyConnection::DefaultConnection(); + } return conn->FetchArrow(rows_per_batch); }, "Fetch a result as Arrow table following execute()", py::arg("rows_per_batch") = 1000000, py::kw_only(), @@ -83,6 +86,9 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "arrow", [](py::object &arrow_object, shared_ptr conn) -> unique_ptr { + if (!connection) { + connection = DuckDBPyConnection::DefaultConnection(); + } return conn->FromArrow(arrow_object); }, "Create a relation object from an Arrow object", py::arg("arrow_object"), py::kw_only(), @@ -90,6 +96,9 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "df", [](bool date_as_object, shared_ptr conn) -> PandasDataFrame { + if (!connection) { + connection = DuckDBPyConnection::DefaultConnection(); + } return conn->FetchDF(date_as_object); }, "Fetch a result as DataFrame following execute()", py::kw_only(), py::arg("date_as_object") = false, @@ -97,6 +106,9 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "df", [](const PandasDataFrame &value, shared_ptr conn) -> unique_ptr { + if (!connection) { + connection = DuckDBPyConnection::DefaultConnection(); + } return conn->FromDF(value); }, "Create a relation object from the DataFrame df", py::arg("df"), py::kw_only(), From 482aa695be2cc9bc00f977794126f4d80a23b3db Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 18 Mar 2024 13:50:47 +0100 Subject: [PATCH 023/201] . --- tools/pythonpkg/duckdb_python.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tools/pythonpkg/duckdb_python.cpp b/tools/pythonpkg/duckdb_python.cpp index f5eb42d446ea..1289c5743833 100644 --- a/tools/pythonpkg/duckdb_python.cpp +++ b/tools/pythonpkg/duckdb_python.cpp @@ -76,8 +76,8 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "arrow", [](idx_t rows_per_batch, shared_ptr conn) -> duckdb::pyarrow::Table { - if (!connection) { - connection = DuckDBPyConnection::DefaultConnection(); + if (!conn) { + conn = DuckDBPyConnection::DefaultConnection(); } return conn->FetchArrow(rows_per_batch); }, @@ -86,8 +86,8 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "arrow", [](py::object &arrow_object, shared_ptr conn) -> unique_ptr { - if (!connection) { - connection = DuckDBPyConnection::DefaultConnection(); + if (!conn) { + conn = DuckDBPyConnection::DefaultConnection(); } return conn->FromArrow(arrow_object); }, @@ -96,8 +96,8 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "df", [](bool date_as_object, shared_ptr conn) -> PandasDataFrame { - if (!connection) { - connection = DuckDBPyConnection::DefaultConnection(); + if (!conn) { + conn = DuckDBPyConnection::DefaultConnection(); } return conn->FetchDF(date_as_object); }, @@ -106,8 +106,8 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "df", [](const PandasDataFrame &value, shared_ptr conn) -> unique_ptr { - if (!connection) { - connection = DuckDBPyConnection::DefaultConnection(); + if (!conn) { + conn = DuckDBPyConnection::DefaultConnection(); } return conn->FromDF(value); }, From d0ff7c077a177b274da9a8f737c06453f5e9f271 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 21 Mar 2024 11:24:00 +0100 Subject: [PATCH 024/201] started on giant json file containing all the information to generate stubs, pybind11 def and init.py --- .../pythonpkg/scripts/connection_methods.json | 856 ++++++++++++++++++ .../scripts/generate_function_definition.py | 3 + .../pyconnection/pyconnection.hpp | 2 +- tools/pythonpkg/src/pyconnection.cpp | 18 +- 4 files changed, 867 insertions(+), 12 deletions(-) create mode 100644 tools/pythonpkg/scripts/connection_methods.json create mode 100644 tools/pythonpkg/scripts/generate_function_definition.py diff --git a/tools/pythonpkg/scripts/connection_methods.json b/tools/pythonpkg/scripts/connection_methods.json new file mode 100644 index 000000000000..dedaef4cb967 --- /dev/null +++ b/tools/pythonpkg/scripts/connection_methods.json @@ -0,0 +1,856 @@ +[ + { + "name": "cursor", + "function": "Cursor", + "docs": "Create a duplicate of the current connection", + "return": "DuckDBPyConnection" + }, + { + "name": "register_filesystem", + "function": "RegisterFilesystem", + "docs": "Register a fsspec compliant filesystem", + "args": [ + { + "name": "filesystem", + "type": "str" + } + ] + }, + { + "name": "unregister_filesystem", + "function": "UnregisterFilesystem", + "docs": "Unregister a filesystem", + "args": [ + { + "name": "name", + "type": "str" + } + ] + }, + { + "name": "list_filesystems", + "function": "ListFilesystems", + "docs": "List registered filesystems, including builtin ones", + "return": "list" + }, + { + "name": "filesystem_is_registered", + "function": "FileSystemIsRegistered", + "docs": "Check if a filesystem with the provided name is currently registered", + "args": [ + { + "name": "name" + } + ], + "return": "bool" + }, + { + "name": "create_function", + "function": "RegisterScalarUDF", + "docs": "Create a DuckDB function out of the passing in Python function so it can be used in queries", + "args": [ + { + "name": "name", + "type": "str" + }, + { + "name": "function", + "type": "function" + }, + { + "name": "parameters", + "type": "Optional[List[DuckDBPyType]]", + "default": "None" + }, + { + "name": "return_type", + "type": "Optional[DuckDBPyType]", + "default": "None" + } + ], + "kwargs": [ + { + "name": "type", + "type": "Optional[PythonUDFType]", + "default": "PythonUDFType.NATIVE" + }, + { + "name": "null_handling", + "type": "Optional[FunctionNullHandling]", + "default": "FunctionNullHandling.DEFAULT" + }, + { + "name": "exception_handling", + "type": "Optional[PythonExceptionHandling]", + "default": "PythonExceptionHandling.DEFAULT" + }, + { + "name": "side_effects", + "type": "bool", + "default": "False" + } + ], + "return": "DuckDBPyConnection" + }, + { + "name": "remove_function", + "function": "UnregisterUDF", + "docs": "Remove a previously created function", + "args": [ + { + "name": "name", + "type": "str" + } + ], + "return": "DuckDBPyConnection" + }, + + { + "name": ["sqltype", "dtype", "type"], + "function": "Type", + "docs": "Create a type object by parsing the 'type_str' string", + "args": [ + { + "name": "type_str" + } + ], + "return": + }, + { + "name": "array_type", + "function": "ArrayType", + "docs": "Create an array type object of 'type'", + "args": [ + + ] + { + "name": "type" + .none(false) + }, + { + "name": "size" + .none(false + }, + }, + { + "name": "list_type", + "function": "ListType", + "docs": "Create a list type object of 'type'", + "args": [ + + ] + { + "name": "type" + .none(false + }, + }, + { + "name": "union_type", + "function": "UnionType", + "docs": "Create a union type object from 'members'", + "args": [ + + ] + { + "name": "members" + .none(false) + }, + }, + { + "name": "string_type", + "function": "StringType", + "docs": "Create a string type with an optional collation", + "args": [ + + ] + { + "name": "collation" + = string() + }, + }, + { + "name": "enum_type", + "function": "EnumType", + "docs": "Create an enum type of underlying 'type', consisting of the list of 'values'", + "args": [ + + ] + { + "name": "name" + }, + { + "name": "type" + }, + { + "name": "values" + }, + }, + { + "name": "decimal_type", + "function": "DecimalType", + "docs": "Create a decimal type with 'width' and 'scale'", + "args": [ + + ] + { + "name": "width" + }, + { + "name": "scale" + }, + }, + { + "name": ["struct_type", "row_type"], + "function": "StructType", + "docs": "Create a struct type object from 'fields'", + { + "name": "fields" + }, + }, + { + "name": "map_type", + "function": "MapType", + "docs": "Create a map type object from 'key_type' and 'value_type'", + "args": [ + + ] + { + "name": "key" + .none(false) + }, + { + "name": "value" + .none(false) + }, + }, + { + "name": "duplicate", + "function": "Cursor", + "docs": "Create a duplicate of the current connection", + }, + { + "name": "execute", + "function": "Execute", + "docs": "Execute the given SQL query, optionally using prepared statements with parameters set", + "args": [ + + ] + { + "name": "query" + }, + { + "name": "parameters" + = py::none() + }, + { + "name": "multiple_parameter_sets" + = false + }, + }, + { + "name": "executemany", + "function": "ExecuteMany", + "docs": "Execute the given prepared statement multiple times using the list of parameter sets in parameters", + "args": [ + + ] + { + "name": "query" + }, + { + "name": "parameters" + = py::none() + }, + }, + { + "name": "close", + "function": "Close", + "docs": "Close the connection" + }, + { + "name": "interrupt", + "function": "Interrupt", + "docs": "Interrupt pending operations" + }, + { + "name": "fetchone", + "function": "FetchOne", + "docs": "Fetch a single row from a result following execute" + }, + { + "name": "fetchmany", + "function": "FetchMany", + "docs": "Fetch the next set of rows from a result following execute" + { + "name": "size" + = 1 + }, + }, + { + "name": "fetchall", + "function": "FetchAll", + "docs": "Fetch all rows from a result following execute" + }, + { + "name": "fetchnumpy", + "function": "FetchNumpy", + "docs": "Fetch a result as list of NumPy arrays following execute" + }, + { + "name": "fetchdf", + "function": "FetchDF", + "docs": "Fetch a result as DataFrame following execute()", + py::kw_only(), + { + "name": "date_as_object" + = false + }, + }, + { + "name": "fetch_df", + "function": "FetchDF", + "docs": "Fetch a result as DataFrame following execute()", + py::kw_only(), + { + "name": "date_as_object" + = false + }, + }, + { + "name": "fetch_df_chunk", + "function": "FetchDFChunk", + "docs": "Fetch a chunk of the result as Data.Frame following execute()", + "args": [ + + ] + { + "name": "vectors_per_chunk" + = 1 + }, + py::kw_only(), + { + "name": "date_as_object" + = false + }, + }, + { + "name": "df", + "function": "FetchDF", + "docs": "Fetch a result as DataFrame following execute()", + py::kw_only(), + { + "name": "date_as_object" + = false + }, + }, + { + "name": "pl", + "function": "FetchPolars", + "docs": "Fetch a result as Polars DataFrame following execute()", + "args": [ + + ] + { + "name": "rows_per_batch" + = 1000000 + }, + }, + { + "name": "fetch_arrow_table", + "function": "FetchArrow", + "docs": "Fetch a result as Arrow table following execute()", + "args": [ + + ] + { + "name": "rows_per_batch" + = 1000000 + }, + }, + { + "name": "fetch_record_batch", + "function": "FetchRecordBatchReader", + "docs": "Fetch an Arrow RecordBatchReader following execute()", + "args": [ + + ] + { + "name": "rows_per_batch" + = 1000000 + }, + }, + { + "name": "arrow", + "function": "FetchArrow", + "docs": "Fetch a result as Arrow table following execute()", + "args": [ + + ] + { + "name": "rows_per_batch" + = 1000000 + }, + }, + { + "name": "torch", + "function": "FetchPyTorch", + "docs": "Fetch a result as dict of PyTorch Tensors following execute()" + + }, + { + "name": "tf", + "function": "FetchTF", + "docs": "Fetch a result as dict of TensorFlow Tensors following execute()" + + }, + { + "name": "begin", + "function": "Begin", + "docs": "Start a new transaction" + + }, + { + "name": "commit", + "function": "Commit", + "docs": "Commit changes performed within a transaction" + }, + { + "name": "rollback", + "function": "Rollback", + "docs": "Roll back changes performed within a transaction" + }, + { + "name": "append", + "function": "Append", + "docs": "Append the passed DataFrame to the named table", + "args": [ + + ] + { + "name": "table_name" + }, + { + "name": "df" + }, + py::kw_only(), + { + "name": "by_name" + = false + }, + }, + { + "name": "register", + "function": "RegisterPythonObject", + "docs": "Register the passed Python Object value for querying with a view", + "args": [ + + ] + { + "name": "view_name" + }, + { + "name": "python_object" + }, + }, + { + "name": "unregister", + "function": "UnregisterPythonObject", + "docs": "Unregister the view name", + "args": [ + + ] + { + "name": "view_name" + }, + }, + { + "name": "table", + "function": "Table", + "docs": "Create a relation object for the name'd table", + "args": [ + + ] + { + "name": "table_name" + }, + }, + { + "name": "view", + "function": "View", + "docs": "Create a relation object for the name'd view", + "args": [ + + ] + { + "name": "view_name" + }, + }, + { + "name": "values", + "function": "Values", + "docs": "Create a relation object from the passed values", + "args": [ + + ] + { + "name": "values" + }, + }, + { + "name": "table_function", + "function": "TableFunction", + "docs": "Create a relation object from the name'd table function with given parameters", + "args": [ + + ] + { + "name": "name" + }, + { + "name": "parameters" + = py::none() + }, + }, + { + "name": "read_json", + "function": "ReadJSON", + "docs": "Create a relation object from the JSON file in 'name'", + "args": [ + + ] + { + "name": "name" + }, + py::kw_only(), + { + "name": "columns" + = py::none() + }, + { + "name": "sample_size" + = py::none() + }, + { + "name": "maximum_depth" + = py::none() + }, + { + "name": "records" + = py::none() + }, + { + "name": "format" + = py::none() + }, + }, + { + "name": "extract_statements", + "function": "ExtractStatements", + "docs": "Parse the query string and extract the Statement object(s) produced", + "args": [ + + ] + { + "name": "query" + }, + }, + + { + "name": ["sql", "query", "from_query"], + "function": "RunQuery", + "docs": "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, otherwise run the query as-is.", + "args": [ + + ] + { + "name": "query" + }, + py::kw_only(), + { + "name": "alias" + = "" + }, + { + "name": "params" + = py::none( + }, + }, + + { + "name": ["read_csv", "from_csv_auto"], + "function": "ReadCSV", + "docs": "Create a relation object from the CSV file in 'name'", + "args": [ + + ] + { + "name": "name" + }, + py::kw_only(), + { + "name": "header" + = py::none() + }, + { + "name": "compression" + = py::none() + }, + { + "name": "sep" + = py::none() + }, + { + "name": "delimiter" + = py::none() + }, + { + "name": "dtype" + = py::none() + }, + { + "name": "na_values" + = py::none() + }, + { + "name": "skiprows" + = py::none() + }, + { + "name": "quotechar" + = py::none() + }, + { + "name": "escapechar" + = py::none() + }, + { + "name": "encoding" + = py::none() + }, + { + "name": "parallel" + = py::none() + }, + { + "name": "date_format" + = py::none() + }, + { + "name": "timestamp_format" + = py::none() + }, + { + "name": "sample_size" + = py::none() + }, + { + "name": "all_varchar" + = py::none() + }, + { + "name": "normalize_names" + = py::none() + }, + { + "name": "filename" + = py::none() + }, + { + "name": "null_padding" + = py::none() + }, + { + "name": "names" + = py::none( + }, + }, + { + "name": "from_df", + "function": "FromDF", + "docs": "Create a relation object from the Data.Frame in df", + "args": [ + + ] + { + "name": "df" + = py::none() + }, + }, + { + "name": "from_arrow", + "function": "FromArrow", + "docs": "Create a relation object from an Arrow object", + "args": [ + + ] + { + "name": "arrow_object" + }, + }, + + { + "name": ["from_parquet", "read_parquet"], + "function": "FromParquet", + "docs": "Create a relation object from the Parquet files in file_glob", + "args": [ + + ] + { + "name": "file_glob" + }, + { + "name": "binary_as_string" + = false + }, + py::kw_only(), + { + "name": "file_row_number" + = false + }, + { + "name": "filename" + = false + }, + { + "name": "hive_partitioning" + = false + }, + { + "name": "union_by_name" + = false + }, + { + "name": "compression" + = py::none( + }, + }, + { + "name": ["from_parquet", "read_parquet"], + "function": "FromParquets", + "docs": "Create a relation object from the Parquet files in file_globs", + "args": [ + + ] + { + "name": "file_globs" + }, + { + "name": "binary_as_string" + = false + }, + py::kw_only(), + { + "name": "file_row_number" + = false + }, + { + "name": "filename" + = false + }, + { + "name": "hive_partitioning" + = false + }, + { + "name": "union_by_name" + = false + }, + { + "name": "compression" + = py::none( + }, + }, + { + "name": "from_substrait", + "function": "FromSubstrait", + "docs": "Create a query object from protobuf plan", + "args": [ + + ] + { + "name": "proto" + }, + }, + { + "name": "get_substrait", + "function": "GetSubstrait", + "docs": "Serialize a query to protobuf", + "args": [ + + ] + { + "name": "query" + }, + py::kw_only(), + { + "name": "enable_optimizer" + = true + }, + }, + { + "name": "get_substrait_json", + "function": "GetSubstraitJSON", + "docs": "Serialize a query to protobuf on the JSON format", + "args": [ + + ] + { + "name": "query" + }, + py::kw_only(), + { + "name": "enable_optimizer" + = true + }, + }, + { + "name": "from_substrait_json", + "function": "FromSubstraitJSON", + "docs": "Create a query object from a JSON protobuf plan", + "args": [ + + ] + { + "name": "json" + }, + }, + { + "name": "get_table_names", + "function": "GetTableNames", + "docs": "Extract the required table names from a query", + "args": [ + + ] + { + "name": "query" + }, + }, + + { + "name": "install_extension", + "function": "InstallExtension", + "docs": "Install an extension by name", + "args": [ + + ] + { + "name": "extension" + }, + py::kw_only(), + { + "name": "force_install" + = false + }, + }, + { + "name": "load_extension", + "function": "LoadExtension", + "docs": "Load an installed extension", + "args": [ + + ] + { + "name": "extension" + }, + } +] diff --git a/tools/pythonpkg/scripts/generate_function_definition.py b/tools/pythonpkg/scripts/generate_function_definition.py new file mode 100644 index 000000000000..aca68675eb43 --- /dev/null +++ b/tools/pythonpkg/scripts/generate_function_definition.py @@ -0,0 +1,3 @@ +class FunctionDefinition: + def __init__(self): + pass diff --git a/tools/pythonpkg/src/include/duckdb_python/pyconnection/pyconnection.hpp b/tools/pythonpkg/src/include/duckdb_python/pyconnection/pyconnection.hpp index a0d572dc9a91..bf1d92f9c00a 100644 --- a/tools/pythonpkg/src/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/pyconnection/pyconnection.hpp @@ -116,7 +116,7 @@ struct DuckDBPyConnection : public std::enable_shared_from_this Execute(const py::object &query, py::object params = py::list(), bool many = false); - shared_ptr Execute(const string &query); + shared_ptr ExecuteFromString(const string &query); shared_ptr Append(const string &name, const PandasDataFrame &value, bool by_name); diff --git a/tools/pythonpkg/src/pyconnection.cpp b/tools/pythonpkg/src/pyconnection.cpp index 661813422c1e..f4cd9100a142 100644 --- a/tools/pythonpkg/src/pyconnection.cpp +++ b/tools/pythonpkg/src/pyconnection.cpp @@ -160,13 +160,9 @@ static void InitializeConnectionMethods(py::class_> DuckDBPyConnection::GetStatements(const py::obj throw InvalidInputException("Please provide either a DuckDBPyStatement or a string representing the query"); } -shared_ptr DuckDBPyConnection::Execute(const string &query) { +shared_ptr DuckDBPyConnection::ExecuteFromString(const string &query) { return Execute(py::str(query)); } @@ -1254,7 +1250,7 @@ shared_ptr DuckDBPyConnection::UnregisterPythonObject(const } shared_ptr DuckDBPyConnection::Begin() { - Execute("BEGIN TRANSACTION"); + ExecuteFromString("BEGIN TRANSACTION"); return shared_from_this(); } @@ -1262,12 +1258,12 @@ shared_ptr DuckDBPyConnection::Commit() { if (connection->context->transaction.IsAutoCommit()) { return shared_from_this(); } - Execute("COMMIT"); + ExecuteFromString("COMMIT"); return shared_from_this(); } shared_ptr DuckDBPyConnection::Rollback() { - Execute("ROLLBACK"); + ExecuteFromString("ROLLBACK"); return shared_from_this(); } From 4848acf50ec10ec905dda249978055c9efbd278c Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 21 Mar 2024 15:29:35 +0100 Subject: [PATCH 025/201] move the setting into the buffer manager + temporary directory handle --- src/common/file_system.cpp | 16 ++-- src/include/duckdb/common/file_system.hpp | 5 +- src/include/duckdb/main/config.hpp | 42 ++-------- src/include/duckdb/storage/buffer_manager.hpp | 19 ++++- .../storage/standard_buffer_manager.hpp | 29 +++++-- .../duckdb/storage/temporary_file_manager.hpp | 8 +- src/main/config.cpp | 30 +------ src/main/database.cpp | 7 +- src/main/settings/settings.cpp | 43 ++++++++-- src/storage/buffer_manager.cpp | 8 +- src/storage/standard_buffer_manager.cpp | 84 ++++++++++++------- src/storage/temporary_file_manager.cpp | 44 +++++++--- .../max_swap_space_inmemory.test | 5 +- .../max_swap_space_persistent.test | 5 +- 14 files changed, 197 insertions(+), 148 deletions(-) diff --git a/src/common/file_system.cpp b/src/common/file_system.cpp index f21e6c46dd79..c58bb588ee07 100644 --- a/src/common/file_system.cpp +++ b/src/common/file_system.cpp @@ -84,7 +84,7 @@ void FileSystem::SetWorkingDirectory(const string &path) { } } -idx_t FileSystem::GetAvailableMemory() { +optional_idx FileSystem::GetAvailableMemory() { errno = 0; #ifdef __MVS__ @@ -95,16 +95,16 @@ idx_t FileSystem::GetAvailableMemory() { idx_t max_memory = MinValue((idx_t)sysconf(_SC_PHYS_PAGES) * (idx_t)sysconf(_SC_PAGESIZE), UINTPTR_MAX); #endif if (errno != 0) { - return DConstants::INVALID_INDEX; + return optional_idx(); } return max_memory; } -idx_t FileSystem::GetAvailableDiskSpace(const string &path) { +optional_idx FileSystem::GetAvailableDiskSpace(const string &path) { struct statvfs vfs; if (statvfs(path.c_str(), &vfs) == -1) { - return DConstants::INVALID_INDEX; + optional_idx(); } auto block_size = vfs.f_frsize; // These are the blocks available for creating new files or extending existing ones @@ -112,7 +112,7 @@ idx_t FileSystem::GetAvailableDiskSpace(const string &path) { idx_t available_disk_space = DConstants::INVALID_INDEX; if (!TryMultiplyOperator::Operation(static_cast(block_size), static_cast(available_blocks), available_disk_space)) { - return DConstants::INVALID_INDEX; + return optional_idx(); } return available_disk_space; } @@ -214,15 +214,15 @@ idx_t FileSystem::GetAvailableMemory() { if (GlobalMemoryStatusEx(&mem_state)) { return MinValue(mem_state.ullTotalPhys, UINTPTR_MAX); } - return DConstants::INVALID_INDEX; + return optional_idx(); } -idx_t FileSystem::GetAvailableDiskSpace(const string &path) { +optional_idx FileSystem::GetAvailableDiskSpace(const string &path) { ULARGE_INTEGER available_bytes, total_bytes, free_bytes; auto unicode_path = WindowsUtil::UTF8ToUnicode(path.c_str()); if (!GetDiskFreeSpaceExW(unicode_path.c_str(), &available_bytes, &total_bytes, &free_bytes)) { - return DConstants::INVALID_INDEX; + return optional_idx(); } (void)total_bytes; (void)free_bytes; diff --git a/src/include/duckdb/common/file_system.hpp b/src/include/duckdb/common/file_system.hpp index e8b11d1802e7..d39abff5362e 100644 --- a/src/include/duckdb/common/file_system.hpp +++ b/src/include/duckdb/common/file_system.hpp @@ -16,6 +16,7 @@ #include "duckdb/common/vector.hpp" #include "duckdb/common/enums/file_glob_options.hpp" #include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/optional_idx.hpp" #include #undef CreateDirectory @@ -187,9 +188,9 @@ class FileSystem { //! Expands a given path, including e.g. expanding the home directory of the user DUCKDB_API virtual string ExpandPath(const string &path); //! Returns the system-available memory in bytes. Returns DConstants::INVALID_INDEX if the system function fails. - DUCKDB_API static idx_t GetAvailableMemory(); + DUCKDB_API static optional_idx GetAvailableMemory(); //! Returns the space available on the disk. Returns DConstants::INVALID_INDEX if the information was not available. - DUCKDB_API static idx_t GetAvailableDiskSpace(const string &path); + DUCKDB_API static optional_idx GetAvailableDiskSpace(const string &path); //! Path separator for path DUCKDB_API virtual string PathSeparator(const string &path); //! Checks if path is starts with separator (i.e., '/' on UNIX '\\' on Windows) diff --git a/src/include/duckdb/main/config.hpp b/src/include/duckdb/main/config.hpp index fcfda6ea8c40..1441a6d1a986 100644 --- a/src/include/duckdb/main/config.hpp +++ b/src/include/duckdb/main/config.hpp @@ -61,37 +61,6 @@ typedef void (*reset_global_function_t)(DatabaseInstance *db, DBConfig &config); typedef void (*reset_local_function_t)(ClientContext &context); typedef Value (*get_setting_function_t)(ClientContext &context); -struct NumericSetting { -public: - NumericSetting() : value(0), set_by_user(false) { - } - -public: - NumericSetting &operator=(idx_t val) = delete; - -public: - operator idx_t() { - return value; - } - -public: - bool ExplicitlySet() const { - return set_by_user; - } - void SetDefault(idx_t val) { - value = val; - set_by_user = false; - } - void SetExplicit(idx_t val) { - value = val; - set_by_user = true; - } - -private: - idx_t value; - bool set_by_user; -}; - struct ConfigurationOption { const char *name; const char *description; @@ -145,14 +114,14 @@ struct DBConfigOptions { #endif //! Override for the default extension repository string custom_extension_repo = ""; - //! Override for the default autoload extensoin repository + //! Override for the default autoload extension repository string autoinstall_extension_repo = ""; //! The maximum memory used by the database system (in bytes). Default: 80% of System available memory - idx_t maximum_memory = (idx_t)-1; - //! The maximum size of the 'temp_directory' folder when set (in bytes) - NumericSetting maximum_swap_space = NumericSetting(); + idx_t maximum_memory = DConstants::INVALID_INDEX; + //! The maximum size of the 'temp_directory' folder when set (in bytes). Default: 90% of available disk space. + idx_t maximum_swap_space = DConstants::INVALID_INDEX; //! The maximum amount of CPU threads used by the database system. Default: all available. - idx_t maximum_threads = (idx_t)-1; + idx_t maximum_threads = DConstants::INVALID_INDEX; //! The number of external threads that work on DuckDB tasks. Default: 1. //! Must be smaller or equal to maximum_threads. idx_t external_threads = 1; @@ -310,7 +279,6 @@ struct DBConfig { DUCKDB_API IndexTypeSet &GetIndexTypes(); static idx_t GetSystemMaxThreads(FileSystem &fs); void SetDefaultMaxMemory(); - void SetDefaultMaxSwapSpace(optional_ptr db); void SetDefaultTempDirectory(); OrderType ResolveOrder(OrderType order_type) const; diff --git a/src/include/duckdb/storage/buffer_manager.hpp b/src/include/duckdb/storage/buffer_manager.hpp index 7238be7844df..19e29aa89365 100644 --- a/src/include/duckdb/storage/buffer_manager.hpp +++ b/src/include/duckdb/storage/buffer_manager.hpp @@ -41,10 +41,17 @@ class BufferManager { virtual void ReAllocate(shared_ptr &handle, idx_t block_size) = 0; virtual BufferHandle Pin(shared_ptr &handle) = 0; virtual void Unpin(shared_ptr &handle) = 0; + //! Returns the currently allocated memory virtual idx_t GetUsedMemory() const = 0; //! Returns the maximum available memory virtual idx_t GetMaxMemory() const = 0; + //! Returns the currently used swap space + virtual idx_t GetUsedSwap() = 0; + //! Returns the maximum swap space that can be used + virtual optional_idx GetMaxSwap() = 0; + + //! Returns a new block of memory that is smaller than Storage::BLOCK_SIZE virtual shared_ptr RegisterSmallMemory(idx_t block_size); virtual DUCKDB_API Allocator &GetBufferAllocator(); virtual DUCKDB_API void ReserveMemory(idx_t size); @@ -52,20 +59,21 @@ class BufferManager { virtual vector GetMemoryUsageInfo() const = 0; //! Set a new memory limit to the buffer manager, throws an exception if the new limit is too low and not enough //! blocks can be evicted - virtual void SetLimit(idx_t limit = (idx_t)-1); + virtual void SetMemoryLimit(idx_t limit = (idx_t)-1); + virtual void SetSwapLimit(optional_idx limit = optional_idx()); + virtual vector GetTemporaryFiles(); virtual const string &GetTemporaryDirectory(); virtual void SetTemporaryDirectory(const string &new_dir); - virtual DatabaseInstance &GetDatabase(); virtual bool HasTemporaryDirectory() const; + //! Construct a managed buffer. virtual unique_ptr ConstructManagedBuffer(idx_t size, unique_ptr &&source, FileBufferType type = FileBufferType::MANAGED_BUFFER); //! Get the underlying buffer pool responsible for managing the buffers virtual BufferPool &GetBufferPool() const; - //! Get the manager that assigns reservations for temporary memory, e.g., for query intermediates - virtual TemporaryMemoryManager &GetTemporaryMemoryManager(); + virtual DatabaseInstance &GetDatabase(); // Static methods DUCKDB_API static BufferManager &GetBufferManager(DatabaseInstance &db); DUCKDB_API static BufferManager &GetBufferManager(ClientContext &context); @@ -77,6 +85,9 @@ class BufferManager { //! Returns the maximum available memory for a given query idx_t GetQueryMaxMemory() const; + //! Get the manager that assigns reservations for temporary memory, e.g., for query intermediates + virtual TemporaryMemoryManager &GetTemporaryMemoryManager(); + protected: virtual void PurgeQueue() = 0; virtual void AddToEvictionQueue(shared_ptr &handle); diff --git a/src/include/duckdb/storage/standard_buffer_manager.hpp b/src/include/duckdb/storage/standard_buffer_manager.hpp index de36f9477ad5..fe64355748c3 100644 --- a/src/include/duckdb/storage/standard_buffer_manager.hpp +++ b/src/include/duckdb/storage/standard_buffer_manager.hpp @@ -50,6 +50,8 @@ class StandardBufferManager : public BufferManager { idx_t GetUsedMemory() const final override; idx_t GetMaxMemory() const final override; + idx_t GetUsedSwap() final override; + optional_idx GetMaxSwap() final override; //! Allocate an in-memory buffer with a single pin. //! The allocated memory is released when the buffer handle is destroyed. @@ -64,7 +66,8 @@ class StandardBufferManager : public BufferManager { //! Set a new memory limit to the buffer manager, throws an exception if the new limit is too low and not enough //! blocks can be evicted - void SetLimit(idx_t limit = (idx_t)-1) final override; + void SetMemoryLimit(idx_t limit = (idx_t)-1) final override; + void SetSwapLimit(optional_idx limit = optional_idx()) final override; //! Returns informaton about memory usage vector GetMemoryUsageInfo() const override; @@ -73,7 +76,7 @@ class StandardBufferManager : public BufferManager { vector GetTemporaryFiles() final override; const string &GetTemporaryDirectory() final override { - return temp_directory; + return temporary_directory.path; } void SetTemporaryDirectory(const string &new_dir) final override; @@ -136,17 +139,27 @@ class StandardBufferManager : public BufferManager { //! overwrites the data within with garbage. Any readers that do not hold the pin will notice void VerifyZeroReaders(shared_ptr &handle); +protected: + // These are stored here because temp_directory creation is lazy + // so we need to store information related to the temporary directory before it's created + struct TemporaryFileData { + //! The directory name where temporary files are stored + string path; + //! Lock for creating the temp handle + mutex lock; + //! Handle for the temporary directory + unique_ptr handle; + //! The maximum swap space that can be used + optional_idx maximum_swap_space = optional_idx(); + }; + protected: //! The database instance DatabaseInstance &db; //! The buffer pool BufferPool &buffer_pool; - //! The directory name where temporary files are stored - string temp_directory; - //! Lock for creating the temp handle - mutex temp_handle_lock; - //! Handle for the temporary directory - unique_ptr temp_directory_handle; + //! The variables related to temporary file management + TemporaryFileData temporary_directory; //! The temporary id used for managed buffers atomic temporary_id; //! Allocator associated with the buffer manager, that passes all allocations through this buffer manager diff --git a/src/include/duckdb/storage/temporary_file_manager.hpp b/src/include/duckdb/storage/temporary_file_manager.hpp index 7c118fb4cf76..9ab77e42e958 100644 --- a/src/include/duckdb/storage/temporary_file_manager.hpp +++ b/src/include/duckdb/storage/temporary_file_manager.hpp @@ -128,7 +128,7 @@ class TemporaryFileHandle { class TemporaryDirectoryHandle { public: - TemporaryDirectoryHandle(DatabaseInstance &db, string path_p); + TemporaryDirectoryHandle(DatabaseInstance &db, string path_p, optional_idx max_swap_space); ~TemporaryDirectoryHandle(); TemporaryFileManager &GetTempFile(); @@ -146,7 +146,7 @@ class TemporaryDirectoryHandle { class TemporaryFileManager { public: - TemporaryFileManager(DatabaseInstance &db, const string &temp_directory_p); + TemporaryFileManager(DatabaseInstance &db, const string &temp_directory_p, optional_idx max_swap_space); ~TemporaryFileManager(); public: @@ -164,6 +164,8 @@ class TemporaryFileManager { void DeleteTemporaryBuffer(block_id_t id); vector GetTemporaryFiles(); idx_t GetTotalUsedSpaceInBytes(); + optional_idx GetMaxSwapSpace() const; + void SetMaxSwapSpace(optional_idx limit); //! Register temporary file size growth void IncreaseSizeOnDisk(idx_t amount); //! Register temporary file size decrease @@ -189,6 +191,8 @@ class TemporaryFileManager { BlockIndexManager index_manager; //! The size in bytes of the temporary files that are currently alive atomic size_on_disk; + //! The max amount of disk space that can be used + idx_t max_swap_space; }; } // namespace duckdb diff --git a/src/main/config.cpp b/src/main/config.cpp index f8f8ca1d1101..f280cf31b6c5 100644 --- a/src/main/config.cpp +++ b/src/main/config.cpp @@ -256,10 +256,7 @@ bool DBConfig::IsInMemoryDatabase(const char *database_path) { // '' empty string return true; } - constexpr const char *IN_MEMORY_PATH_PREFIX = ":memory:"; - const idx_t PREFIX_LENGTH = strlen(IN_MEMORY_PATH_PREFIX); - if (strncmp(database_path, IN_MEMORY_PATH_PREFIX, PREFIX_LENGTH) == 0) { - // Starts with :memory:, i.e ':memory:named_conn' is valid + if (strcmp(database_path, ":memory:") == 0) { return true; } return false; @@ -275,8 +272,8 @@ IndexTypeSet &DBConfig::GetIndexTypes() { void DBConfig::SetDefaultMaxMemory() { auto memory = FileSystem::GetAvailableMemory(); - if (memory != DConstants::INVALID_INDEX) { - options.maximum_memory = memory * 8 / 10; + if (memory.IsValid()) { + options.maximum_memory = memory.GetIndex() * 8 / 10; } } @@ -288,27 +285,6 @@ void DBConfig::SetDefaultTempDirectory() { } } -void DBConfig::SetDefaultMaxSwapSpace(optional_ptr db) { - options.maximum_swap_space.SetDefault(0); - if (options.temporary_directory.empty()) { - return; - } - if (!db) { - return; - } - auto &fs = FileSystem::GetFileSystem(*db); - if (!fs.DirectoryExists(options.temporary_directory)) { - // Directory doesnt exist yet, we will look up the disk space once we have created the directory - // FIXME: do we want to proactively create the directory instead ??? - return; - } - // Use the available disk space if temp directory is set - auto disk_space = FileSystem::GetAvailableDiskSpace(options.temporary_directory); - // Only use 90% of the available disk space - auto default_value = disk_space == DConstants::INVALID_INDEX ? 0 : static_cast(disk_space) * 0.9; - options.maximum_swap_space.SetDefault(default_value); -} - void DBConfig::CheckLock(const string &name) { if (!options.lock_configuration) { // not locked diff --git a/src/main/database.cpp b/src/main/database.cpp index e561eeab5a50..21598a6522b0 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -334,13 +334,10 @@ void DatabaseInstance::Configure(DBConfig &new_config, const char *database_path if (new_config.secret_manager) { config.secret_manager = std::move(new_config.secret_manager); } - if (config.options.maximum_memory == (idx_t)-1) { + if (config.options.maximum_memory == DConstants::INVALID_INDEX) { config.SetDefaultMaxMemory(); } - if (!config.options.maximum_swap_space.ExplicitlySet()) { - config.SetDefaultMaxSwapSpace(this); - } - if (new_config.options.maximum_threads == (idx_t)-1) { + if (new_config.options.maximum_threads == DConstants::INVALID_INDEX) { config.options.maximum_threads = config.GetSystemMaxThreads(*config.file_system); } config.allocator = std::move(new_config.allocator); diff --git a/src/main/settings/settings.cpp b/src/main/settings/settings.cpp index d237fb7e882b..21ef43e50ee9 100644 --- a/src/main/settings/settings.cpp +++ b/src/main/settings/settings.cpp @@ -959,7 +959,7 @@ Value MaximumExpressionDepthSetting::GetSetting(ClientContext &context) { void MaximumMemorySetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { config.options.maximum_memory = DBConfig::ParseMemoryLimit(input.ToString()); if (db) { - BufferManager::GetBufferManager(*db).SetLimit(config.options.maximum_memory); + BufferManager::GetBufferManager(*db).SetMemoryLimit(config.options.maximum_memory); } } @@ -976,18 +976,46 @@ Value MaximumMemorySetting::GetSetting(ClientContext &context) { // Maximum Temp Directory Size //===--------------------------------------------------------------------===// void MaximumTempDirectorySize::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - // FIXME: should this not use 'SetExplicit' when the value is 0? - // So it acts as RESET instead when 0 is passed? - config.options.maximum_swap_space.SetExplicit(DBConfig::ParseMemoryLimit(input.ToString())); + idx_t maximum_swap_space = DConstants::INVALID_INDEX; + if (input.ToString() != "-1") { + maximum_swap_space = DBConfig::ParseMemoryLimit(input.ToString()); + } + config.options.maximum_swap_space = maximum_swap_space; + if (!db) { + return; + } + auto &buffer_manager = BufferManager::GetBufferManager(*db); + if (maximum_swap_space == DConstants::INVALID_INDEX) { + buffer_manager.SetSwapLimit(); + } else { + buffer_manager.SetSwapLimit(maximum_swap_space); + } } void MaximumTempDirectorySize::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.SetDefaultMaxSwapSpace(db); + config.options.maximum_swap_space = DConstants::INVALID_INDEX; + if (!db) { + return; + } + auto &buffer_manager = BufferManager::GetBufferManager(*db); + buffer_manager.SetSwapLimit(); } Value MaximumTempDirectorySize::GetSetting(ClientContext &context) { auto &config = DBConfig::GetConfig(context); - return Value(StringUtil::BytesToHumanReadableString(config.options.maximum_swap_space)); + if (config.options.maximum_swap_space != DConstants::INVALID_INDEX) { + // Explicitly set by the user + return Value(StringUtil::BytesToHumanReadableString(config.options.maximum_swap_space)); + } + auto &buffer_manager = BufferManager::GetBufferManager(context); + // Database is initialized, use the setting from the temporary directory + auto max_swap = buffer_manager.GetMaxSwap(); + if (max_swap.IsValid()) { + return Value(StringUtil::BytesToHumanReadableString(max_swap.GetIndex())); + } else { + // The temp directory has not been used yet + return Value(StringUtil::BytesToHumanReadableString(0)); + } } //===--------------------------------------------------------------------===// @@ -1271,9 +1299,6 @@ Value SecretDirectorySetting::GetSetting(ClientContext &context) { void TempDirectorySetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { config.options.temporary_directory = input.ToString(); config.options.use_temporary_directory = !config.options.temporary_directory.empty(); - if (!config.options.maximum_swap_space.ExplicitlySet()) { - config.SetDefaultMaxSwapSpace(db); - } if (db) { auto &buffer_manager = BufferManager::GetBufferManager(*db); buffer_manager.SetTemporaryDirectory(config.options.temporary_directory); diff --git a/src/storage/buffer_manager.cpp b/src/storage/buffer_manager.cpp index 5b0a89296c9c..64eceba4807a 100644 --- a/src/storage/buffer_manager.cpp +++ b/src/storage/buffer_manager.cpp @@ -27,8 +27,12 @@ void BufferManager::FreeReservedMemory(idx_t size) { throw NotImplementedException("This type of BufferManager can not free reserved memory"); } -void BufferManager::SetLimit(idx_t limit) { - throw NotImplementedException("This type of BufferManager can not set a limit"); +void BufferManager::SetMemoryLimit(idx_t limit) { + throw NotImplementedException("This type of BufferManager can not set a memory limit"); +} + +void BufferManager::SetSwapLimit(optional_idx limit) { + throw NotImplementedException("This type of BufferManager can not set a swap limit"); } vector BufferManager::GetTemporaryFiles() { diff --git a/src/storage/standard_buffer_manager.cpp b/src/storage/standard_buffer_manager.cpp index ab5ad560ba12..feb61f75a536 100644 --- a/src/storage/standard_buffer_manager.cpp +++ b/src/storage/standard_buffer_manager.cpp @@ -36,16 +36,18 @@ unique_ptr StandardBufferManager::ConstructManagedBuffer(idx_t size, } void StandardBufferManager::SetTemporaryDirectory(const string &new_dir) { - if (temp_directory_handle) { + lock_guard guard(temporary_directory.lock); + if (temporary_directory.handle) { throw NotImplementedException("Cannot switch temporary directory after the current one has been used"); } - this->temp_directory = new_dir; + temporary_directory.path = new_dir; } StandardBufferManager::StandardBufferManager(DatabaseInstance &db, string tmp) - : BufferManager(), db(db), buffer_pool(db.GetBufferPool()), temp_directory(std::move(tmp)), - temporary_id(MAXIMUM_BLOCK), buffer_allocator(BufferAllocatorAllocate, BufferAllocatorFree, - BufferAllocatorRealloc, make_uniq(*this)) { + : BufferManager(), db(db), buffer_pool(db.GetBufferPool()), temporary_id(MAXIMUM_BLOCK), + buffer_allocator(BufferAllocatorAllocate, BufferAllocatorFree, BufferAllocatorRealloc, + make_uniq(*this)) { + temporary_directory.path = std::move(tmp); temp_block_manager = make_uniq(*this); for (idx_t i = 0; i < MEMORY_TAG_COUNT; i++) { evicted_data_per_tag[i] = 0; @@ -70,6 +72,22 @@ idx_t StandardBufferManager::GetMaxMemory() const { return buffer_pool.GetMaxMemory(); } +idx_t StandardBufferManager::GetUsedSwap() { + lock_guard guard(temporary_directory.lock); + if (!temporary_directory.handle) { + return 0; + } + return temporary_directory.handle->GetTempFile().GetTotalUsedSpaceInBytes(); +} + +optional_idx StandardBufferManager::GetMaxSwap() { + lock_guard guard(temporary_directory.lock); + if (!temporary_directory.handle) { + return optional_idx(); + } + return temporary_directory.handle->GetTempFile().GetMaxSwapSpace(); +} + template TempBufferPoolReservation StandardBufferManager::EvictBlocksOrThrow(MemoryTag tag, idx_t memory_delta, unique_ptr *buffer, ARGS... args) { @@ -232,10 +250,19 @@ void StandardBufferManager::Unpin(shared_ptr &handle) { } } -void StandardBufferManager::SetLimit(idx_t limit) { +void StandardBufferManager::SetMemoryLimit(idx_t limit) { buffer_pool.SetLimit(limit, InMemoryWarning()); } +void StandardBufferManager::SetSwapLimit(optional_idx limit) { + lock_guard guard(temporary_directory.lock); + if (temporary_directory.handle) { + temporary_directory.handle->GetTempFile().SetMaxSwapSpace(limit); + } else { + temporary_directory.maximum_swap_space = limit; + } +} + vector StandardBufferManager::GetMemoryUsageInfo() const { vector result; for (idx_t k = 0; k < MEMORY_TAG_COUNT; k++) { @@ -259,19 +286,20 @@ unique_ptr StandardBufferManager::ReadTemporaryBufferInternal(Buffer string StandardBufferManager::GetTemporaryPath(block_id_t id) { auto &fs = FileSystem::GetFileSystem(db); - return fs.JoinPath(temp_directory, "duckdb_temp_block-" + to_string(id) + ".block"); + return fs.JoinPath(temporary_directory.path, "duckdb_temp_block-" + to_string(id) + ".block"); } void StandardBufferManager::RequireTemporaryDirectory() { - if (temp_directory.empty()) { + if (temporary_directory.path.empty()) { throw InvalidInputException( "Out-of-memory: cannot write buffer because no temporary directory is specified!\nTo enable " "temporary buffer eviction set a temporary directory using PRAGMA temp_directory='/path/to/tmp.tmp'"); } - lock_guard temp_handle_guard(temp_handle_lock); - if (!temp_directory_handle) { + lock_guard guard(temporary_directory.lock); + if (!temporary_directory.handle) { // temp directory has not been created yet: initialize it - temp_directory_handle = make_uniq(db, temp_directory); + temporary_directory.handle = + make_uniq(db, temporary_directory.path, temporary_directory.maximum_swap_space); } } @@ -279,7 +307,7 @@ void StandardBufferManager::WriteTemporaryBuffer(MemoryTag tag, block_id_t block RequireTemporaryDirectory(); if (buffer.size == Storage::BLOCK_SIZE) { evicted_data_per_tag[uint8_t(tag)] += Storage::BLOCK_SIZE; - temp_directory_handle->GetTempFile().WriteTemporaryBuffer(block_id, buffer); + temporary_directory.handle->GetTempFile().WriteTemporaryBuffer(block_id, buffer); return; } evicted_data_per_tag[uint8_t(tag)] += buffer.size; @@ -295,11 +323,11 @@ void StandardBufferManager::WriteTemporaryBuffer(MemoryTag tag, block_id_t block unique_ptr StandardBufferManager::ReadTemporaryBuffer(MemoryTag tag, block_id_t id, unique_ptr reusable_buffer) { - D_ASSERT(!temp_directory.empty()); - D_ASSERT(temp_directory_handle.get()); - if (temp_directory_handle->GetTempFile().HasTemporaryBuffer(id)) { + D_ASSERT(!temporary_directory.path.empty()); + D_ASSERT(temporary_directory.handle.get()); + if (temporary_directory.handle->GetTempFile().HasTemporaryBuffer(id)) { evicted_data_per_tag[uint8_t(tag)] -= Storage::BLOCK_SIZE; - return temp_directory_handle->GetTempFile().ReadTemporaryBuffer(id, std::move(reusable_buffer)); + return temporary_directory.handle->GetTempFile().ReadTemporaryBuffer(id, std::move(reusable_buffer)); } idx_t block_size; // open the temporary file and read the size @@ -318,20 +346,20 @@ unique_ptr StandardBufferManager::ReadTemporaryBuffer(MemoryTag tag, } void StandardBufferManager::DeleteTemporaryFile(block_id_t id) { - if (temp_directory.empty()) { + if (temporary_directory.path.empty()) { // no temporary directory specified: nothing to delete return; } { - lock_guard temp_handle_guard(temp_handle_lock); - if (!temp_directory_handle) { + lock_guard guard(temporary_directory.lock); + if (!temporary_directory.handle) { // temporary directory was not initialized yet: nothing to delete return; } } // check if we should delete the file from the shared pool of files, or from the general file system - if (temp_directory_handle->GetTempFile().HasTemporaryBuffer(id)) { - temp_directory_handle->GetTempFile().DeleteTemporaryBuffer(id); + if (temporary_directory.handle->GetTempFile().HasTemporaryBuffer(id)) { + temporary_directory.handle->GetTempFile().DeleteTemporaryBuffer(id); return; } auto &fs = FileSystem::GetFileSystem(db); @@ -342,22 +370,22 @@ void StandardBufferManager::DeleteTemporaryFile(block_id_t id) { } bool StandardBufferManager::HasTemporaryDirectory() const { - return !temp_directory.empty(); + return !temporary_directory.path.empty(); } vector StandardBufferManager::GetTemporaryFiles() { vector result; - if (temp_directory.empty()) { + if (temporary_directory.path.empty()) { return result; } { - lock_guard temp_handle_guard(temp_handle_lock); - if (temp_directory_handle) { - result = temp_directory_handle->GetTempFile().GetTemporaryFiles(); + lock_guard temp_handle_guard(temporary_directory.lock); + if (temporary_directory.handle) { + result = temporary_directory.handle->GetTempFile().GetTemporaryFiles(); } } auto &fs = FileSystem::GetFileSystem(db); - fs.ListFiles(temp_directory, [&](const string &name, bool is_dir) { + fs.ListFiles(temporary_directory.path, [&](const string &name, bool is_dir) { if (is_dir) { return; } @@ -375,7 +403,7 @@ vector StandardBufferManager::GetTemporaryFiles() { } const char *StandardBufferManager::InMemoryWarning() { - if (!temp_directory.empty()) { + if (!temporary_directory.path.empty()) { return ""; } return "\nDatabase is launched in in-memory mode and no temporary directory is specified." diff --git a/src/storage/temporary_file_manager.cpp b/src/storage/temporary_file_manager.cpp index d4fa09470596..d16cc23f1d0f 100644 --- a/src/storage/temporary_file_manager.cpp +++ b/src/storage/temporary_file_manager.cpp @@ -189,18 +189,14 @@ idx_t TemporaryFileHandle::GetPositionInFile(idx_t index) { // TemporaryDirectoryHandle //===--------------------------------------------------------------------===// -TemporaryDirectoryHandle::TemporaryDirectoryHandle(DatabaseInstance &db, string path_p) - : db(db), temp_directory(std::move(path_p)), temp_file(make_uniq(db, temp_directory)) { +TemporaryDirectoryHandle::TemporaryDirectoryHandle(DatabaseInstance &db, string path_p, optional_idx max_swap_space) + : db(db), temp_directory(std::move(path_p)), + temp_file(make_uniq(db, temp_directory, max_swap_space)) { auto &fs = FileSystem::GetFileSystem(db); if (!temp_directory.empty()) { if (!fs.DirectoryExists(temp_directory)) { - auto &config = DBConfig::GetConfig(db); fs.CreateDirectory(temp_directory); created_directory = true; - // Maximum swap space isn't set explicitly, initialize to default - if (!config.options.maximum_swap_space.ExplicitlySet()) { - config.SetDefaultMaxSwapSpace(&db); - } } } } @@ -258,8 +254,23 @@ bool TemporaryFileIndex::IsValid() const { // TemporaryFileManager //===--------------------------------------------------------------------===// -TemporaryFileManager::TemporaryFileManager(DatabaseInstance &db, const string &temp_directory_p) - : db(db), temp_directory(temp_directory_p), size_on_disk(0) { +static idx_t GetDefaultMax(const string &path) { + // Use the available disk space + auto disk_space = FileSystem::GetAvailableDiskSpace(path); + idx_t default_value = 0; + if (disk_space.IsValid()) { + // Only use 90% of the available disk space + default_value = static_cast(static_cast(disk_space.GetIndex()) * 0.9); + } + return default_value; +} + +TemporaryFileManager::TemporaryFileManager(DatabaseInstance &db, const string &temp_directory_p, + optional_idx max_swap_space) + : db(db), temp_directory(temp_directory_p), size_on_disk(0), max_swap_space(GetDefaultMax(temp_directory_p)) { + if (max_swap_space.IsValid()) { + this->max_swap_space = max_swap_space.GetIndex(); + } } TemporaryFileManager::~TemporaryFileManager() { @@ -311,10 +322,19 @@ idx_t TemporaryFileManager::GetTotalUsedSpaceInBytes() { return size_on_disk.load(); } -void TemporaryFileManager::IncreaseSizeOnDisk(idx_t bytes) { - auto &config = DBConfig::GetConfig(db); - auto max_swap_space = config.options.maximum_swap_space; +optional_idx TemporaryFileManager::GetMaxSwapSpace() const { + return max_swap_space; +} +void TemporaryFileManager::SetMaxSwapSpace(optional_idx limit) { + if (limit.IsValid()) { + max_swap_space = limit.GetIndex(); + } else { + max_swap_space = GetDefaultMax(temp_directory); + } +} + +void TemporaryFileManager::IncreaseSizeOnDisk(idx_t bytes) { auto current_size_on_disk = size_on_disk.load(); if (current_size_on_disk + bytes > max_swap_space) { auto used = StringUtil::BytesToHumanReadableString(current_size_on_disk); diff --git a/test/sql/storage/temp_directory/max_swap_space_inmemory.test b/test/sql/storage/temp_directory/max_swap_space_inmemory.test index e6aad865c921..f02c6fada818 100644 --- a/test/sql/storage/temp_directory/max_swap_space_inmemory.test +++ b/test/sql/storage/temp_directory/max_swap_space_inmemory.test @@ -44,10 +44,11 @@ select current_setting('max_temp_directory_size') statement ok set temp_directory = '__TEST_DIR__'; -# '__TEST_DIR__' is guaranteed to exist, we can get the disk space +# even though '__TEST_DIR__' exists, we haven't used the temporary directory so the max size is still 0 query I -select current_setting('max_temp_directory_size') a where a == '0 bytes' +select current_setting('max_temp_directory_size') ---- +0 bytes # --- Set explicitly by the user --- diff --git a/test/sql/storage/temp_directory/max_swap_space_persistent.test b/test/sql/storage/temp_directory/max_swap_space_persistent.test index aa18759f7a83..3b4093948199 100644 --- a/test/sql/storage/temp_directory/max_swap_space_persistent.test +++ b/test/sql/storage/temp_directory/max_swap_space_persistent.test @@ -55,10 +55,11 @@ select current_setting('max_temp_directory_size') statement ok set temp_directory = '__TEST_DIR__'; -# '__TEST_DIR__' is guaranteed to exist, we can get the disk space +# Even though __TEST_DIR__ exists, the handle is not created so the size is still 0 (unknown) query I -select current_setting('max_temp_directory_size') a where a == '0 bytes' +select current_setting('max_temp_directory_size') ---- +0 bytes # --- Set explicitly by the user --- From 17086122b885319f0359dc4bc7d08aff8189ec75 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 21 Mar 2024 16:10:40 +0100 Subject: [PATCH 026/201] get rid of FileSizeMonitor, just pass along the TemporaryFileManager & --- .../duckdb/storage/temporary_file_manager.hpp | 9 ++- src/storage/temporary_file_manager.cpp | 56 +++++++++---------- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/src/include/duckdb/storage/temporary_file_manager.hpp b/src/include/duckdb/storage/temporary_file_manager.hpp index 9ab77e42e958..f6a8a17e81a0 100644 --- a/src/include/duckdb/storage/temporary_file_manager.hpp +++ b/src/include/duckdb/storage/temporary_file_manager.hpp @@ -33,16 +33,14 @@ struct FileSizeMonitor { FileSizeMonitor(TemporaryFileManager &manager); public: - void Increase(idx_t blocks); - void Decrease(idx_t blocks); - private: TemporaryFileManager &manager; }; struct BlockIndexManager { public: - BlockIndexManager(unique_ptr file_size_monitor = nullptr); + BlockIndexManager(TemporaryFileManager &manager); + BlockIndexManager(); public: //! Obtains a new block index from the index manager @@ -54,13 +52,14 @@ struct BlockIndexManager { bool HasFreeBlocks(); private: + void SetMaxIndex(idx_t blocks); idx_t GetNewBlockIndexInternal(); private: idx_t max_index; set free_indexes; set indexes_in_use; - unique_ptr file_size_monitor; + optional_ptr manager; }; //===--------------------------------------------------------------------===// diff --git a/src/storage/temporary_file_manager.cpp b/src/storage/temporary_file_manager.cpp index d16cc23f1d0f..463210a405e7 100644 --- a/src/storage/temporary_file_manager.cpp +++ b/src/storage/temporary_file_manager.cpp @@ -5,28 +5,13 @@ namespace duckdb { //===--------------------------------------------------------------------===// -// FileSizeMonitor +// BlockIndexManager //===--------------------------------------------------------------------===// -FileSizeMonitor::FileSizeMonitor(TemporaryFileManager &manager) : manager(manager) { -} - -void FileSizeMonitor::Increase(idx_t blocks) { - auto size_on_disk = blocks * TEMPFILE_BLOCK_SIZE; - manager.IncreaseSizeOnDisk(size_on_disk); -} - -void FileSizeMonitor::Decrease(idx_t blocks) { - auto size_on_disk = blocks * TEMPFILE_BLOCK_SIZE; - manager.DecreaseSizeOnDisk(size_on_disk); +BlockIndexManager::BlockIndexManager(TemporaryFileManager &manager) : max_index(0), manager(&manager) { } -//===--------------------------------------------------------------------===// -// BlockIndexManager -//===--------------------------------------------------------------------===// - -BlockIndexManager::BlockIndexManager(unique_ptr file_size_monitor) - : max_index(0), file_size_monitor(std::move(file_size_monitor)) { +BlockIndexManager::BlockIndexManager() : max_index(0), manager(nullptr) { } idx_t BlockIndexManager::GetNewBlockIndex() { @@ -45,17 +30,12 @@ bool BlockIndexManager::RemoveIndex(idx_t index) { free_indexes.insert(index); // check if we can truncate the file - auto old_max = max_index; - // get the max_index in use right now auto max_index_in_use = indexes_in_use.empty() ? 0 : *indexes_in_use.rbegin() + 1; if (max_index_in_use < max_index) { // max index in use is lower than the max_index // reduce the max_index - max_index = max_index_in_use; - if (file_size_monitor) { - file_size_monitor->Decrease(old_max - max_index); - } + SetMaxIndex(max_index_in_use); // we can remove any free_indexes that are larger than the current max_index while (!free_indexes.empty()) { auto max_entry = *free_indexes.rbegin(); @@ -77,13 +57,31 @@ bool BlockIndexManager::HasFreeBlocks() { return !free_indexes.empty(); } +void BlockIndexManager::SetMaxIndex(idx_t new_index) { + static constexpr idx_t TEMPFILE_BLOCK_SIZE = Storage::BLOCK_ALLOC_SIZE; + if (!manager) { + max_index = new_index; + } else { + auto old = max_index; + if (new_index < old) { + max_index = new_index; + auto difference = old - new_index; + auto size_on_disk = difference * TEMPFILE_BLOCK_SIZE; + manager->DecreaseSizeOnDisk(size_on_disk); + } else if (new_index > old) { + auto difference = new_index - old; + auto size_on_disk = difference * TEMPFILE_BLOCK_SIZE; + manager->IncreaseSizeOnDisk(size_on_disk); + // Increase can throw, so this is only updated after it was succesfully updated + max_index = new_index; + } + } +} + idx_t BlockIndexManager::GetNewBlockIndexInternal() { if (free_indexes.empty()) { auto new_index = max_index; - if (file_size_monitor) { - file_size_monitor->Increase(1); - } - max_index++; + SetMaxIndex(max_index + 1); return new_index; } auto entry = free_indexes.begin(); @@ -100,7 +98,7 @@ TemporaryFileHandle::TemporaryFileHandle(idx_t temp_file_count, DatabaseInstance idx_t index, TemporaryFileManager &manager) : max_allowed_index((1 << temp_file_count) * MAX_ALLOWED_INDEX_BASE), db(db), file_index(index), path(FileSystem::GetFileSystem(db).JoinPath(temp_directory, "duckdb_temp_storage-" + to_string(index) + ".tmp")), - index_manager(make_uniq(manager)) { + index_manager(manager) { } TemporaryFileHandle::TemporaryFileLock::TemporaryFileLock(mutex &mutex) : lock(mutex) { From 03fc90e02bfa71c2900cedc43a2c6038bd7c7532 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 21 Mar 2024 16:13:25 +0100 Subject: [PATCH 027/201] remove named connection, should be stripped when it gets into the database instance constructor --- test/api/test_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/api/test_api.cpp b/test/api/test_api.cpp index 7302627eb8cd..23d4bb3d02ef 100644 --- a/test/api/test_api.cpp +++ b/test/api/test_api.cpp @@ -140,7 +140,7 @@ static void parallel_query(Connection *conn, bool *correct, size_t threadnr) { } TEST_CASE("Test temp_directory defaults", "[api][.]") { - const char *db_paths[] = {nullptr, "", ":memory:", ":memory:named_conn"}; + const char *db_paths[] = {nullptr, "", ":memory:"}; for (auto &path : db_paths) { auto db = make_uniq(path); auto conn = make_uniq(*db); From eebc24c215ada39552a52041c48761914ebc92c5 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 21 Mar 2024 17:30:58 +0100 Subject: [PATCH 028/201] test error when setting a limit that's too low --- src/main/settings/settings.cpp | 3 ++- src/storage/temporary_file_manager.cpp | 18 +++++++++++-- .../temp_directory/max_swap_space_error.test | 25 +++++++++++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/src/main/settings/settings.cpp b/src/main/settings/settings.cpp index 21ef43e50ee9..ffbfcf7322d8 100644 --- a/src/main/settings/settings.cpp +++ b/src/main/settings/settings.cpp @@ -980,8 +980,8 @@ void MaximumTempDirectorySize::SetGlobal(DatabaseInstance *db, DBConfig &config, if (input.ToString() != "-1") { maximum_swap_space = DBConfig::ParseMemoryLimit(input.ToString()); } - config.options.maximum_swap_space = maximum_swap_space; if (!db) { + config.options.maximum_swap_space = maximum_swap_space; return; } auto &buffer_manager = BufferManager::GetBufferManager(*db); @@ -990,6 +990,7 @@ void MaximumTempDirectorySize::SetGlobal(DatabaseInstance *db, DBConfig &config, } else { buffer_manager.SetSwapLimit(maximum_swap_space); } + config.options.maximum_swap_space = maximum_swap_space; } void MaximumTempDirectorySize::ResetGlobal(DatabaseInstance *db, DBConfig &config) { diff --git a/src/storage/temporary_file_manager.cpp b/src/storage/temporary_file_manager.cpp index 463210a405e7..c4b8643d4e84 100644 --- a/src/storage/temporary_file_manager.cpp +++ b/src/storage/temporary_file_manager.cpp @@ -325,11 +325,25 @@ optional_idx TemporaryFileManager::GetMaxSwapSpace() const { } void TemporaryFileManager::SetMaxSwapSpace(optional_idx limit) { + idx_t new_limit; if (limit.IsValid()) { - max_swap_space = limit.GetIndex(); + new_limit = limit.GetIndex(); } else { - max_swap_space = GetDefaultMax(temp_directory); + new_limit = GetDefaultMax(temp_directory); } + + auto current_size_on_disk = size_on_disk.load(); + if (current_size_on_disk > new_limit) { + auto used = StringUtil::BytesToHumanReadableString(current_size_on_disk); + auto max = StringUtil::BytesToHumanReadableString(new_limit); + throw OutOfMemoryException( + R"(failed to adjust the 'max_temp_directory_size', currently used space (%s) exceeds the new limit (%s) +Please increase the limit or destroy the buffers stored in the temp directory by e.g removing temporary tables. +To get usage information of the temp_directory, use 'CALL duckdb_temporary_files();' + )", + used, max); + } + max_swap_space = new_limit; } void TemporaryFileManager::IncreaseSizeOnDisk(idx_t bytes) { diff --git a/test/sql/storage/temp_directory/max_swap_space_error.test b/test/sql/storage/temp_directory/max_swap_space_error.test index 6116c552ea7c..d9512df2230d 100644 --- a/test/sql/storage/temp_directory/max_swap_space_error.test +++ b/test/sql/storage/temp_directory/max_swap_space_error.test @@ -70,3 +70,28 @@ query I select "size" from duckdb_temporary_files() ---- 1572864 + +# Lower the limit +statement error +set max_temp_directory_size='256KiB' +---- +failed to adjust the 'max_temp_directory_size', currently used space (1.5 MiB) exceeds the new limit (256.0 KiB) + +# Lower the limit +statement error +set max_temp_directory_size='256KiB' +---- +failed to adjust the 'max_temp_directory_size', currently used space (1.5 MiB) exceeds the new limit (256.0 KiB) + +query I +select current_setting('max_temp_directory_size') +---- +1.5 MiB + +statement ok +set max_temp_directory_size='2550KiB' + +query I +select current_setting('max_temp_directory_size') +---- +2.4 MiB From fcc1af60ca3b286dbf6cdcdceb4de060ff3e36e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Fri, 22 Mar 2024 14:53:42 +0100 Subject: [PATCH 029/201] first batch of changes to comply with -Wsign-conversion -Wsign-compare --- extension/parquet/column_writer.cpp | 2 +- src/CMakeLists.txt | 2 +- src/common/types/cast_helpers.cpp | 2 +- src/core_functions/lambda_functions.cpp | 4 ++-- src/function/function.cpp | 4 ++-- .../duckdb/common/multi_file_reader.hpp | 4 ++-- src/include/duckdb/common/string_util.hpp | 15 ++++++------ .../duckdb/common/types/cast_helpers.hpp | 23 +++++++++---------- src/include/duckdb/common/types/datetime.hpp | 6 ++--- src/include/duckdb/common/types/hash.hpp | 2 +- src/include/duckdb/common/vector.hpp | 7 ++++++ src/include/duckdb/function/scalar/regexp.hpp | 2 +- .../duckdb/optimizer/matcher/set_matcher.hpp | 3 ++- .../duckdb/storage/buffer/block_handle.hpp | 2 +- src/optimizer/common_aggregate_optimizer.cpp | 2 +- src/optimizer/filter_combiner.cpp | 6 ++--- .../join_order/cardinality_estimator.cpp | 2 +- src/optimizer/join_order/plan_enumerator.cpp | 4 ++-- .../join_order/query_graph_manager.cpp | 2 +- src/optimizer/pushdown/pushdown_aggregate.cpp | 2 +- src/optimizer/pushdown/pushdown_left_join.cpp | 2 +- src/optimizer/pushdown/pushdown_mark_join.cpp | 6 ++--- .../pushdown/pushdown_single_join.cpp | 2 +- src/optimizer/remove_duplicate_groups.cpp | 2 +- src/optimizer/remove_unused_columns.cpp | 2 +- .../rule/arithmetic_simplification.cpp | 2 +- src/optimizer/rule/case_simplification.cpp | 4 ++-- .../rule/conjunction_simplification.cpp | 2 +- src/optimizer/rule/distributivity.cpp | 2 +- .../expression/propagate_conjunction.cpp | 2 +- .../expression/propagate_operator.cpp | 8 +++---- .../statistics/operator/propagate_filter.cpp | 6 ++--- .../statistics/operator/propagate_get.cpp | 2 +- .../statistics/operator/propagate_join.cpp | 5 ++-- .../operator/propagate_set_operation.cpp | 3 ++- src/optimizer/topn_optimizer.cpp | 2 +- src/parser/parser.cpp | 9 ++++---- .../expression/transform_boolean_test.cpp | 4 ++-- .../transform/expression/transform_cast.cpp | 2 +- .../expression/transform_param_ref.cpp | 2 +- .../transform_positional_reference.cpp | 2 +- .../transform/helpers/transform_typename.cpp | 10 ++++---- src/parser/transformer.cpp | 4 ++-- .../binder/statement/bind_copy_database.cpp | 2 +- src/planner/binder/statement/bind_export.cpp | 2 +- src/planner/binder/statement/bind_insert.cpp | 2 +- src/planner/binder/tableref/bind_pivot.cpp | 6 ++--- src/planner/bound_result_modifier.cpp | 2 +- src/planner/operator/logical_top_n.cpp | 2 +- src/planner/table_binding.cpp | 2 +- src/storage/data_table.cpp | 2 +- src/storage/local_storage.cpp | 2 +- src/storage/table/row_group_collection.cpp | 2 +- src/storage/table_index_list.cpp | 2 +- src/transaction/duck_transaction_manager.cpp | 2 +- src/transaction/meta_transaction.cpp | 2 +- third_party/utf8proc/include/utf8proc.hpp | 2 +- 57 files changed, 112 insertions(+), 101 deletions(-) diff --git a/extension/parquet/column_writer.cpp b/extension/parquet/column_writer.cpp index 47ff6c93d6e9..c90ab1f5b26d 100644 --- a/extension/parquet/column_writer.cpp +++ b/extension/parquet/column_writer.cpp @@ -477,7 +477,7 @@ void BasicColumnWriter::BeginWrite(ColumnWriterState &state_p) { auto &page_info = state.page_info[page_idx]; if (page_info.row_count == 0) { D_ASSERT(page_idx + 1 == state.page_info.size()); - state.page_info.erase(state.page_info.begin() + page_idx); + state.page_info.erase_at(page_idx); break; } PageWriteInformation write_info; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d45ae7fb65fd..4c69853abd3c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,7 +24,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") set(EXIT_TIME_DESTRUCTORS_WARNING TRUE) set(CMAKE_CXX_FLAGS_DEBUG - "${CMAKE_CXX_FLAGS_DEBUG} -Wexit-time-destructors -Wimplicit-int-conversion -Wshorten-64-to-32 -Wnarrowing" + "${CMAKE_CXX_FLAGS_DEBUG} -Wexit-time-destructors -Wimplicit-int-conversion -Wshorten-64-to-32 -Wnarrowing -Wsign-conversion -Wsign-compare" ) endif() diff --git a/src/common/types/cast_helpers.cpp b/src/common/types/cast_helpers.cpp index f37fbaa971e9..1e37fbc79ea6 100644 --- a/src/common/types/cast_helpers.cpp +++ b/src/common/types/cast_helpers.cpp @@ -67,7 +67,7 @@ int NumericHelper::UnsignedLength(uint32_t value) { } template <> -int NumericHelper::UnsignedLength(uint64_t value) { +idx_t NumericHelper::UnsignedLength(uint64_t value) { if (value >= 10000000000ULL) { if (value >= 1000000000000000ULL) { int length = 16; diff --git a/src/core_functions/lambda_functions.cpp b/src/core_functions/lambda_functions.cpp index 3b67f8809305..ee78be581cd7 100644 --- a/src/core_functions/lambda_functions.cpp +++ b/src/core_functions/lambda_functions.cpp @@ -154,7 +154,7 @@ struct ListFilterFunctor { // slice the input chunk's corresponding vector to get the new lists // and append them to the result - auto source_list_idx = execute_info.has_index ? 1 : 0; + idx_t source_list_idx = execute_info.has_index ? 1 : 0; Vector result_lists(execute_info.input_chunk.data[source_list_idx], sel, count); ListVector::Append(result, result_lists, count, 0); } @@ -353,7 +353,7 @@ void ExecuteLambda(DataChunk &args, ExpressionState &state, Vector &result) { // set the index vector if (info.has_index) { - index_vector.SetValue(elem_cnt, Value::BIGINT(child_idx + 1)); + index_vector.SetValue(elem_cnt, Value::BIGINT(NumericCast(child_idx + 1))); } elem_cnt++; diff --git a/src/function/function.cpp b/src/function/function.cpp index 3d76a7257afc..9427f445b740 100644 --- a/src/function/function.cpp +++ b/src/function/function.cpp @@ -158,8 +158,8 @@ void Function::EraseArgument(SimpleFunction &bound_function, vectorGetFileName()); if (entry == file_set.end()) { - data.union_readers.erase(data.union_readers.begin() + r); + data.union_readers.erase_at(r); r--; continue; } diff --git a/src/include/duckdb/common/string_util.hpp b/src/include/duckdb/common/string_util.hpp index e5eb9cd6269a..0a2335e7ed81 100644 --- a/src/include/duckdb/common/string_util.hpp +++ b/src/include/duckdb/common/string_util.hpp @@ -10,8 +10,9 @@ #include "duckdb/common/constants.hpp" #include "duckdb/common/exception.hpp" -#include "duckdb/common/vector.hpp" +#include "duckdb/common/numeric_utils.hpp" #include "duckdb/common/set.hpp" +#include "duckdb/common/vector.hpp" #include @@ -40,22 +41,22 @@ class StringUtil { static uint8_t GetHexValue(char c) { if (c >= '0' && c <= '9') { - return c - '0'; + return UnsafeNumericCast(c - '0'); } if (c >= 'a' && c <= 'f') { - return c - 'a' + 10; + return UnsafeNumericCast(c - 'a' + 10); } if (c >= 'A' && c <= 'F') { - return c - 'A' + 10; + return UnsafeNumericCast(c - 'A' + 10); } - throw InvalidInputException("Invalid input for hex digit: %s", string(c, 1)); + throw InvalidInputException("Invalid input for hex digit: %s", string(1, c)); } static uint8_t GetBinaryValue(char c) { if (c >= '0' && c <= '1') { - return c - '0'; + return UnsafeNumericCast(c - '0'); } - throw InvalidInputException("Invalid input for binary digit: %s", string(c, 1)); + throw InvalidInputException("Invalid input for binary digit: %s", string(1, c)); } static bool CharacterIsSpace(char c) { diff --git a/src/include/duckdb/common/types/cast_helpers.hpp b/src/include/duckdb/common/types/cast_helpers.hpp index cac8803f3cb4..5e38665a4d37 100644 --- a/src/include/duckdb/common/types/cast_helpers.hpp +++ b/src/include/duckdb/common/types/cast_helpers.hpp @@ -64,7 +64,7 @@ class NumericHelper { int sign = -(value < 0); UNSIGNED unsigned_value = UnsafeNumericCast(UNSIGNED(value ^ sign) - sign); int length = UnsignedLength(unsigned_value) - sign; - string_t result = StringVector::EmptyString(vector, length); + string_t result = StringVector::EmptyString(vector, NumericCast(length)); auto dataptr = result.GetDataWriteable(); auto endptr = dataptr + length; endptr = FormatUnsigned(unsigned_value, endptr); @@ -149,7 +149,7 @@ struct DecimalToString { template static string_t Format(SIGNED value, uint8_t width, uint8_t scale, Vector &vector) { int len = DecimalLength(value, width, scale); - string_t result = StringVector::EmptyString(vector, len); + string_t result = StringVector::EmptyString(vector, NumericCast(len)); FormatDecimal(value, width, scale, result.GetDataWriteable(), len); result.Finalize(); return result; @@ -260,7 +260,7 @@ struct HugeintToStringCast { Hugeint::NegateInPlace(value); } int length = UnsignedLength(value) + negative; - string_t result = StringVector::EmptyString(vector, length); + string_t result = StringVector::EmptyString(vector, NumericCast(length)); auto dataptr = result.GetDataWriteable(); auto endptr = dataptr + length; if (value.upper == 0) { @@ -339,7 +339,7 @@ struct HugeintToStringCast { static string_t FormatDecimal(hugeint_t value, uint8_t width, uint8_t scale, Vector &vector) { int length = DecimalLength(value, width, scale); - string_t result = StringVector::EmptyString(vector, length); + string_t result = StringVector::EmptyString(vector, NumericCast(length)); auto dst = result.GetDataWriteable(); @@ -417,9 +417,9 @@ struct DateToStringCast { struct TimeToStringCast { //! Format microseconds to a buffer of length 6. Returns the number of trailing zeros - static int32_t FormatMicros(uint32_t microseconds, char micro_buffer[]) { + static int32_t FormatMicros(int32_t microseconds, char micro_buffer[]) { char *endptr = micro_buffer + 6; - endptr = NumericHelper::FormatUnsigned(microseconds, endptr); + endptr = NumericHelper::FormatUnsigned(microseconds, endptr); while (endptr > micro_buffer) { *--endptr = '0'; } @@ -448,7 +448,7 @@ struct TimeToStringCast { // we write backwards and pad with zeros to the left // now we figure out how many digits we need to include by looking backwards // and checking how many zeros we encounter - length -= FormatMicros(time[3], micro_buffer); + length -= NumericCast(FormatMicros(time[3], micro_buffer)); } return length; } @@ -485,8 +485,8 @@ struct TimeToStringCast { struct IntervalToStringCast { static void FormatSignedNumber(int64_t value, char buffer[], idx_t &length) { int sign = -(value < 0); - uint64_t unsigned_value = (value ^ sign) - sign; - length += NumericHelper::UnsignedLength(unsigned_value) - sign; + auto unsigned_value = NumericCast((value ^ sign) - sign); + length += NumericCast(NumericHelper::UnsignedLength(unsigned_value) - sign); auto endptr = buffer + length; endptr = NumericHelper::FormatUnsigned(unsigned_value, endptr); if (sign) { @@ -567,9 +567,8 @@ struct IntervalToStringCast { FormatTwoDigits(sec, buffer, length); if (micros != 0) { buffer[length++] = '.'; - auto trailing_zeros = - TimeToStringCast::FormatMicros(UnsafeNumericCast(micros), buffer + length); - length += 6 - trailing_zeros; + auto trailing_zeros = TimeToStringCast::FormatMicros(NumericCast(micros), buffer + length); + length += NumericCast(6 - trailing_zeros); } } else if (length == 0) { // empty interval: default to 00:00:00 diff --git a/src/include/duckdb/common/types/datetime.hpp b/src/include/duckdb/common/types/datetime.hpp index 4a06e1b762fd..a289a9a46971 100644 --- a/src/include/duckdb/common/types/datetime.hpp +++ b/src/include/duckdb/common/types/datetime.hpp @@ -57,10 +57,10 @@ struct dtime_t { // NOLINT return dtime_t(this->micros - micros); }; inline dtime_t operator*(const idx_t &copies) const { - return dtime_t(this->micros * copies); + return dtime_t(this->micros * UnsafeNumericCast(copies)); }; inline dtime_t operator/(const idx_t &copies) const { - return dtime_t(this->micros / copies); + return dtime_t(this->micros / UnsafeNumericCast(copies)); }; inline int64_t operator-(const dtime_t &other) const { return this->micros - other.micros; @@ -149,7 +149,7 @@ template <> struct hash { std::size_t operator()(const duckdb::dtime_tz_t &k) const { using std::hash; - return hash()(k.bits); + return hash()(k.bits); } }; } // namespace std diff --git a/src/include/duckdb/common/types/hash.hpp b/src/include/duckdb/common/types/hash.hpp index eeb849857f69..f43f75473e75 100644 --- a/src/include/duckdb/common/types/hash.hpp +++ b/src/include/duckdb/common/types/hash.hpp @@ -35,7 +35,7 @@ inline hash_t murmurhash32(uint32_t x) { template hash_t Hash(T value) { - return murmurhash32(value); + return murmurhash32(static_cast(value)); } //! Combine two hashes by XORing them diff --git a/src/include/duckdb/common/vector.hpp b/src/include/duckdb/common/vector.hpp index 66a6bc73153a..b634a8583ce3 100644 --- a/src/include/duckdb/common/vector.hpp +++ b/src/include/duckdb/common/vector.hpp @@ -100,6 +100,13 @@ class vector : public std::vector<_Tp, std::allocator<_Tp>> { } return get(original::size() - 1); } + + void erase_at(idx_t idx) { + if (MemorySafety::enabled && idx > original::size()) { + throw InternalException("Can't remove offset %d from vector of size %d", idx, original::size()); + } + original::erase(original::begin() + static_cast(idx)); + } }; template diff --git a/src/include/duckdb/function/scalar/regexp.hpp b/src/include/duckdb/function/scalar/regexp.hpp index 208e033e0904..612c4cae24ff 100644 --- a/src/include/duckdb/function/scalar/regexp.hpp +++ b/src/include/duckdb/function/scalar/regexp.hpp @@ -140,7 +140,7 @@ struct RegexLocalState : public FunctionLocalState { if (extract_all) { auto group_count_p = constant_pattern.NumberOfCapturingGroups(); if (group_count_p != -1) { - group_buffer.Init(group_count_p); + group_buffer.Init(NumericCast(group_count_p)); } } D_ASSERT(info.constant_pattern); diff --git a/src/include/duckdb/optimizer/matcher/set_matcher.hpp b/src/include/duckdb/optimizer/matcher/set_matcher.hpp index a709a1b5c3fa..11a991d2f683 100644 --- a/src/include/duckdb/optimizer/matcher/set_matcher.hpp +++ b/src/include/duckdb/optimizer/matcher/set_matcher.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/common.hpp" +#include "duckdb/common/numeric_utils.hpp" #include "duckdb/common/unordered_set.hpp" namespace duckdb { @@ -59,7 +60,7 @@ class SetMatcher { return true; } else { // we did not find a match! remove any bindings we added in the call to Match() - bindings.erase(bindings.begin() + previous_binding_count, bindings.end()); + bindings.erase(bindings.begin() + NumericCast(previous_binding_count), bindings.end()); } } } diff --git a/src/include/duckdb/storage/buffer/block_handle.hpp b/src/include/duckdb/storage/buffer/block_handle.hpp index 196fb27c7654..6c8f554897e6 100644 --- a/src/include/duckdb/storage/buffer/block_handle.hpp +++ b/src/include/duckdb/storage/buffer/block_handle.hpp @@ -76,7 +76,7 @@ class BlockHandle { D_ASSERT(buffer); // resize and adjust current memory buffer->Resize(block_size); - memory_usage += memory_delta; + memory_usage = NumericCast(NumericCast(memory_usage) + memory_delta); D_ASSERT(memory_usage == buffer->AllocSize()); } diff --git a/src/optimizer/common_aggregate_optimizer.cpp b/src/optimizer/common_aggregate_optimizer.cpp index da13092c6e14..435b94cd35aa 100644 --- a/src/optimizer/common_aggregate_optimizer.cpp +++ b/src/optimizer/common_aggregate_optimizer.cpp @@ -47,7 +47,7 @@ void CommonAggregateOptimizer::ExtractCommonAggregates(LogicalAggregate &aggr) { } else { // aggregate already exists! we can remove this entry total_erased++; - aggr.expressions.erase(aggr.expressions.begin() + i); + aggr.expressions.erase_at(i); i--; // we need to remap any references to this aggregate so they point to the other aggregate ColumnBinding original_binding(aggr.aggregate_index, original_index); diff --git a/src/optimizer/filter_combiner.cpp b/src/optimizer/filter_combiner.cpp index cd6f3319134d..97cabfba613a 100644 --- a/src/optimizer/filter_combiner.cpp +++ b/src/optimizer/filter_combiner.cpp @@ -67,7 +67,7 @@ FilterResult FilterCombiner::AddConstantComparison(vector &column_id table_filters.PushFilter(column_index, std::move(upper_bound)); table_filters.PushFilter(column_index, make_uniq()); - remaining_filters.erase(remaining_filters.begin() + rem_fil_idx); + remaining_filters.erase_at(rem_fil_idx); } } @@ -971,7 +971,7 @@ unique_ptr FilterCombiner::FindTransitiveFilter(Expression &expr) { auto &comparison = remaining_filters[i]->Cast(); if (expr.Equals(*comparison.right) && comparison.type != ExpressionType::COMPARE_NOTEQUAL) { auto filter = std::move(remaining_filters[i]); - remaining_filters.erase(remaining_filters.begin() + i); + remaining_filters.erase_at(i); return filter; } } diff --git a/src/optimizer/join_order/cardinality_estimator.cpp b/src/optimizer/join_order/cardinality_estimator.cpp index 18fd5c858386..7a30a8ce573a 100644 --- a/src/optimizer/join_order/cardinality_estimator.cpp +++ b/src/optimizer/join_order/cardinality_estimator.cpp @@ -50,7 +50,7 @@ bool CardinalityEstimator::SingleColumnFilter(FilterInfo &filter_info) { vector CardinalityEstimator::DetermineMatchingEquivalentSets(FilterInfo *filter_info) { vector matching_equivalent_sets; - auto equivalent_relation_index = 0; + idx_t equivalent_relation_index = 0; for (const RelationsToTDom &r2tdom : relations_to_tdoms) { auto &i_set = r2tdom.equivalent_relations; diff --git a/src/optimizer/join_order/plan_enumerator.cpp b/src/optimizer/join_order/plan_enumerator.cpp index a6efb84b8e6e..6bbc13c3c53a 100644 --- a/src/optimizer/join_order/plan_enumerator.cpp +++ b/src/optimizer/join_order/plan_enumerator.cpp @@ -490,8 +490,8 @@ void PlanEnumerator::SolveJoinOrderApproximately() { // important to erase the biggest element first // if we erase the smallest element first the index of the biggest element changes D_ASSERT(best_right > best_left); - join_relations.erase(join_relations.begin() + best_right); - join_relations.erase(join_relations.begin() + best_left); + join_relations.erase_at(best_right); + join_relations.erase_at(best_left); join_relations.push_back(best_connection->set); } } diff --git a/src/optimizer/join_order/query_graph_manager.cpp b/src/optimizer/join_order/query_graph_manager.cpp index b8835265b843..13b7807b8964 100644 --- a/src/optimizer/join_order/query_graph_manager.cpp +++ b/src/optimizer/join_order/query_graph_manager.cpp @@ -116,7 +116,7 @@ static unique_ptr ExtractJoinRelation(unique_ptrop) { // found it! take ownership o/**/f it from the parent auto result = std::move(children[i]); - children.erase(children.begin() + i); + children.erase_at(i); return result; } } diff --git a/src/optimizer/pushdown/pushdown_aggregate.cpp b/src/optimizer/pushdown/pushdown_aggregate.cpp index 396980d54361..47a2b24e9942 100644 --- a/src/optimizer/pushdown/pushdown_aggregate.cpp +++ b/src/optimizer/pushdown/pushdown_aggregate.cpp @@ -87,7 +87,7 @@ unique_ptr FilterPushdown::PushdownAggregate(unique_ptr(std::move(op)); } // erase the filter from here - filters.erase(filters.begin() + i); + filters.erase_at(i); i--; } child_pushdown.GenerateFilters(); diff --git a/src/optimizer/pushdown/pushdown_left_join.cpp b/src/optimizer/pushdown/pushdown_left_join.cpp index 47cfdfd6b639..b4ebd2ef27e3 100644 --- a/src/optimizer/pushdown/pushdown_left_join.cpp +++ b/src/optimizer/pushdown/pushdown_left_join.cpp @@ -91,7 +91,7 @@ unique_ptr FilterPushdown::PushdownLeftJoin(unique_ptr FilterPushdown::PushdownMarkJoin(unique_ptr FilterPushdown::PushdownMarkJoin(unique_ptr FilterPushdown::PushdownMarkJoin(unique_ptr FilterPushdown::PushdownSingleJoin(unique_ptr &list, idx_t table_id auto entry = column_references.find(current_binding); if (entry == column_references.end()) { // this entry is not referred to, erase it from the set of expressions - list.erase(list.begin() + col_idx); + list.erase_at(col_idx); offset++; col_idx--; } else if (offset > 0 && replace) { diff --git a/src/optimizer/rule/arithmetic_simplification.cpp b/src/optimizer/rule/arithmetic_simplification.cpp index b319c60d5713..bd4e0821e2fa 100644 --- a/src/optimizer/rule/arithmetic_simplification.cpp +++ b/src/optimizer/rule/arithmetic_simplification.cpp @@ -26,7 +26,7 @@ unique_ptr ArithmeticSimplificationRule::Apply(LogicalOperator &op, bool &changes_made, bool is_root) { auto &root = bindings[0].get().Cast(); auto &constant = bindings[1].get().Cast(); - int constant_child = root.children[0].get() == &constant ? 0 : 1; + idx_t constant_child = root.children[0].get() == &constant ? 0 : 1; D_ASSERT(root.children.size() == 2); (void)root; // any arithmetic operator involving NULL is always NULL diff --git a/src/optimizer/rule/case_simplification.cpp b/src/optimizer/rule/case_simplification.cpp index 61c6ed353829..2dbbecef56a7 100644 --- a/src/optimizer/rule/case_simplification.cpp +++ b/src/optimizer/rule/case_simplification.cpp @@ -25,14 +25,14 @@ unique_ptr CaseSimplificationRule::Apply(LogicalOperator &op, vector auto condition = constant_value.DefaultCastAs(LogicalType::BOOLEAN); if (condition.IsNull() || !BooleanValue::Get(condition)) { // the condition is always false: remove this case check - root.case_checks.erase(root.case_checks.begin() + i); + root.case_checks.erase_at(i); i--; } else { // the condition is always true // move the THEN clause to the ELSE of the case root.else_expr = std::move(case_check.then_expr); // remove this case check and any case checks after this one - root.case_checks.erase(root.case_checks.begin() + i, root.case_checks.end()); + root.case_checks.erase(root.case_checks.begin() + NumericCast(i), root.case_checks.end()); break; } } diff --git a/src/optimizer/rule/conjunction_simplification.cpp b/src/optimizer/rule/conjunction_simplification.cpp index 070237cb724c..646471b9412e 100644 --- a/src/optimizer/rule/conjunction_simplification.cpp +++ b/src/optimizer/rule/conjunction_simplification.cpp @@ -19,7 +19,7 @@ unique_ptr ConjunctionSimplificationRule::RemoveExpression(BoundConj for (idx_t i = 0; i < conj.children.size(); i++) { if (conj.children[i].get() == &expr) { // erase the expression - conj.children.erase(conj.children.begin() + i); + conj.children.erase_at(i); break; } } diff --git a/src/optimizer/rule/distributivity.cpp b/src/optimizer/rule/distributivity.cpp index 509960c03b0d..b6a889af416d 100644 --- a/src/optimizer/rule/distributivity.cpp +++ b/src/optimizer/rule/distributivity.cpp @@ -35,7 +35,7 @@ unique_ptr DistributivityRule::ExtractExpression(BoundConjunctionExp for (idx_t i = 0; i < and_expr.children.size(); i++) { if (and_expr.children[i]->Equals(expr)) { result = std::move(and_expr.children[i]); - and_expr.children.erase(and_expr.children.begin() + i); + and_expr.children.erase_at(i); break; } } diff --git a/src/optimizer/statistics/expression/propagate_conjunction.cpp b/src/optimizer/statistics/expression/propagate_conjunction.cpp index 1fce16c8cf51..4c69f8a15aa9 100644 --- a/src/optimizer/statistics/expression/propagate_conjunction.cpp +++ b/src/optimizer/statistics/expression/propagate_conjunction.cpp @@ -46,7 +46,7 @@ unique_ptr StatisticsPropagator::PropagateExpression(BoundConjun } } if (prune_child) { - expr.children.erase(expr.children.begin() + expr_idx); + expr.children.erase_at(expr_idx); expr_idx--; continue; } diff --git a/src/optimizer/statistics/expression/propagate_operator.cpp b/src/optimizer/statistics/expression/propagate_operator.cpp index 4bb1f8fd7a32..b48593cd8e33 100644 --- a/src/optimizer/statistics/expression/propagate_operator.cpp +++ b/src/optimizer/statistics/expression/propagate_operator.cpp @@ -28,8 +28,8 @@ unique_ptr StatisticsPropagator::PropagateExpression(BoundOperat // this child is always NULL, we can remove it from the coalesce // UNLESS there is only one node remaining if (expr.children.size() > 1) { - expr.children.erase(expr.children.begin() + i); - child_stats.erase(child_stats.begin() + i); + expr.children.erase_at(i); + child_stats.erase_at(i); i--; } } else if (!child_stats[i]->CanHaveNull()) { @@ -37,8 +37,8 @@ unique_ptr StatisticsPropagator::PropagateExpression(BoundOperat // this is the last coalesce node that influences the result // we can erase any children after this node if (i + 1 < expr.children.size()) { - expr.children.erase(expr.children.begin() + i + 1, expr.children.end()); - child_stats.erase(child_stats.begin() + i + 1, child_stats.end()); + expr.children.erase(expr.children.begin() + NumericCast(i + 1), expr.children.end()); + child_stats.erase(child_stats.begin() + NumericCast(i + 1), child_stats.end()); } break; } diff --git a/src/optimizer/statistics/operator/propagate_filter.cpp b/src/optimizer/statistics/operator/propagate_filter.cpp index 97609fa62888..cff64ebca12e 100644 --- a/src/optimizer/statistics/operator/propagate_filter.cpp +++ b/src/optimizer/statistics/operator/propagate_filter.cpp @@ -227,7 +227,7 @@ unique_ptr StatisticsPropagator::PropagateStatistics(LogicalFilt node_stats = PropagateStatistics(filter.children[0]); if (filter.children[0]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT) { ReplaceWithEmptyResult(*node_ptr); - return make_uniq(0, 0); + return make_uniq(0U, 0U); } // then propagate to each of the expressions @@ -238,7 +238,7 @@ unique_ptr StatisticsPropagator::PropagateStatistics(LogicalFilt if (ExpressionIsConstant(*condition, Value::BOOLEAN(true))) { // filter is always true; it is useless to execute it // erase this condition - filter.expressions.erase(filter.expressions.begin() + i); + filter.expressions.erase_at(i); i--; if (filter.expressions.empty()) { // just break. The physical filter planner will plan a projection instead @@ -248,7 +248,7 @@ unique_ptr StatisticsPropagator::PropagateStatistics(LogicalFilt ExpressionIsConstantOrNull(*condition, Value::BOOLEAN(false))) { // filter is always false or null; this entire filter should be replaced by an empty result block ReplaceWithEmptyResult(*node_ptr); - return make_uniq(0, 0); + return make_uniq(0U, 0U); } else { // cannot prune this filter: propagate statistics from the filter UpdateFilterStatistics(*condition); diff --git a/src/optimizer/statistics/operator/propagate_get.cpp b/src/optimizer/statistics/operator/propagate_get.cpp index 22979f19450c..48b41c1e0770 100644 --- a/src/optimizer/statistics/operator/propagate_get.cpp +++ b/src/optimizer/statistics/operator/propagate_get.cpp @@ -86,7 +86,7 @@ unique_ptr StatisticsPropagator::PropagateStatistics(LogicalGet case FilterPropagateResult::FILTER_ALWAYS_FALSE: // filter is always false; this entire filter should be replaced by an empty result block ReplaceWithEmptyResult(*node_ptr); - return make_uniq(0, 0); + return make_uniq(0U, 0U); default: // general case: filter can be true or false, update this columns' statistics UpdateFilterStatistics(stats, *filter); diff --git a/src/optimizer/statistics/operator/propagate_join.cpp b/src/optimizer/statistics/operator/propagate_join.cpp index 093a37905270..10fd54a568c2 100644 --- a/src/optimizer/statistics/operator/propagate_join.cpp +++ b/src/optimizer/statistics/operator/propagate_join.cpp @@ -85,7 +85,7 @@ void StatisticsPropagator::PropagateStatistics(LogicalComparisonJoin &join, uniq } if (join.conditions.size() > 1) { // there are multiple conditions: erase this condition - join.conditions.erase(join.conditions.begin() + i); + join.conditions.erase_at(i); // remove the corresponding statistics join.join_stats.clear(); i--; @@ -187,7 +187,8 @@ void StatisticsPropagator::MultiplyCardinalities(unique_ptr &sta return; } stats->estimated_cardinality = MaxValue(stats->estimated_cardinality, new_stats.estimated_cardinality); - auto new_max = Hugeint::Multiply(stats->max_cardinality, new_stats.max_cardinality); + auto new_max = Hugeint::Multiply(NumericCast(stats->max_cardinality), + NumericCast(new_stats.max_cardinality)); if (new_max < NumericLimits::Maximum()) { int64_t result; if (!Hugeint::TryCast(new_max, result)) { diff --git a/src/optimizer/statistics/operator/propagate_set_operation.cpp b/src/optimizer/statistics/operator/propagate_set_operation.cpp index 8b90aeb6346a..e3d1b80a2b3d 100644 --- a/src/optimizer/statistics/operator/propagate_set_operation.cpp +++ b/src/optimizer/statistics/operator/propagate_set_operation.cpp @@ -10,7 +10,8 @@ void StatisticsPropagator::AddCardinalities(unique_ptr &stats, N return; } stats->estimated_cardinality += new_stats.estimated_cardinality; - auto new_max = Hugeint::Add(stats->max_cardinality, new_stats.max_cardinality); + auto new_max = + Hugeint::Add(NumericCast(stats->max_cardinality), NumericCast(new_stats.max_cardinality)); if (new_max < NumericLimits::Maximum()) { int64_t result; if (!Hugeint::TryCast(new_max, result)) { diff --git a/src/optimizer/topn_optimizer.cpp b/src/optimizer/topn_optimizer.cpp index ca64e5700704..26f6ca99506e 100644 --- a/src/optimizer/topn_optimizer.cpp +++ b/src/optimizer/topn_optimizer.cpp @@ -32,7 +32,7 @@ unique_ptr TopN::Optimize(unique_ptr op) { auto limit_val = int64_t(limit.limit_val.GetConstantValue()); int64_t offset_val = 0; if (limit.offset_val.Type() == LimitNodeType::CONSTANT_VALUE) { - offset_val = limit.offset_val.GetConstantValue(); + offset_val = NumericCast(limit.offset_val.GetConstantValue()); } auto topn = make_uniq(std::move(order_by.orders), limit_val, offset_val); topn->AddChild(std::move(order_by.children[0])); diff --git a/src/parser/parser.cpp b/src/parser/parser.cpp index 8274c521d71f..7dd7a43d9f65 100644 --- a/src/parser/parser.cpp +++ b/src/parser/parser.cpp @@ -181,7 +181,7 @@ void Parser::ParseQuery(const string &query) { } else { parser_error = parser.error_message; if (parser.error_location > 0) { - parser_error_location = parser.error_location - 1; + parser_error_location = NumericCast(parser.error_location - 1); } } } @@ -196,7 +196,7 @@ void Parser::ParseQuery(const string &query) { } else { // split sql string into statements and re-parse using extension auto query_statements = SplitQueryStringIntoStatements(query); - auto stmt_loc = 0; + idx_t stmt_loc = 0; for (auto const &query_statement : query_statements) { ErrorData another_parser_error; // Creating a new scope to allow extensions to use PostgresParser, which is not reentrant @@ -219,7 +219,8 @@ void Parser::ParseQuery(const string &query) { } else { another_parser_error = ErrorData(another_parser.error_message); if (another_parser.error_location > 0) { - another_parser_error.AddQueryLocation(another_parser.error_location - 1); + another_parser_error.AddQueryLocation( + NumericCast(another_parser.error_location - 1)); } } } // LCOV_EXCL_STOP @@ -292,7 +293,7 @@ vector Parser::Tokenize(const string &query) { default: throw InternalException("Unrecognized token category"); } // LCOV_EXCL_STOP - token.start = pg_token.start; + token.start = NumericCast(pg_token.start); result.push_back(token); } return result; diff --git a/src/parser/transform/expression/transform_boolean_test.cpp b/src/parser/transform/expression/transform_boolean_test.cpp index ec9dffb794a5..3c96f4dab16f 100644 --- a/src/parser/transform/expression/transform_boolean_test.cpp +++ b/src/parser/transform/expression/transform_boolean_test.cpp @@ -20,9 +20,9 @@ static unique_ptr TransformBooleanTestInternal(unique_ptr TransformBooleanTestIsNull(unique_ptr argument, - ExpressionType operator_type, idx_t query_location) { + ExpressionType operator_type, int query_location) { auto result = make_uniq(operator_type, std::move(argument)); - Transformer::SetQueryLocation(*result, UnsafeNumericCast(query_location)); + Transformer::SetQueryLocation(*result, query_location); return std::move(result); } diff --git a/src/parser/transform/expression/transform_cast.cpp b/src/parser/transform/expression/transform_cast.cpp index c10c81d76404..a4b1dde59bbe 100644 --- a/src/parser/transform/expression/transform_cast.cpp +++ b/src/parser/transform/expression/transform_cast.cpp @@ -18,7 +18,7 @@ unique_ptr Transformer::TransformTypeCast(duckdb_libpgquery::P if (c->val.type == duckdb_libpgquery::T_PGString) { CastParameters parameters; if (root.location >= 0) { - parameters.query_location = root.location; + parameters.query_location = NumericCast(root.location); } auto blob_data = Blob::ToBlob(string(c->val.val.str), parameters); return make_uniq(Value::BLOB_RAW(blob_data)); diff --git a/src/parser/transform/expression/transform_param_ref.cpp b/src/parser/transform/expression/transform_param_ref.cpp index d5d7931fe3e0..2d4ba5155d18 100644 --- a/src/parser/transform/expression/transform_param_ref.cpp +++ b/src/parser/transform/expression/transform_param_ref.cpp @@ -40,7 +40,7 @@ unique_ptr Transformer::TransformParamRef(duckdb_libpgquery::P // We have not seen this parameter before if (node.number != 0) { // Preserve the parameter number - known_param_index = node.number; + known_param_index = NumericCast(node.number); } else { known_param_index = ParamCount() + 1; if (!node.name) { diff --git a/src/parser/transform/expression/transform_positional_reference.cpp b/src/parser/transform/expression/transform_positional_reference.cpp index efe7d3da3321..8d672948db3f 100644 --- a/src/parser/transform/expression/transform_positional_reference.cpp +++ b/src/parser/transform/expression/transform_positional_reference.cpp @@ -8,7 +8,7 @@ unique_ptr Transformer::TransformPositionalReference(duckdb_li if (node.position <= 0) { throw ParserException("Positional reference node needs to be >= 1"); } - auto result = make_uniq(node.position); + auto result = make_uniq(NumericCast(node.position)); SetQueryLocation(*result, node.location); return std::move(result); } diff --git a/src/parser/transform/helpers/transform_typename.cpp b/src/parser/transform/helpers/transform_typename.cpp index ff40554f6084..edea1222fba0 100644 --- a/src/parser/transform/helpers/transform_typename.cpp +++ b/src/parser/transform/helpers/transform_typename.cpp @@ -41,7 +41,7 @@ LogicalType Transformer::TransformTypeName(duckdb_libpgquery::PGTypeName &type_n if (!type_name.typmods || type_name.typmods->length == 0) { throw ParserException("Enum needs a set of entries"); } - Vector enum_vector(LogicalType::VARCHAR, type_name.typmods->length); + Vector enum_vector(LogicalType::VARCHAR, NumericCast(type_name.typmods->length)); auto string_data = FlatVector::GetData(enum_vector); idx_t pos = 0; for (auto node = type_name.typmods->head; node; node = node->next) { @@ -52,7 +52,7 @@ LogicalType Transformer::TransformTypeName(duckdb_libpgquery::PGTypeName &type_n } string_data[pos++] = StringVector::AddString(enum_vector, constant_value->val.val.str); } - return LogicalType::ENUM(enum_vector, type_name.typmods->length); + return LogicalType::ENUM(enum_vector, NumericCast(type_name.typmods->length)); } else if (base_type == LogicalTypeId::STRUCT) { if (!type_name.typmods || type_name.typmods->length == 0) { throw ParserException("Struct needs a name and entries"); @@ -154,12 +154,12 @@ LogicalType Transformer::TransformTypeName(duckdb_libpgquery::PGTypeName &type_n throw ParserException("Negative modifier not supported"); } if (modifier_idx == 0) { - width = const_val.val.val.ival; + width = NumericCast(const_val.val.val.ival); if (base_type == LogicalTypeId::BIT && const_val.location != -1) { width = 0; } } else if (modifier_idx == 1) { - scale = const_val.val.val.ival; + scale = NumericCast(const_val.val.val.ival); } else { throw ParserException("A maximum of two modifiers is supported"); } @@ -245,7 +245,7 @@ LogicalType Transformer::TransformTypeName(duckdb_libpgquery::PGTypeName &type_n if (val->type != duckdb_libpgquery::T_PGInteger) { throw ParserException("Expected integer value as array bound"); } - auto array_size = val->val.ival; + auto array_size = NumericCast(val->val.ival); if (array_size < 0) { // -1 if bounds are empty result_type = LogicalType::LIST(result_type); diff --git a/src/parser/transformer.cpp b/src/parser/transformer.cpp index 8c54c46a3a7f..f12e0d17882b 100644 --- a/src/parser/transformer.cpp +++ b/src/parser/transformer.cpp @@ -134,8 +134,8 @@ unique_ptr Transformer::TransformStatementInternal(duckdb_libpgque auto &raw_stmt = PGCast(stmt); auto result = TransformStatement(*raw_stmt.stmt); if (result) { - result->stmt_location = raw_stmt.stmt_location; - result->stmt_length = raw_stmt.stmt_len; + result->stmt_location = NumericCast(raw_stmt.stmt_location); + result->stmt_length = NumericCast(raw_stmt.stmt_len); } return result; } diff --git a/src/planner/binder/statement/bind_copy_database.cpp b/src/planner/binder/statement/bind_copy_database.cpp index 8a0047da3263..bccb584b5625 100644 --- a/src/planner/binder/statement/bind_copy_database.cpp +++ b/src/planner/binder/statement/bind_copy_database.cpp @@ -135,7 +135,7 @@ unique_ptr Binder::BindCopyDatabaseData(CopyDatabaseStatement & if (result) { // use UNION ALL to combine the individual copy statements into a single node auto copy_union = - make_uniq(GenerateTableIndex(), 1, std::move(insert_plan), std::move(result), + make_uniq(GenerateTableIndex(), 1U, std::move(insert_plan), std::move(result), LogicalOperatorType::LOGICAL_UNION, true, false); result = std::move(copy_union); } else { diff --git a/src/planner/binder/statement/bind_export.cpp b/src/planner/binder/statement/bind_export.cpp index f5370a20a915..b22681ed18df 100644 --- a/src/planner/binder/statement/bind_export.cpp +++ b/src/planner/binder/statement/bind_export.cpp @@ -331,7 +331,7 @@ BoundStatement Binder::Bind(ExportStatement &stmt) { if (child_operator) { // use UNION ALL to combine the individual copy statements into a single node - auto copy_union = make_uniq(GenerateTableIndex(), 1, std::move(child_operator), + auto copy_union = make_uniq(GenerateTableIndex(), 1U, std::move(child_operator), std::move(plan), LogicalOperatorType::LOGICAL_UNION, true); child_operator = std::move(copy_union); } else { diff --git a/src/planner/binder/statement/bind_insert.cpp b/src/planner/binder/statement/bind_insert.cpp index d0827fde2ebb..f9e8e6e2b578 100644 --- a/src/planner/binder/statement/bind_insert.cpp +++ b/src/planner/binder/statement/bind_insert.cpp @@ -30,7 +30,7 @@ namespace duckdb { -static void CheckInsertColumnCountMismatch(int64_t expected_columns, int64_t result_columns, bool columns_provided, +static void CheckInsertColumnCountMismatch(idx_t expected_columns, idx_t result_columns, bool columns_provided, const char *tname) { if (result_columns != expected_columns) { string msg = StringUtil::Format(!columns_provided ? "table %s has %lld columns but %lld values were supplied" diff --git a/src/planner/binder/tableref/bind_pivot.cpp b/src/planner/binder/tableref/bind_pivot.cpp index 6c89b7b2cd91..4d9cf553097b 100644 --- a/src/planner/binder/tableref/bind_pivot.cpp +++ b/src/planner/binder/tableref/bind_pivot.cpp @@ -89,7 +89,7 @@ static unique_ptr ConstructInitialGrouping(PivotRef &ref, vectorgroups.group_expressions.push_back(make_uniq( - Value::INTEGER(UnsafeNumericCast(subquery->select_list.size() + 1)))); + Value::INTEGER(UnsafeNumericCast(subquery->select_list.size() + 1)))); subquery->select_list.push_back(make_uniq(row)); } } @@ -166,7 +166,7 @@ static unique_ptr PivotInitialAggregate(PivotBindState &bind_state, } auto pivot_alias = pivot_expr->alias; subquery_stage1->groups.group_expressions.push_back(make_uniq( - Value::INTEGER(UnsafeNumericCast(subquery_stage1->select_list.size() + 1)))); + Value::INTEGER(UnsafeNumericCast(subquery_stage1->select_list.size() + 1)))); subquery_stage1->select_list.push_back(std::move(pivot_expr)); pivot_expr = make_uniq(std::move(pivot_alias)); } @@ -203,7 +203,7 @@ static unique_ptr PivotListAggregate(PivotBindState &bind_state, Piv // add all of the groups for (idx_t gr = 0; gr < bind_state.internal_group_names.size(); gr++) { subquery_stage2->groups.group_expressions.push_back(make_uniq( - Value::INTEGER(UnsafeNumericCast(subquery_stage2->select_list.size() + 1)))); + Value::INTEGER(UnsafeNumericCast(subquery_stage2->select_list.size() + 1)))); auto group_reference = make_uniq(bind_state.internal_group_names[gr]); group_reference->alias = bind_state.internal_group_names[gr]; subquery_stage2->select_list.push_back(std::move(group_reference)); diff --git a/src/planner/bound_result_modifier.cpp b/src/planner/bound_result_modifier.cpp index 9a1728ed71d0..a4cb6ed4fd23 100644 --- a/src/planner/bound_result_modifier.cpp +++ b/src/planner/bound_result_modifier.cpp @@ -126,7 +126,7 @@ BoundLimitNode::BoundLimitNode() : type(LimitNodeType::UNSET) { } BoundLimitNode::BoundLimitNode(int64_t constant_value) - : type(LimitNodeType::CONSTANT_VALUE), constant_integer(constant_value) { + : type(LimitNodeType::CONSTANT_VALUE), constant_integer(NumericCast(constant_value)) { } BoundLimitNode::BoundLimitNode(double percentage_value) diff --git a/src/planner/operator/logical_top_n.cpp b/src/planner/operator/logical_top_n.cpp index da1fa493f4ec..adf4019e84b0 100644 --- a/src/planner/operator/logical_top_n.cpp +++ b/src/planner/operator/logical_top_n.cpp @@ -5,7 +5,7 @@ namespace duckdb { idx_t LogicalTopN::EstimateCardinality(ClientContext &context) { auto child_cardinality = LogicalOperator::EstimateCardinality(context); if (limit >= 0 && child_cardinality < idx_t(limit)) { - return limit; + return NumericCast(limit); } return child_cardinality; } diff --git a/src/planner/table_binding.cpp b/src/planner/table_binding.cpp index 1709a4524f72..dff3de5952f5 100644 --- a/src/planner/table_binding.cpp +++ b/src/planner/table_binding.cpp @@ -171,7 +171,7 @@ ColumnBinding TableBinding::GetColumnBinding(column_t column_index) { auto it = std::find_if(column_ids.begin(), column_ids.end(), [&](const column_t &id) -> bool { return id == column_index; }); // Get the index of it - binding.column_index = std::distance(column_ids.begin(), it); + binding.column_index = NumericCast(std::distance(column_ids.begin(), it)); // If it wasn't found, add it if (it == column_ids.end()) { column_ids.push_back(column_index); diff --git a/src/storage/data_table.cpp b/src/storage/data_table.cpp index 6541bf193434..977ff7387fd8 100644 --- a/src/storage/data_table.cpp +++ b/src/storage/data_table.cpp @@ -105,7 +105,7 @@ DataTable::DataTable(ClientContext &context, DataTable &parent, idx_t removed_co // erase the column definitions from this DataTable D_ASSERT(removed_column < column_definitions.size()); - column_definitions.erase(column_definitions.begin() + removed_column); + column_definitions.erase_at(removed_column); storage_t storage_idx = 0; for (idx_t i = 0; i < column_definitions.size(); i++) { diff --git a/src/storage/local_storage.cpp b/src/storage/local_storage.cpp index 2c7fb0fe1d79..513a18e2faf1 100644 --- a/src/storage/local_storage.cpp +++ b/src/storage/local_storage.cpp @@ -219,7 +219,7 @@ void LocalTableStorage::FinalizeOptimisticWriter(OptimisticDataWriter &writer) { for (idx_t i = 0; i < optimistic_writers.size(); i++) { if (optimistic_writers[i].get() == &writer) { owned_writer = std::move(optimistic_writers[i]); - optimistic_writers.erase(optimistic_writers.begin() + i); + optimistic_writers.erase_at(i); break; } } diff --git a/src/storage/table/row_group_collection.cpp b/src/storage/table/row_group_collection.cpp index 00e42d5b7b7f..5fae55491020 100644 --- a/src/storage/table/row_group_collection.cpp +++ b/src/storage/table/row_group_collection.cpp @@ -1056,7 +1056,7 @@ shared_ptr RowGroupCollection::AddColumn(ClientContext &cont shared_ptr RowGroupCollection::RemoveColumn(idx_t col_idx) { D_ASSERT(col_idx < types.size()); auto new_types = types; - new_types.erase(new_types.begin() + col_idx); + new_types.erase_at(col_idx); auto result = make_shared(info, block_manager, std::move(new_types), row_start, total_rows.load()); diff --git a/src/storage/table_index_list.cpp b/src/storage/table_index_list.cpp index ef14073f302b..623585baf6e1 100644 --- a/src/storage/table_index_list.cpp +++ b/src/storage/table_index_list.cpp @@ -21,7 +21,7 @@ void TableIndexList::RemoveIndex(const string &name) { for (idx_t index_idx = 0; index_idx < indexes.size(); index_idx++) { auto &index_entry = indexes[index_idx]; if (index_entry->name == name) { - indexes.erase(indexes.begin() + index_idx); + indexes.erase_at(index_idx); break; } } diff --git a/src/transaction/duck_transaction_manager.cpp b/src/transaction/duck_transaction_manager.cpp index 008604c3502c..bf7014c93411 100644 --- a/src/transaction/duck_transaction_manager.cpp +++ b/src/transaction/duck_transaction_manager.cpp @@ -277,7 +277,7 @@ void DuckTransactionManager::RemoveTransaction(DuckTransaction &transaction) noe } } // remove the transaction from the set of currently active transactions - active_transactions.erase(active_transactions.begin() + t_index); + active_transactions.erase_at(t_index); // traverse the recently_committed transactions to see if we can remove any idx_t i = 0; for (; i < recently_committed_transactions.size(); i++) { diff --git a/src/transaction/meta_transaction.cpp b/src/transaction/meta_transaction.cpp index 7cd7fa450b81..6cf12cca5d27 100644 --- a/src/transaction/meta_transaction.cpp +++ b/src/transaction/meta_transaction.cpp @@ -63,7 +63,7 @@ void MetaTransaction::RemoveTransaction(AttachedDatabase &db) { for (idx_t i = 0; i < all_transactions.size(); i++) { auto &db_entry = all_transactions[i]; if (RefersToSameObject(db_entry.get(), db)) { - all_transactions.erase(all_transactions.begin() + i); + all_transactions.erase_at(i); break; } } diff --git a/third_party/utf8proc/include/utf8proc.hpp b/third_party/utf8proc/include/utf8proc.hpp index a7cc6eef73eb..336f95e2ed5a 100644 --- a/third_party/utf8proc/include/utf8proc.hpp +++ b/third_party/utf8proc/include/utf8proc.hpp @@ -637,7 +637,7 @@ void utf8proc_grapheme_callback(const char *s, size_t len, T &&fun) { size_t start = 0; size_t cpos = 0; while(true) { - cpos += sz; + cpos += UnsafeNumericCast(sz); if (cpos >= len) { fun(start, cpos); return; From dc16f563f8718be445c228ceb8900472471529c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Fri, 22 Mar 2024 17:10:06 +0100 Subject: [PATCH 030/201] mc --- src/core_functions/aggregate/distributive/approx_count.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core_functions/aggregate/distributive/approx_count.cpp b/src/core_functions/aggregate/distributive/approx_count.cpp index 844e31acfd9d..6b8dd0d7c667 100644 --- a/src/core_functions/aggregate/distributive/approx_count.cpp +++ b/src/core_functions/aggregate/distributive/approx_count.cpp @@ -43,7 +43,7 @@ struct ApproxCountDistinctFunction { template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { if (state.log) { - target = state.log->Count(); + target = UnsafeNumericCast(state.log->Count()); } else { target = 0; } From 5774dc67dad81f4ee5c684730f5fa5b8dee03258 Mon Sep 17 00:00:00 2001 From: Tishj Date: Sat, 23 Mar 2024 20:04:43 +0100 Subject: [PATCH 031/201] delay the available disk space lookup until we have made sure the directory exists --- src/common/file_system.cpp | 7 ++++--- .../duckdb/storage/temporary_file_manager.hpp | 2 +- src/storage/temporary_file_manager.cpp | 12 ++++-------- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/common/file_system.cpp b/src/common/file_system.cpp index c58bb588ee07..11b24c75ae46 100644 --- a/src/common/file_system.cpp +++ b/src/common/file_system.cpp @@ -103,12 +103,13 @@ optional_idx FileSystem::GetAvailableMemory() { optional_idx FileSystem::GetAvailableDiskSpace(const string &path) { struct statvfs vfs; - if (statvfs(path.c_str(), &vfs) == -1) { - optional_idx(); + auto ret = statvfs(path.c_str(), &vfs); + if (ret == -1) { + return optional_idx(); } auto block_size = vfs.f_frsize; // These are the blocks available for creating new files or extending existing ones - auto available_blocks = vfs.f_bavail; + auto available_blocks = vfs.f_bfree; idx_t available_disk_space = DConstants::INVALID_INDEX; if (!TryMultiplyOperator::Operation(static_cast(block_size), static_cast(available_blocks), available_disk_space)) { diff --git a/src/include/duckdb/storage/temporary_file_manager.hpp b/src/include/duckdb/storage/temporary_file_manager.hpp index f6a8a17e81a0..92fcbf8eaba5 100644 --- a/src/include/duckdb/storage/temporary_file_manager.hpp +++ b/src/include/duckdb/storage/temporary_file_manager.hpp @@ -145,7 +145,7 @@ class TemporaryDirectoryHandle { class TemporaryFileManager { public: - TemporaryFileManager(DatabaseInstance &db, const string &temp_directory_p, optional_idx max_swap_space); + TemporaryFileManager(DatabaseInstance &db, const string &temp_directory_p); ~TemporaryFileManager(); public: diff --git a/src/storage/temporary_file_manager.cpp b/src/storage/temporary_file_manager.cpp index c4b8643d4e84..bd4da45dc88e 100644 --- a/src/storage/temporary_file_manager.cpp +++ b/src/storage/temporary_file_manager.cpp @@ -188,8 +188,7 @@ idx_t TemporaryFileHandle::GetPositionInFile(idx_t index) { //===--------------------------------------------------------------------===// TemporaryDirectoryHandle::TemporaryDirectoryHandle(DatabaseInstance &db, string path_p, optional_idx max_swap_space) - : db(db), temp_directory(std::move(path_p)), - temp_file(make_uniq(db, temp_directory, max_swap_space)) { + : db(db), temp_directory(std::move(path_p)), temp_file(make_uniq(db, temp_directory)) { auto &fs = FileSystem::GetFileSystem(db); if (!temp_directory.empty()) { if (!fs.DirectoryExists(temp_directory)) { @@ -197,6 +196,7 @@ TemporaryDirectoryHandle::TemporaryDirectoryHandle(DatabaseInstance &db, string created_directory = true; } } + temp_file->SetMaxSwapSpace(max_swap_space); } TemporaryDirectoryHandle::~TemporaryDirectoryHandle() { @@ -263,12 +263,8 @@ static idx_t GetDefaultMax(const string &path) { return default_value; } -TemporaryFileManager::TemporaryFileManager(DatabaseInstance &db, const string &temp_directory_p, - optional_idx max_swap_space) - : db(db), temp_directory(temp_directory_p), size_on_disk(0), max_swap_space(GetDefaultMax(temp_directory_p)) { - if (max_swap_space.IsValid()) { - this->max_swap_space = max_swap_space.GetIndex(); - } +TemporaryFileManager::TemporaryFileManager(DatabaseInstance &db, const string &temp_directory_p) + : db(db), temp_directory(temp_directory_p), size_on_disk(0), max_swap_space(0) { } TemporaryFileManager::~TemporaryFileManager() { From 235d4c0c4f6dd7400594b19add0f735105f2e246 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 26 Mar 2024 15:19:12 +0100 Subject: [PATCH 032/201] more implicit conversions --- .../aggregate/distributive/bitagg.cpp | 13 ++++++++----- .../aggregate/distributive/bitstring_agg.cpp | 4 ++-- src/core_functions/aggregate/distributive/sum.cpp | 4 ++-- src/core_functions/aggregate/holistic/quantile.cpp | 6 ++++-- .../core_functions/aggregate/sum_helpers.hpp | 4 ++-- src/include/duckdb/execution/merge_sort_tree.hpp | 14 +++++++++----- 6 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/core_functions/aggregate/distributive/bitagg.cpp b/src/core_functions/aggregate/distributive/bitagg.cpp index 3943707fd8db..2d57a4f548cc 100644 --- a/src/core_functions/aggregate/distributive/bitagg.cpp +++ b/src/core_functions/aggregate/distributive/bitagg.cpp @@ -10,6 +10,7 @@ namespace duckdb { template struct BitState { + using TYPE = T; bool is_set; T value; }; @@ -67,7 +68,7 @@ struct BitwiseOperation { template static void Assign(STATE &state, INPUT_TYPE input) { - state.value = input; + state.value = typename STATE::TYPE(input); } template @@ -90,7 +91,7 @@ struct BitwiseOperation { if (!state.is_set) { finalize_data.ReturnNull(); } else { - target = state.value; + target = T(state.value); } } @@ -102,21 +103,23 @@ struct BitwiseOperation { struct BitAndOperation : public BitwiseOperation { template static void Execute(STATE &state, INPUT_TYPE input) { - state.value &= input; + state.value &= typename STATE::TYPE(input); + ; } }; struct BitOrOperation : public BitwiseOperation { template static void Execute(STATE &state, INPUT_TYPE input) { - state.value |= input; + state.value |= typename STATE::TYPE(input); + ; } }; struct BitXorOperation : public BitwiseOperation { template static void Execute(STATE &state, INPUT_TYPE input) { - state.value ^= input; + state.value ^= typename STATE::TYPE(input); } template diff --git a/src/core_functions/aggregate/distributive/bitstring_agg.cpp b/src/core_functions/aggregate/distributive/bitstring_agg.cpp index 700b0cce28f8..36920a476439 100644 --- a/src/core_functions/aggregate/distributive/bitstring_agg.cpp +++ b/src/core_functions/aggregate/distributive/bitstring_agg.cpp @@ -107,7 +107,7 @@ struct BitStringAggOperation { if (!TrySubtractOperator::Operation(max, min, result)) { return NumericLimits::Maximum(); } - idx_t val(result); + auto val = NumericCast(result); if (val == NumericLimits::Maximum()) { return val; } @@ -116,7 +116,7 @@ struct BitStringAggOperation { template static void Execute(STATE &state, INPUT_TYPE input, INPUT_TYPE min) { - Bit::SetBit(state.value, input - min, 1); + Bit::SetBit(state.value, UnsafeNumericCast(input - min), 1); } template diff --git a/src/core_functions/aggregate/distributive/sum.cpp b/src/core_functions/aggregate/distributive/sum.cpp index 9f243869ad16..2858d5552009 100644 --- a/src/core_functions/aggregate/distributive/sum.cpp +++ b/src/core_functions/aggregate/distributive/sum.cpp @@ -112,8 +112,8 @@ unique_ptr SumPropagateStats(ClientContext &context, BoundAggreg default: throw InternalException("Unsupported type for propagate sum stats"); } - auto max_sum_negative = max_negative * hugeint_t(input.node_stats->max_cardinality); - auto max_sum_positive = max_positive * hugeint_t(input.node_stats->max_cardinality); + auto max_sum_negative = max_negative * Hugeint::Convert(input.node_stats->max_cardinality); + auto max_sum_positive = max_positive * Hugeint::Convert(input.node_stats->max_cardinality); if (max_sum_positive >= NumericLimits::Maximum() || max_sum_negative <= NumericLimits::Minimum()) { // sum can potentially exceed int64_t bounds: use hugeint sum diff --git a/src/core_functions/aggregate/holistic/quantile.cpp b/src/core_functions/aggregate/holistic/quantile.cpp index 7a7693a6aa27..932d30f626a6 100644 --- a/src/core_functions/aggregate/holistic/quantile.cpp +++ b/src/core_functions/aggregate/holistic/quantile.cpp @@ -356,8 +356,10 @@ struct Interpolator { // Integer arithmetic for accuracy const auto integral = q.integral; const auto scaling = q.scaling; - const auto scaled_q = DecimalMultiplyOverflowCheck::Operation(n, integral); - const auto scaled_n = DecimalMultiplyOverflowCheck::Operation(n, scaling); + const auto scaled_q = + DecimalMultiplyOverflowCheck::Operation(Hugeint::Convert(n), integral); + const auto scaled_n = + DecimalMultiplyOverflowCheck::Operation(Hugeint::Convert(n), scaling); floored = Cast::Operation((scaled_n - scaled_q) / scaling); break; } diff --git a/src/include/duckdb/core_functions/aggregate/sum_helpers.hpp b/src/include/duckdb/core_functions/aggregate/sum_helpers.hpp index 45f533a7f8c4..fceb25635871 100644 --- a/src/include/duckdb/core_functions/aggregate/sum_helpers.hpp +++ b/src/include/duckdb/core_functions/aggregate/sum_helpers.hpp @@ -61,7 +61,7 @@ struct RegularAdd { template static void AddConstant(STATE &state, T input, idx_t count) { - state.value += input * count; + state.value += input * int64_t(count); } }; @@ -123,7 +123,7 @@ struct HugeintAdd { AddValue(state.value, uint64_t(input), input >= 0); } } else { - hugeint_t addition = hugeint_t(input) * count; + hugeint_t addition = hugeint_t(input) * Hugeint::Convert(count); state.value += addition; } } diff --git a/src/include/duckdb/execution/merge_sort_tree.hpp b/src/include/duckdb/execution/merge_sort_tree.hpp index b01d7087f058..e2b19b433d24 100644 --- a/src/include/duckdb/execution/merge_sort_tree.hpp +++ b/src/include/duckdb/execution/merge_sort_tree.hpp @@ -349,7 +349,7 @@ idx_t MergeSortTree::SelectNth(const SubFrames &frames, idx_t n // The first level contains a single run, // so the only thing we need is any cascading pointers auto level_no = tree.size() - 2; - auto level_width = 1; + idx_t level_width = 1; for (idx_t i = 0; i < level_no; ++i) { level_width *= FANOUT; } @@ -367,9 +367,11 @@ idx_t MergeSortTree::SelectNth(const SubFrames &frames, idx_t n for (idx_t f = 0; f < frames.size(); ++f) { const auto &frame = frames[f]; auto &cascade_idx = cascades[f]; - const auto lower_idx = std::lower_bound(level.begin(), level.end(), frame.start) - level.begin(); + const auto lower_idx = + UnsafeNumericCast(std::lower_bound(level.begin(), level.end(), frame.start) - level.begin()); cascade_idx.first = lower_idx / CASCADING * FANOUT; - const auto upper_idx = std::lower_bound(level.begin(), level.end(), frame.end) - level.begin(); + const auto upper_idx = + UnsafeNumericCast(std::lower_bound(level.begin(), level.end(), frame.end) - level.begin()); cascade_idx.second = upper_idx / CASCADING * FANOUT; } @@ -390,11 +392,13 @@ idx_t MergeSortTree::SelectNth(const SubFrames &frames, idx_t n const auto lower_begin = level_data + level_cascades[cascade_idx.first]; const auto lower_end = level_data + level_cascades[cascade_idx.first + FANOUT]; - match.first = std::lower_bound(lower_begin, lower_end, frame.start) - level_data; + match.first = + UnsafeNumericCast(std::lower_bound(lower_begin, lower_end, frame.start) - level_data); const auto upper_begin = level_data + level_cascades[cascade_idx.second]; const auto upper_end = level_data + level_cascades[cascade_idx.second + FANOUT]; - match.second = std::lower_bound(upper_begin, upper_end, frame.end) - level_data; + match.second = + UnsafeNumericCast(std::lower_bound(upper_begin, upper_end, frame.end) - level_data); matched += idx_t(match.second - match.first); } From affb50bd080931f5d64e8fdb971e490a2d362e93 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 28 Mar 2024 19:36:41 +0100 Subject: [PATCH 033/201] fix merge conflicts --- src/include/duckdb/storage/standard_buffer_manager.hpp | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/include/duckdb/storage/standard_buffer_manager.hpp b/src/include/duckdb/storage/standard_buffer_manager.hpp index 7aeeee733411..85d710ab8f33 100644 --- a/src/include/duckdb/storage/standard_buffer_manager.hpp +++ b/src/include/duckdb/storage/standard_buffer_manager.hpp @@ -48,15 +48,10 @@ class StandardBufferManager : public BufferManager { //! Unpin and pin are nops on this block of memory shared_ptr RegisterSmallMemory(idx_t block_size) final; -<<<<<<< HEAD idx_t GetUsedMemory() const final override; idx_t GetMaxMemory() const final override; idx_t GetUsedSwap() final override; optional_idx GetMaxSwap() final override; -======= - idx_t GetUsedMemory() const final; - idx_t GetMaxMemory() const final; ->>>>>>> upstream/main //! Allocate an in-memory buffer with a single pin. //! The allocated memory is released when the buffer handle is destroyed. @@ -80,13 +75,8 @@ class StandardBufferManager : public BufferManager { //! Returns a list of all temporary files vector GetTemporaryFiles() final; -<<<<<<< HEAD const string &GetTemporaryDirectory() final override { return temporary_directory.path; -======= - const string &GetTemporaryDirectory() final { - return temp_directory; ->>>>>>> upstream/main } void SetTemporaryDirectory(const string &new_dir) final; From f30f491fa934864e0fcee0c01e1309ac2a4d6522 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 2 Apr 2024 09:52:22 +0200 Subject: [PATCH 034/201] more casts --- src/include/duckdb/common/vector.hpp | 2 +- src/include/duckdb/execution/merge_sort_tree.hpp | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/include/duckdb/common/vector.hpp b/src/include/duckdb/common/vector.hpp index a2b591902e72..95f76fe4d94e 100644 --- a/src/include/duckdb/common/vector.hpp +++ b/src/include/duckdb/common/vector.hpp @@ -102,7 +102,7 @@ class vector : public std::vector> { // NOL } void erase_at(idx_t idx) { - if (MemorySafety::enabled && idx > original::size()) { + if (MemorySafety::ENABLED && idx > original::size()) { throw InternalException("Can't remove offset %d from vector of size %d", idx, original::size()); } original::erase(original::begin() + static_cast(idx)); diff --git a/src/include/duckdb/execution/merge_sort_tree.hpp b/src/include/duckdb/execution/merge_sort_tree.hpp index 1c750ce54aad..4f752a75410b 100644 --- a/src/include/duckdb/execution/merge_sort_tree.hpp +++ b/src/include/duckdb/execution/merge_sort_tree.hpp @@ -430,8 +430,8 @@ idx_t MergeSortTree::SelectNth(const SubFrames &frames, idx_t n // Continue with the uncascaded levels (except the first) for (; level_no > 0; --level_no) { const auto &level = tree[level_no].first; - auto range_begin = level.begin() + result * level_width; - auto range_end = range_begin + level_width; + auto range_begin = level.begin() + UnsafeNumericCast(result * level_width); + auto range_end = range_begin + UnsafeNumericCast(level_width); while (range_end < level.end()) { idx_t matched = 0; for (idx_t f = 0; f < frames.size(); ++f) { @@ -447,7 +447,7 @@ idx_t MergeSortTree::SelectNth(const SubFrames &frames, idx_t n } // Not enough in this child, so move right range_begin = range_end; - range_end += level_width; + range_end += UnsafeNumericCast(level_width); ++result; n -= matched; } From 815c155eb232e7f12a7d5cdaa1ee4f2beac535c0 Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 2 Apr 2024 11:48:21 +0200 Subject: [PATCH 035/201] remove dead code, fix tidy issue --- .../duckdb/storage/temporary_file_manager.hpp | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/include/duckdb/storage/temporary_file_manager.hpp b/src/include/duckdb/storage/temporary_file_manager.hpp index 92fcbf8eaba5..e5587547ec59 100644 --- a/src/include/duckdb/storage/temporary_file_manager.hpp +++ b/src/include/duckdb/storage/temporary_file_manager.hpp @@ -25,21 +25,9 @@ namespace duckdb { class TemporaryFileManager; -struct FileSizeMonitor { -public: - static constexpr idx_t TEMPFILE_BLOCK_SIZE = Storage::BLOCK_ALLOC_SIZE; - -public: - FileSizeMonitor(TemporaryFileManager &manager); - -public: -private: - TemporaryFileManager &manager; -}; - struct BlockIndexManager { public: - BlockIndexManager(TemporaryFileManager &manager); + explicit BlockIndexManager(TemporaryFileManager &manager); BlockIndexManager(); public: From 72ef808241a0b46f85508634dbdafbc46f0efbff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 2 Apr 2024 15:41:20 +0200 Subject: [PATCH 036/201] more casts --- .../aggregate/holistic/mode.cpp | 4 +-- .../aggregate/holistic/reservoir_quantile.cpp | 10 +++---- src/core_functions/scalar/bit/bitstring.cpp | 27 ++++++++----------- .../scalar/blob/create_sort_key.cpp | 5 ++-- src/core_functions/scalar/date/date_part.cpp | 4 +-- src/include/duckdb/common/radix.hpp | 16 ++++++++--- third_party/tdigest/t_digest.hpp | 12 ++++----- 7 files changed, 41 insertions(+), 37 deletions(-) diff --git a/src/core_functions/aggregate/holistic/mode.cpp b/src/core_functions/aggregate/holistic/mode.cpp index 029b10c31fb7..f33ccc415e99 100644 --- a/src/core_functions/aggregate/holistic/mode.cpp +++ b/src/core_functions/aggregate/holistic/mode.cpp @@ -27,7 +27,7 @@ struct hash { template <> struct hash { inline size_t operator()(const duckdb::hugeint_t &val) const { - return hash {}(val.upper) ^ hash {}(val.lower); + return hash {}(val.upper) ^ hash {}(val.lower); } }; @@ -102,7 +102,7 @@ struct ModeState { void ModeRm(const KEY_TYPE &key, idx_t frame) { auto &attr = (*frequency_map)[key]; auto old_count = attr.count; - nonzero -= int(old_count == 1); + nonzero -= size_t(old_count == 1); attr.count -= 1; if (count == old_count && key == *mode) { diff --git a/src/core_functions/aggregate/holistic/reservoir_quantile.cpp b/src/core_functions/aggregate/holistic/reservoir_quantile.cpp index 7da2cdbeedbf..9e7c24c59522 100644 --- a/src/core_functions/aggregate/holistic/reservoir_quantile.cpp +++ b/src/core_functions/aggregate/holistic/reservoir_quantile.cpp @@ -52,11 +52,11 @@ struct ReservoirQuantileState { struct ReservoirQuantileBindData : public FunctionData { ReservoirQuantileBindData() { } - ReservoirQuantileBindData(double quantile_p, int32_t sample_size_p) + ReservoirQuantileBindData(double quantile_p, idx_t sample_size_p) : quantiles(1, quantile_p), sample_size(sample_size_p) { } - ReservoirQuantileBindData(vector quantiles_p, int32_t sample_size_p) + ReservoirQuantileBindData(vector quantiles_p, idx_t sample_size_p) : quantiles(std::move(quantiles_p)), sample_size(sample_size_p) { } @@ -84,7 +84,7 @@ struct ReservoirQuantileBindData : public FunctionData { } vector quantiles; - int32_t sample_size; + idx_t sample_size; }; struct ReservoirQuantileOperation { @@ -330,7 +330,7 @@ unique_ptr BindReservoirQuantile(ClientContext &context, Aggregate } else { arguments.pop_back(); } - return make_uniq(quantiles, 8192); + return make_uniq(quantiles, 8192ULL); } if (!arguments[2]->IsFoldable()) { throw BinderException("RESERVOIR_QUANTILE can only take constant sample size parameters"); @@ -348,7 +348,7 @@ unique_ptr BindReservoirQuantile(ClientContext &context, Aggregate // remove the quantile argument so we can use the unary aggregate Function::EraseArgument(function, arguments, arguments.size() - 1); Function::EraseArgument(function, arguments, arguments.size() - 1); - return make_uniq(quantiles, sample_size); + return make_uniq(quantiles, NumericCast(sample_size)); } unique_ptr BindReservoirQuantileDecimal(ClientContext &context, AggregateFunction &function, diff --git a/src/core_functions/scalar/bit/bitstring.cpp b/src/core_functions/scalar/bit/bitstring.cpp index babfadfe01e7..8a3074250139 100644 --- a/src/core_functions/scalar/bit/bitstring.cpp +++ b/src/core_functions/scalar/bit/bitstring.cpp @@ -8,17 +8,13 @@ namespace duckdb { // BitStringFunction //===--------------------------------------------------------------------===// static void BitStringFunction(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t n) { - if (n < 0) { - throw InvalidInputException("The bitstring length cannot be negative"); - } - if (idx_t(n) < input.GetSize()) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t input, idx_t n) { + if (n < input.GetSize()) { throw InvalidInputException("Length must be equal or larger than input string"); } idx_t len; Bit::TryGetBitStringSize(input, len, nullptr); // string verification - len = Bit::ComputeBitstringLen(n); string_t target = StringVector::EmptyString(result, len); Bit::BitString(input, n, target); @@ -28,7 +24,7 @@ static void BitStringFunction(DataChunk &args, ExpressionState &state, Vector &r } ScalarFunction BitStringFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::BIT, BitStringFunction); + return ScalarFunction({LogicalType::VARCHAR, LogicalType::UBIGINT}, LogicalType::BIT, BitStringFunction); } //===--------------------------------------------------------------------===// @@ -37,7 +33,7 @@ ScalarFunction BitStringFun::GetFunction() { struct GetBitOperator { template static inline TR Operation(TA input, TB n) { - if (n < 0 || (idx_t)n > Bit::BitLength(input) - 1) { + if (n > Bit::BitLength(input) - 1) { throw OutOfRangeException("bit index %s out of valid range (0..%s)", NumericHelper::ToString(n), NumericHelper::ToString(Bit::BitLength(input) - 1)); } @@ -46,21 +42,20 @@ struct GetBitOperator { }; ScalarFunction GetBitFun::GetFunction() { - return ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::INTEGER, - ScalarFunction::BinaryFunction); + return ScalarFunction({LogicalType::BIT, LogicalType::UBIGINT}, LogicalType::INTEGER, + ScalarFunction::BinaryFunction); } //===--------------------------------------------------------------------===// // set_bit //===--------------------------------------------------------------------===// static void SetBitOperation(DataChunk &args, ExpressionState &state, Vector &result) { - TernaryExecutor::Execute( - args.data[0], args.data[1], args.data[2], result, args.size(), - [&](string_t input, int32_t n, int32_t new_value) { + TernaryExecutor::Execute( + args.data[0], args.data[1], args.data[2], result, args.size(), [&](string_t input, idx_t n, idx_t new_value) { if (new_value != 0 && new_value != 1) { throw InvalidInputException("The new bit must be 1 or 0"); } - if (n < 0 || (idx_t)n > Bit::BitLength(input) - 1) { + if (n > Bit::BitLength(input) - 1) { throw OutOfRangeException("bit index %s out of valid range (0..%s)", NumericHelper::ToString(n), NumericHelper::ToString(Bit::BitLength(input) - 1)); } @@ -72,7 +67,7 @@ static void SetBitOperation(DataChunk &args, ExpressionState &state, Vector &res } ScalarFunction SetBitFun::GetFunction() { - return ScalarFunction({LogicalType::BIT, LogicalType::INTEGER, LogicalType::INTEGER}, LogicalType::BIT, + return ScalarFunction({LogicalType::BIT, LogicalType::UBIGINT, LogicalType::UBIGINT}, LogicalType::BIT, SetBitOperation); } diff --git a/src/core_functions/scalar/blob/create_sort_key.cpp b/src/core_functions/scalar/blob/create_sort_key.cpp index 880acd2c8366..09ee36f5167e 100644 --- a/src/core_functions/scalar/blob/create_sort_key.cpp +++ b/src/core_functions/scalar/blob/create_sort_key.cpp @@ -189,7 +189,7 @@ struct SortKeyVarcharOperator { auto input_data = input.GetDataUnsafe(); auto input_size = input.GetSize(); for (idx_t r = 0; r < input_size; r++) { - result[r] = input_data[r] + 1; + result[r] = UnsafeNumericCast(input_data[r] + 1); } result[input_size] = SortKeyVectorData::STRING_DELIMITER; // null-byte delimiter return input_size + 1; @@ -519,7 +519,8 @@ void ConstructSortKeyList(SortKeyVectorData &vector_data, SortKeyChunk chunk, So } // write the end-of-list delimiter - result_ptr[offset++] = info.flip_bytes ? ~SortKeyVectorData::LIST_DELIMITER : SortKeyVectorData::LIST_DELIMITER; + result_ptr[offset++] = UnsafeNumericCast(info.flip_bytes ? ~SortKeyVectorData::LIST_DELIMITER + : SortKeyVectorData::LIST_DELIMITER); } } diff --git a/src/core_functions/scalar/date/date_part.cpp b/src/core_functions/scalar/date/date_part.cpp index 1c3e16448e23..ebbe158b51fd 100644 --- a/src/core_functions/scalar/date/date_part.cpp +++ b/src/core_functions/scalar/date/date_part.cpp @@ -1432,8 +1432,8 @@ void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec & // Both define epoch, and the correct value is the sum. // So mask it out and compute it separately. - Operation(bigint_values, double_values, d, idx, mask & ~EPOCH); - Operation(bigint_values, double_values, t, idx, mask & ~EPOCH); + Operation(bigint_values, double_values, d, idx, mask & UnsafeNumericCast(~EPOCH)); + Operation(bigint_values, double_values, t, idx, mask & UnsafeNumericCast(~EPOCH)); if (mask & EPOCH) { auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); diff --git a/src/include/duckdb/common/radix.hpp b/src/include/duckdb/common/radix.hpp index a4d774b9bb6d..4b7c89a3374a 100644 --- a/src/include/duckdb/common/radix.hpp +++ b/src/include/duckdb/common/radix.hpp @@ -117,25 +117,33 @@ inline void Radix::EncodeData(data_ptr_t dataptr, bool value) { template <> inline void Radix::EncodeData(data_ptr_t dataptr, int8_t value) { - Store(value, dataptr); + uint8_t bytes; // dance around signedness conversion check + Store(value, data_ptr_cast(&bytes)); + Store(bytes, dataptr); dataptr[0] = FlipSign(dataptr[0]); } template <> inline void Radix::EncodeData(data_ptr_t dataptr, int16_t value) { - Store(BSwap(value), dataptr); + uint16_t bytes; + Store(value, data_ptr_cast(&bytes)); + Store(BSwap(bytes), dataptr); dataptr[0] = FlipSign(dataptr[0]); } template <> inline void Radix::EncodeData(data_ptr_t dataptr, int32_t value) { - Store(BSwap(value), dataptr); + uint32_t bytes; + Store(value, data_ptr_cast(&bytes)); + Store(BSwap(bytes), dataptr); dataptr[0] = FlipSign(dataptr[0]); } template <> inline void Radix::EncodeData(data_ptr_t dataptr, int64_t value) { - Store(BSwap(value), dataptr); + uint64_t bytes; + Store(value, data_ptr_cast(&bytes)); + Store(BSwap(bytes), dataptr); dataptr[0] = FlipSign(dataptr[0]); } diff --git a/third_party/tdigest/t_digest.hpp b/third_party/tdigest/t_digest.hpp index 5c78567d510e..a25668c08eaf 100644 --- a/third_party/tdigest/t_digest.hpp +++ b/third_party/tdigest/t_digest.hpp @@ -219,7 +219,7 @@ class TDigest { pq.push((*iter)); } std::vector batch; - batch.reserve(size); + batch.reserve(size_t(size)); size_t totalSize = 0; while (!pq.empty()) { @@ -324,7 +324,7 @@ class TDigest { CentroidComparator cc; auto iter = std::upper_bound(processed_.cbegin(), processed_.cend(), Centroid(x, 0), cc); - auto i = std::distance(processed_.cbegin(), iter); + auto i = size_t(std::distance(processed_.cbegin(), iter)); auto z1 = x - (iter - 1)->mean(); auto z2 = (iter)->mean() - x; return weightedAverage(cumulative_[i - 1], z2, cumulative_[i], z1) / processedWeight_; @@ -369,7 +369,7 @@ class TDigest { auto iter = std::lower_bound(cumulative_.cbegin(), cumulative_.cend(), index); if (iter + 1 != cumulative_.cend()) { - auto i = std::distance(cumulative_.cbegin(), iter); + auto i = size_t(std::distance(cumulative_.cbegin(), iter)); auto z1 = index - *(iter - 1); auto z2 = *(iter)-index; // LOG(INFO) << "z2 " << z2 << " index " << index << " z1 " << z1; @@ -406,9 +406,9 @@ class TDigest { inline void add(std::vector::const_iterator iter, std::vector::const_iterator end) { while (iter != end) { - const size_t diff = std::distance(iter, end); + const size_t diff = size_t(std::distance(iter, end)); const size_t room = maxUnprocessed_ - unprocessed_.size(); - auto mid = iter + std::min(diff, room); + auto mid = iter + int64_t(std::min(diff, room)); while (iter != mid) { unprocessed_.push_back(*(iter++)); } @@ -538,7 +538,7 @@ class TDigest { std::sort(unprocessed_.begin(), unprocessed_.end(), cc); auto count = unprocessed_.size(); unprocessed_.insert(unprocessed_.end(), processed_.cbegin(), processed_.cend()); - std::inplace_merge(unprocessed_.begin(), unprocessed_.begin() + count, unprocessed_.end(), cc); + std::inplace_merge(unprocessed_.begin(), unprocessed_.begin() + int64_t(count), unprocessed_.end(), cc); processedWeight_ += unprocessedWeight_; unprocessedWeight_ = 0; From 0c2495faf3b8e332281ecb25a4376e4b1a9c0cd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 2 Apr 2024 15:50:18 +0200 Subject: [PATCH 037/201] more casts or type changes --- src/core_functions/scalar/date/strftime.cpp | 2 +- src/core_functions/scalar/generic/system_functions.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/core_functions/scalar/date/strftime.cpp b/src/core_functions/scalar/date/strftime.cpp index 72da80bb5d52..01c907a5d893 100644 --- a/src/core_functions/scalar/date/strftime.cpp +++ b/src/core_functions/scalar/date/strftime.cpp @@ -33,7 +33,7 @@ struct StrfTimeBindData : public FunctionData { template static unique_ptr StrfTimeBindFunction(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { - auto format_idx = REVERSED ? 0 : 1; + auto format_idx = REVERSED ? 0U : 1U; auto &format_arg = arguments[format_idx]; if (format_arg->HasParameter()) { throw ParameterNotResolvedException(); diff --git a/src/core_functions/scalar/generic/system_functions.cpp b/src/core_functions/scalar/generic/system_functions.cpp index 757b819a73ec..52c581e071af 100644 --- a/src/core_functions/scalar/generic/system_functions.cpp +++ b/src/core_functions/scalar/generic/system_functions.cpp @@ -89,7 +89,7 @@ static void TransactionIdCurrent(DataChunk &input, ExpressionState &state, Vecto auto &context = state.GetContext(); auto &catalog = Catalog::GetCatalog(context, DatabaseManager::GetDefaultDatabase(context)); auto &transaction = DuckTransaction::Get(context, catalog); - auto val = Value::BIGINT(transaction.start_time); + auto val = Value::UBIGINT(transaction.start_time); result.Reference(val); } @@ -133,7 +133,7 @@ ScalarFunction InSearchPathFun::GetFunction() { } ScalarFunction CurrentTransactionIdFun::GetFunction() { - ScalarFunction txid_current({}, LogicalType::BIGINT, TransactionIdCurrent); + ScalarFunction txid_current({}, LogicalType::UBIGINT, TransactionIdCurrent); txid_current.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; return txid_current; } From 2f76a7b72dd4427bce5ec0872c58269a667c7c75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Wed, 3 Apr 2024 12:41:41 +0200 Subject: [PATCH 038/201] arrays and jaro --- .../scalar/list/array_slice.cpp | 22 ++++++------ .../scalar/list/list_reduce.cpp | 2 +- src/core_functions/scalar/list/list_sort.cpp | 6 ++-- src/core_functions/scalar/list/list_value.cpp | 2 +- src/core_functions/scalar/list/range.cpp | 4 +-- src/core_functions/scalar/map/map_extract.cpp | 2 +- .../scalar/operators/bitwise.cpp | 4 +-- src/core_functions/scalar/string/bar.cpp | 2 +- src/core_functions/scalar/string/chr.cpp | 2 +- src/core_functions/scalar/string/hex.cpp | 29 ++++++++------- src/core_functions/scalar/string/instr.cpp | 2 +- .../function/scalar/string_functions.hpp | 2 +- third_party/jaro_winkler/details/common.hpp | 18 +++++----- .../jaro_winkler/details/intrinsics.hpp | 2 +- .../jaro_winkler/details/jaro_impl.hpp | 36 +++++++++---------- 15 files changed, 69 insertions(+), 66 deletions(-) diff --git a/src/core_functions/scalar/list/array_slice.cpp b/src/core_functions/scalar/list/array_slice.cpp index 7651f410d023..c91247ae9f69 100644 --- a/src/core_functions/scalar/list/array_slice.cpp +++ b/src/core_functions/scalar/list/array_slice.cpp @@ -40,7 +40,7 @@ unique_ptr ListSliceBindData::Copy() const { } template -static idx_t CalculateSliceLength(idx_t begin, idx_t end, INDEX_TYPE step, bool svalid) { +static idx_t CalculateSliceLength(INDEX_TYPE begin, INDEX_TYPE end, INDEX_TYPE step, bool svalid) { if (step < 0) { step = abs(step); } @@ -48,14 +48,14 @@ static idx_t CalculateSliceLength(idx_t begin, idx_t end, INDEX_TYPE step, bool throw InvalidInputException("Slice step cannot be zero"); } if (step == 1) { - return NumericCast(end - begin); - } else if (static_cast(step) >= (end - begin)) { + return UnsafeNumericCast(end - begin); + } else if (step >= (end - begin)) { return 1; } if ((end - begin) % step != 0) { - return (end - begin) / step + 1; + return UnsafeNumericCast((end - begin) / step + 1); } - return (end - begin) / step; + return UnsafeNumericCast((end - begin) / step); } template @@ -64,7 +64,7 @@ INDEX_TYPE ValueLength(const INPUT_TYPE &value) { } template <> -int64_t ValueLength(const list_entry_t &value) { +idx_t ValueLength(const list_entry_t &value) { return value.length; } @@ -119,8 +119,8 @@ INPUT_TYPE SliceValue(Vector &result, INPUT_TYPE input, INDEX_TYPE begin, INDEX_ template <> list_entry_t SliceValue(Vector &result, list_entry_t input, int64_t begin, int64_t end) { - input.offset += begin; - input.length = end - begin; + input.offset = UnsafeNumericCast(UnsafeNumericCast(input.offset) + begin); + input.length = UnsafeNumericCast(end - begin); return input; } @@ -145,14 +145,14 @@ list_entry_t SliceValueWithSteps(Vector &result, SelectionVector &sel, list_entr return input; } input.length = CalculateSliceLength(begin, end, step, true); - idx_t child_idx = input.offset + begin; + auto child_idx = UnsafeNumericCast(UnsafeNumericCast(input.offset) + begin); if (step < 0) { - child_idx = input.offset + end - 1; + child_idx = UnsafeNumericCast(UnsafeNumericCast(input.offset) + end - 1); } input.offset = sel_idx; for (idx_t i = 0; i < input.length; i++) { sel.set_index(sel_idx, child_idx); - child_idx += step; + child_idx = UnsafeNumericCast(UnsafeNumericCast(child_idx) + step); sel_idx++; } return input; diff --git a/src/core_functions/scalar/list/list_reduce.cpp b/src/core_functions/scalar/list/list_reduce.cpp index 1f619780af2f..b58a741294f5 100644 --- a/src/core_functions/scalar/list/list_reduce.cpp +++ b/src/core_functions/scalar/list/list_reduce.cpp @@ -99,7 +99,7 @@ static bool ExecuteReduce(idx_t loops, ReduceExecuteInfo &execute_info, LambdaFu } // create the index vector - Vector index_vector(Value::BIGINT(loops + 1)); + Vector index_vector(Value::BIGINT(UnsafeNumericCast(loops + 1))); // slice the left and right slice execute_info.left_slice.Slice(execute_info.left_slice, execute_info.left_sel, reduced_row_idx); diff --git a/src/core_functions/scalar/list/list_sort.cpp b/src/core_functions/scalar/list/list_sort.cpp index 1226430c3721..64733ec4ec29 100644 --- a/src/core_functions/scalar/list/list_sort.cpp +++ b/src/core_functions/scalar/list/list_sort.cpp @@ -52,8 +52,8 @@ ListSortBindData::ListSortBindData(OrderType order_type_p, OrderByNullType null_ payload_layout.Initialize(payload_types); // get the BoundOrderByNode - auto idx_col_expr = make_uniq_base(LogicalType::USMALLINT, 0); - auto lists_col_expr = make_uniq_base(child_type, 1); + auto idx_col_expr = make_uniq_base(LogicalType::USMALLINT, 0U); + auto lists_col_expr = make_uniq_base(child_type, 1U); orders.emplace_back(OrderType::ASCENDING, OrderByNullType::ORDER_DEFAULT, std::move(idx_col_expr)); orders.emplace_back(order_type, null_order, std::move(lists_col_expr)); } @@ -241,7 +241,7 @@ static void ListSortFunction(DataChunk &args, ExpressionState &state, Vector &re for (idx_t i = 0; i < count; i++) { for (idx_t j = result_data[i].offset; j < result_data[i].offset + result_data[i].length; j++) { auto b = sel_sorted.get_index(j) - result_data[i].offset; - result_entry.SetValue(j, Value::BIGINT(b + 1)); + result_entry.SetValue(j, Value::BIGINT(UnsafeNumericCast(b + 1))); } } } else { diff --git a/src/core_functions/scalar/list/list_value.cpp b/src/core_functions/scalar/list/list_value.cpp index 9aae09e2c8a4..e2ae537fb723 100644 --- a/src/core_functions/scalar/list/list_value.cpp +++ b/src/core_functions/scalar/list/list_value.cpp @@ -59,7 +59,7 @@ static unique_ptr ListValueBind(ClientContext &context, ScalarFunc StringUtil::Format("Cannot unpivot columns of types %s and %s - an explicit cast is required", child_type.ToString(), arg_type.ToString()); throw BinderException(arguments[i]->query_location, - QueryErrorContext::Format(list_arguments, error, int(error_index), false)); + QueryErrorContext::Format(list_arguments, error, error_index, false)); } else { throw BinderException(arguments[i]->query_location, "Cannot create a list of types %s and %s - an explicit cast is required", diff --git a/src/core_functions/scalar/list/range.cpp b/src/core_functions/scalar/list/range.cpp index 4bb8685388c4..d965eb30d06f 100644 --- a/src/core_functions/scalar/list/range.cpp +++ b/src/core_functions/scalar/list/range.cpp @@ -80,7 +80,7 @@ struct TimestampRangeInfo { if (start_value < end_value && is_negative) { return 0; } - int64_t total_values = 0; + uint64_t total_values = 0; if (is_negative) { // negative interval, start_value is going down while (inclusive_bound ? start_value >= end_value : start_value > end_value) { @@ -203,7 +203,7 @@ static void ListRangeFunction(DataChunk &args, ExpressionState &state, Vector &r } auto list_data = FlatVector::GetData(result); auto &result_validity = FlatVector::Validity(result); - int64_t total_size = 0; + uint64_t total_size = 0; for (idx_t i = 0; i < args_size; i++) { if (!info.RowIsValid(i)) { result_validity.SetInvalid(i); diff --git a/src/core_functions/scalar/map/map_extract.cpp b/src/core_functions/scalar/map/map_extract.cpp index 2986a7f6b48a..8f6c4a8b03f5 100644 --- a/src/core_functions/scalar/map/map_extract.cpp +++ b/src/core_functions/scalar/map/map_extract.cpp @@ -58,7 +58,7 @@ void FillResult(Vector &map, Vector &offsets, Vector &result, idx_t count) { auto &value_list_entry = UnifiedVectorFormat::GetData(map_data)[value_index]; // Add the values to the result - idx_t list_offset = value_list_entry.offset + offset; + idx_t list_offset = value_list_entry.offset + UnsafeNumericCast(offset); // All keys are unique, only one will ever match idx_t length = 1; ListVector::Append(result, values_entries, length + list_offset, list_offset); diff --git a/src/core_functions/scalar/operators/bitwise.cpp b/src/core_functions/scalar/operators/bitwise.cpp index 73f968bb5377..d9bb58950272 100644 --- a/src/core_functions/scalar/operators/bitwise.cpp +++ b/src/core_functions/scalar/operators/bitwise.cpp @@ -251,7 +251,7 @@ static void BitwiseShiftLeftOperation(DataChunk &args, ExpressionState &state, V Bit::SetEmptyBitString(target, input); return target; } - Bit::LeftShift(input, shift, target); + Bit::LeftShift(input, UnsafeNumericCast(shift), target); return target; }); } @@ -294,7 +294,7 @@ static void BitwiseShiftRightOperation(DataChunk &args, ExpressionState &state, Bit::SetEmptyBitString(target, input); return target; } - Bit::RightShift(input, shift, target); + Bit::RightShift(input, UnsafeNumericCast(shift), target); return target; }); } diff --git a/src/core_functions/scalar/string/bar.cpp b/src/core_functions/scalar/string/bar.cpp index 291553a3a2d0..e9cd400cc426 100644 --- a/src/core_functions/scalar/string/bar.cpp +++ b/src/core_functions/scalar/string/bar.cpp @@ -40,7 +40,7 @@ static string_t BarScalarFunction(double x, double min, double max, double max_w result.clear(); - int32_t width_as_int = static_cast(width * PARTIAL_BLOCKS_COUNT); + auto width_as_int = NumericCast(width * PARTIAL_BLOCKS_COUNT); idx_t full_blocks_count = (width_as_int / PARTIAL_BLOCKS_COUNT); for (idx_t i = 0; i < full_blocks_count; i++) { result += FULL_BLOCK; diff --git a/src/core_functions/scalar/string/chr.cpp b/src/core_functions/scalar/string/chr.cpp index e7bb62e1db57..20947cd2f816 100644 --- a/src/core_functions/scalar/string/chr.cpp +++ b/src/core_functions/scalar/string/chr.cpp @@ -16,7 +16,7 @@ struct ChrOperator { char c[5] = {'\0', '\0', '\0', '\0', '\0'}; int utf8_bytes; GetCodepoint(input, c, utf8_bytes); - return string_t(&c[0], utf8_bytes); + return string_t(&c[0], UnsafeNumericCast(utf8_bytes)); } }; diff --git a/src/core_functions/scalar/string/hex.cpp b/src/core_functions/scalar/string/hex.cpp index dffbae70d030..6afaeadb56a4 100644 --- a/src/core_functions/scalar/string/hex.cpp +++ b/src/core_functions/scalar/string/hex.cpp @@ -90,7 +90,8 @@ struct HexIntegralOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(input); + auto num_leading_zero = + UnsafeNumericCast(CountZeros::Leading(UnsafeNumericCast(input))); idx_t num_bits_to_check = 64 - num_leading_zero; D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); @@ -109,7 +110,7 @@ struct HexIntegralOperator { auto target = StringVector::EmptyString(result, buffer_size); auto output = target.GetDataWriteable(); - WriteHexBytes(input, output, buffer_size); + WriteHexBytes(UnsafeNumericCast(input), output, buffer_size); target.Finalize(); return target; @@ -120,7 +121,7 @@ struct HexHugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(input); + auto num_leading_zero = UnsafeNumericCast(CountZeros::Leading(input)); idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); // Special case: All bits are zero @@ -147,7 +148,7 @@ struct HexUhugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(input); + auto num_leading_zero = UnsafeNumericCast(CountZeros::Leading(input)); idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); // Special case: All bits are zero @@ -189,7 +190,7 @@ struct BinaryStrOperator { auto output = target.GetDataWriteable(); for (idx_t i = 0; i < size; ++i) { - uint8_t byte = data[i]; + auto byte = UnsafeNumericCast(data[i]); for (idx_t i = 8; i >= 1; --i) { *output = ((byte >> (i - 1)) & 0x01) + '0'; output++; @@ -205,7 +206,8 @@ struct BinaryIntegralOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(input); + auto num_leading_zero = + UnsafeNumericCast(CountZeros::Leading(UnsafeNumericCast(input))); idx_t num_bits_to_check = 64 - num_leading_zero; D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); @@ -224,7 +226,7 @@ struct BinaryIntegralOperator { auto target = StringVector::EmptyString(result, buffer_size); auto output = target.GetDataWriteable(); - WriteBinBytes(input, output, buffer_size); + WriteBinBytes(UnsafeNumericCast(input), output, buffer_size); target.Finalize(); return target; @@ -234,7 +236,7 @@ struct BinaryIntegralOperator { struct BinaryHugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(input); + auto num_leading_zero = UnsafeNumericCast(CountZeros::Leading(input)); idx_t buffer_size = sizeof(INPUT_TYPE) * 8 - num_leading_zero; // Special case: All bits are zero @@ -259,7 +261,7 @@ struct BinaryHugeIntOperator { struct BinaryUhugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(input); + auto num_leading_zero = UnsafeNumericCast(CountZeros::Leading(input)); idx_t buffer_size = sizeof(INPUT_TYPE) * 8 - num_leading_zero; // Special case: All bits are zero @@ -301,7 +303,7 @@ struct FromHexOperator { // Treated as a single byte idx_t i = 0; if (size % 2 != 0) { - *output = StringUtil::GetHexValue(data[i]); + *output = UnsafeNumericCast(StringUtil::GetHexValue(data[i])); i++; output++; } @@ -309,7 +311,7 @@ struct FromHexOperator { for (; i < size; i += 2) { uint8_t major = StringUtil::GetHexValue(data[i]); uint8_t minor = StringUtil::GetHexValue(data[i + 1]); - *output = UnsafeNumericCast((major << 4) | minor); + *output = UnsafeNumericCast((major << 4) | minor); output++; } @@ -343,7 +345,7 @@ struct FromBinaryOperator { byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); i++; } - *output = byte; + *output = UnsafeNumericCast(byte); output++; } @@ -353,7 +355,8 @@ struct FromBinaryOperator { byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); i++; } - *output = byte; + *output = UnsafeNumericCast(byte); + ; output++; } diff --git a/src/core_functions/scalar/string/instr.cpp b/src/core_functions/scalar/string/instr.cpp index becbbd4896f3..66608db60820 100644 --- a/src/core_functions/scalar/string/instr.cpp +++ b/src/core_functions/scalar/string/instr.cpp @@ -33,7 +33,7 @@ struct InstrAsciiOperator { template static inline TR Operation(TA haystack, TB needle) { auto location = ContainsFun::Find(haystack, needle); - return location == DConstants::INVALID_INDEX ? 0 : location + 1; + return UnsafeNumericCast(location == DConstants::INVALID_INDEX ? 0U : location + 1U); } }; diff --git a/src/include/duckdb/function/scalar/string_functions.hpp b/src/include/duckdb/function/scalar/string_functions.hpp index 3fc6cf62978d..c96c2b3c33c3 100644 --- a/src/include/duckdb/function/scalar/string_functions.hpp +++ b/src/include/duckdb/function/scalar/string_functions.hpp @@ -78,7 +78,7 @@ struct LengthFun { return length; } } - return input_length; + return UnsafeNumericCast(input_length); } }; diff --git a/third_party/jaro_winkler/details/common.hpp b/third_party/jaro_winkler/details/common.hpp index 0a3193fb5b65..6855ce60af01 100644 --- a/third_party/jaro_winkler/details/common.hpp +++ b/third_party/jaro_winkler/details/common.hpp @@ -91,7 +91,7 @@ struct BitvectorHashmap { void insert_mask(CharT key, uint64_t mask) { uint64_t i = lookup(static_cast(key)); - m_map[i].key = key; + m_map[i].key = static_cast(key); m_map[i].value |= mask; } @@ -150,7 +150,7 @@ struct PatternMatchVector { for (int64_t i = 0; i < std::distance(first, last); ++i) { auto key = first[i]; if (key >= 0 && key <= 255) { - m_extendedAscii[key] |= mask; + m_extendedAscii[static_cast(key)] |= mask; } else { m_map.insert_mask(key, mask); @@ -175,7 +175,7 @@ struct PatternMatchVector { uint64_t get(CharT key) const { if (key >= 0 && key <= 255) { - return m_extendedAscii[key]; + return m_extendedAscii[static_cast(key)]; } else { return m_map.get(key); @@ -215,10 +215,10 @@ struct BlockPatternMatchVector { assert(block < m_block_count); if (key >= 0 && key <= 255) { - m_extendedAscii[key * m_block_count + block] |= mask; + m_extendedAscii[static_cast(key * m_block_count + block)] |= mask; } else { - m_map[block].insert_mask(key, mask); + m_map[static_cast(block)].insert_mask(key, mask); } } @@ -227,8 +227,8 @@ struct BlockPatternMatchVector { { int64_t len = std::distance(first, last); m_block_count = ceildiv(len, 64); - m_map.resize(m_block_count); - m_extendedAscii.resize(m_block_count * 256); + m_map.resize(static_cast(m_block_count)); + m_extendedAscii.resize(static_cast(m_block_count * 256)); for (int64_t i = 0; i < len; ++i) { int64_t block = i / 64; @@ -251,10 +251,10 @@ struct BlockPatternMatchVector { { assert(block < m_block_count); if (key >= 0 && key <= 255) { - return m_extendedAscii[key * m_block_count + block]; + return m_extendedAscii[static_cast(key * m_block_count + block)]; } else { - return m_map[block].get(key); + return m_map[static_cast(block)].get(key); } } diff --git a/third_party/jaro_winkler/details/intrinsics.hpp b/third_party/jaro_winkler/details/intrinsics.hpp index 174bdccd841b..45f07c407f16 100644 --- a/third_party/jaro_winkler/details/intrinsics.hpp +++ b/third_party/jaro_winkler/details/intrinsics.hpp @@ -15,7 +15,7 @@ namespace intrinsics { template T bit_mask_lsb(int n) { - T mask = -1; + T mask = static_cast(-1); if (n < static_cast(sizeof(T) * 8)) { mask += static_cast(1) << n; } diff --git a/third_party/jaro_winkler/details/jaro_impl.hpp b/third_party/jaro_winkler/details/jaro_impl.hpp index 7b8f28b893ca..47e6fa506c1c 100644 --- a/third_party/jaro_winkler/details/jaro_impl.hpp +++ b/third_party/jaro_winkler/details/jaro_impl.hpp @@ -147,39 +147,39 @@ static inline void flag_similar_characters_step(const common::BlockPatternMatchV if (BoundMask.words == 1) { uint64_t PM_j = PM.get(word, T_j) & BoundMask.last_mask & BoundMask.first_mask & - (~flagged.P_flag[word]); + (~flagged.P_flag[static_cast(word)]); - flagged.P_flag[word] |= blsi(PM_j); - flagged.T_flag[j_word] |= static_cast(PM_j != 0) << j_pos; + flagged.P_flag[static_cast(word)] |= blsi(PM_j); + flagged.T_flag[static_cast(j_word)] |= static_cast(PM_j != 0) << j_pos; return; } if (BoundMask.first_mask) { - uint64_t PM_j = PM.get(word, T_j) & BoundMask.first_mask & (~flagged.P_flag[word]); + uint64_t PM_j = PM.get(word, T_j) & BoundMask.first_mask & (~flagged.P_flag[static_cast(word)]); if (PM_j) { - flagged.P_flag[word] |= blsi(PM_j); - flagged.T_flag[j_word] |= 1ull << j_pos; + flagged.P_flag[static_cast(word)] |= blsi(PM_j); + flagged.T_flag[static_cast(j_word)] |= 1ull << j_pos; return; } word++; } for (; word < last_word - 1; ++word) { - uint64_t PM_j = PM.get(word, T_j) & (~flagged.P_flag[word]); + uint64_t PM_j = PM.get(word, T_j) & (~flagged.P_flag[static_cast(word)]); if (PM_j) { - flagged.P_flag[word] |= blsi(PM_j); - flagged.T_flag[j_word] |= 1ull << j_pos; + flagged.P_flag[static_cast(word)] |= blsi(PM_j); + flagged.T_flag[static_cast(j_word)] |= 1ull << j_pos; return; } } if (BoundMask.last_mask) { - uint64_t PM_j = PM.get(word, T_j) & BoundMask.last_mask & (~flagged.P_flag[word]); + uint64_t PM_j = PM.get(word, T_j) & BoundMask.last_mask & (~flagged.P_flag[static_cast(word)]); - flagged.P_flag[word] |= blsi(PM_j); - flagged.T_flag[j_word] |= static_cast(PM_j != 0) << j_pos; + flagged.P_flag[static_cast(word)] |= blsi(PM_j); + flagged.T_flag[static_cast(j_word)] |= static_cast(PM_j != 0) << j_pos; } } @@ -199,8 +199,8 @@ flag_similar_characters_block(const common::BlockPatternMatchVector& PM, InputIt int64_t PatternWords = common::ceildiv(P_len, 64); FlaggedCharsMultiword flagged; - flagged.T_flag.resize(TextWords); - flagged.P_flag.resize(PatternWords); + flagged.T_flag.resize(static_cast(TextWords)); + flagged.P_flag.resize(static_cast(PatternWords)); SearchBoundMask BoundMask; int64_t start_range = std::min(Bound + 1, P_len); @@ -262,21 +262,21 @@ count_transpositions_block(const common::BlockPatternMatchVector& PM, InputIt1 T using namespace intrinsics; int64_t TextWord = 0; int64_t PatternWord = 0; - uint64_t T_flag = flagged.T_flag[TextWord]; - uint64_t P_flag = flagged.P_flag[PatternWord]; + uint64_t T_flag = flagged.T_flag[static_cast(TextWord)]; + uint64_t P_flag = flagged.P_flag[static_cast(PatternWord)]; int64_t Transpositions = 0; while (FlaggedChars) { while (!T_flag) { TextWord++; T_first += 64; - T_flag = flagged.T_flag[TextWord]; + T_flag = flagged.T_flag[static_cast(TextWord)]; } while (T_flag) { while (!P_flag) { PatternWord++; - P_flag = flagged.P_flag[PatternWord]; + P_flag = flagged.P_flag[static_cast(PatternWord)]; } uint64_t PatternFlagMask = blsi(P_flag); From 4a8819fe1388431cbac1f290f135b2332ef023eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Wed, 3 Apr 2024 13:18:41 +0200 Subject: [PATCH 039/201] still stuck in scalars --- src/core_functions/scalar/string/pad.cpp | 16 ++--- .../scalar/string/string_split.cpp | 2 +- src/core_functions/scalar/string/to_base.cpp | 2 +- .../scalar/string/translate.cpp | 8 +-- src/core_functions/scalar/string/trim.cpp | 22 +++--- src/core_functions/scalar/string/unicode.cpp | 2 +- src/function/aggregate/distributive/first.cpp | 2 +- src/function/pragma/pragma_queries.cpp | 4 +- .../compress_string.cpp | 2 +- src/function/scalar/list/list_extract.cpp | 4 +- src/function/scalar/list/list_resize.cpp | 4 +- src/function/scalar/list/list_select.cpp | 3 +- src/function/scalar/strftime_format.cpp | 67 ++++++++++--------- src/function/scalar/string/caseconvert.cpp | 8 +-- src/include/duckdb/common/types/bit.hpp | 2 +- 15 files changed, 78 insertions(+), 70 deletions(-) diff --git a/src/core_functions/scalar/string/pad.cpp b/src/core_functions/scalar/string/pad.cpp index 859142293153..856544ea7d93 100644 --- a/src/core_functions/scalar/string/pad.cpp +++ b/src/core_functions/scalar/string/pad.cpp @@ -17,9 +17,9 @@ static pair PadCountChars(const idx_t len, const char *data, const idx_t nchars = 0; for (; nchars < len && nbytes < size; ++nchars) { utf8proc_int32_t codepoint; - auto bytes = utf8proc_iterate(str + nbytes, size - nbytes, &codepoint); + auto bytes = utf8proc_iterate(str + nbytes, UnsafeNumericCast(size - nbytes), &codepoint); D_ASSERT(bytes > 0); - nbytes += bytes; + nbytes += UnsafeNumericCast(bytes); } return pair(nbytes, nchars); @@ -47,9 +47,9 @@ static bool InsertPadding(const idx_t len, const string_t &pad, vector &re // Write the next character utf8proc_int32_t codepoint; - auto bytes = utf8proc_iterate(str + nbytes, size - nbytes, &codepoint); + auto bytes = utf8proc_iterate(str + nbytes, UnsafeNumericCast(size - nbytes), &codepoint); D_ASSERT(bytes > 0); - nbytes += bytes; + nbytes += UnsafeNumericCast(bytes); } // Flush the remaining pad @@ -67,10 +67,10 @@ static string_t LeftPadFunction(const string_t &str, const int32_t len, const st auto size_str = str.GetSize(); // Count how much of str will fit in the output - auto written = PadCountChars(len, data_str, size_str); + auto written = PadCountChars(UnsafeNumericCast(len), data_str, size_str); // Left pad by the number of characters still needed - if (!InsertPadding(len - written.second, pad, result)) { + if (!InsertPadding(UnsafeNumericCast(len) - written.second, pad, result)) { throw InvalidInputException("Insufficient padding in LPAD."); } @@ -96,13 +96,13 @@ static string_t RightPadFunction(const string_t &str, const int32_t len, const s auto size_str = str.GetSize(); // Count how much of str will fit in the output - auto written = PadCountChars(len, data_str, size_str); + auto written = PadCountChars(UnsafeNumericCast(len), data_str, size_str); // Append as much of the original string as fits result.insert(result.end(), data_str, data_str + written.first); // Right pad by the number of characters still needed - if (!InsertPadding(len - written.second, pad, result)) { + if (!InsertPadding(UnsafeNumericCast(len) - written.second, pad, result)) { throw InvalidInputException("Insufficient padding in RPAD."); }; diff --git a/src/core_functions/scalar/string/string_split.cpp b/src/core_functions/scalar/string/string_split.cpp index 9d3ea23be52c..c62cacd753e5 100644 --- a/src/core_functions/scalar/string/string_split.cpp +++ b/src/core_functions/scalar/string/string_split.cpp @@ -51,7 +51,7 @@ struct ConstantRegexpStringSplit { return DConstants::INVALID_INDEX; } match_size = match.size(); - return match.data() - input_data; + return UnsafeNumericCast(match.data() - input_data); } }; diff --git a/src/core_functions/scalar/string/to_base.cpp b/src/core_functions/scalar/string/to_base.cpp index ad5e1088ef24..963f4f562f42 100644 --- a/src/core_functions/scalar/string/to_base.cpp +++ b/src/core_functions/scalar/string/to_base.cpp @@ -48,7 +48,7 @@ static void ToBaseFunction(DataChunk &args, ExpressionState &state, Vector &resu length++; } - return StringVector::AddString(result, ptr, end - ptr); + return StringVector::AddString(result, ptr, UnsafeNumericCast(end - ptr)); }); } diff --git a/src/core_functions/scalar/string/translate.cpp b/src/core_functions/scalar/string/translate.cpp index 192438697b74..c01ec2a25093 100644 --- a/src/core_functions/scalar/string/translate.cpp +++ b/src/core_functions/scalar/string/translate.cpp @@ -37,10 +37,10 @@ static string_t TranslateScalarFunction(const string_t &haystack, const string_t while (i < size_needle && j < size_thread) { auto codepoint_needle = Utf8Proc::UTF8ToCodepoint(input_needle, sz); input_needle += sz; - i += sz; + i += UnsafeNumericCast(sz); auto codepoint_thread = Utf8Proc::UTF8ToCodepoint(input_thread, sz); input_thread += sz; - j += sz; + j += UnsafeNumericCast(sz); // Ignore unicode character that is existed in to_replace if (to_replace.count(codepoint_needle) == 0) { to_replace[codepoint_needle] = codepoint_thread; @@ -52,7 +52,7 @@ static string_t TranslateScalarFunction(const string_t &haystack, const string_t while (i < size_needle) { auto codepoint_needle = Utf8Proc::UTF8ToCodepoint(input_needle, sz); input_needle += sz; - i += sz; + i += UnsafeNumericCast(sz); // Add unicode character that will be deleted if (to_replace.count(codepoint_needle) == 0) { to_delete.insert(codepoint_needle); @@ -60,7 +60,7 @@ static string_t TranslateScalarFunction(const string_t &haystack, const string_t } char c[5] = {'\0', '\0', '\0', '\0', '\0'}; - for (i = 0; i < size_haystack; i += sz) { + for (i = 0; i < size_haystack; i += UnsafeNumericCast(sz)) { auto codepoint_haystack = Utf8Proc::UTF8ToCodepoint(input_haystack, sz); if (to_replace.count(codepoint_haystack) != 0) { Utf8Proc::CodepointToUtf8(to_replace[codepoint_haystack], c_sz, c); diff --git a/src/core_functions/scalar/string/trim.cpp b/src/core_functions/scalar/string/trim.cpp index 91e3b5dd1970..d89ebbaffb6d 100644 --- a/src/core_functions/scalar/string/trim.cpp +++ b/src/core_functions/scalar/string/trim.cpp @@ -23,12 +23,13 @@ struct TrimOperator { idx_t begin = 0; if (LTRIM) { while (begin < size) { - auto bytes = utf8proc_iterate(str + begin, size - begin, &codepoint); + auto bytes = + utf8proc_iterate(str + begin, UnsafeNumericCast(size - begin), &codepoint); D_ASSERT(bytes > 0); if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { break; } - begin += bytes; + begin += UnsafeNumericCast(bytes); } } @@ -37,9 +38,9 @@ struct TrimOperator { if (RTRIM) { end = begin; for (auto next = begin; next < size;) { - auto bytes = utf8proc_iterate(str + next, size - next, &codepoint); + auto bytes = utf8proc_iterate(str + next, UnsafeNumericCast(size - next), &codepoint); D_ASSERT(bytes > 0); - next += bytes; + next += UnsafeNumericCast(bytes); if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { end = next; } @@ -69,7 +70,8 @@ static void GetIgnoredCodepoints(string_t ignored, unordered_set( + utf8proc_iterate(dataptr + pos, UnsafeNumericCast(size - pos), &codepoint)); ignored_codepoints.insert(codepoint); } } @@ -91,11 +93,12 @@ static void BinaryTrimFunction(DataChunk &input, ExpressionState &state, Vector idx_t begin = 0; if (LTRIM) { while (begin < size) { - auto bytes = utf8proc_iterate(str + begin, size - begin, &codepoint); + auto bytes = + utf8proc_iterate(str + begin, UnsafeNumericCast(size - begin), &codepoint); if (ignored_codepoints.find(codepoint) == ignored_codepoints.end()) { break; } - begin += bytes; + begin += UnsafeNumericCast(bytes); } } @@ -104,9 +107,10 @@ static void BinaryTrimFunction(DataChunk &input, ExpressionState &state, Vector if (RTRIM) { end = begin; for (auto next = begin; next < size;) { - auto bytes = utf8proc_iterate(str + next, size - next, &codepoint); + auto bytes = + utf8proc_iterate(str + next, UnsafeNumericCast(size - next), &codepoint); D_ASSERT(bytes > 0); - next += bytes; + next += UnsafeNumericCast(bytes); if (ignored_codepoints.find(codepoint) == ignored_codepoints.end()) { end = next; } diff --git a/src/core_functions/scalar/string/unicode.cpp b/src/core_functions/scalar/string/unicode.cpp index b621c53202f1..b62a129aad61 100644 --- a/src/core_functions/scalar/string/unicode.cpp +++ b/src/core_functions/scalar/string/unicode.cpp @@ -15,7 +15,7 @@ struct UnicodeOperator { auto str = reinterpret_cast(input.GetData()); auto len = input.GetSize(); utf8proc_int32_t codepoint; - (void)utf8proc_iterate(str, len, &codepoint); + (void)utf8proc_iterate(str, UnsafeNumericCast(len), &codepoint); return codepoint; } }; diff --git a/src/function/aggregate/distributive/first.cpp b/src/function/aggregate/distributive/first.cpp index 2d2aa0530395..143cf4317573 100644 --- a/src/function/aggregate/distributive/first.cpp +++ b/src/function/aggregate/distributive/first.cpp @@ -88,7 +88,7 @@ struct FirstFunctionString : public FirstFunctionBase { auto ptr = LAST ? new char[len] : char_ptr_cast(input_data.allocator.Allocate(len)); memcpy(ptr, value.GetData(), len); - state.value = string_t(ptr, UnsafeNumericCast(len)); + state.value = string_t(ptr, UnsafeNumericCast(len)); } } } diff --git a/src/function/pragma/pragma_queries.cpp b/src/function/pragma/pragma_queries.cpp index a69d735bc826..d1033bbc55d2 100644 --- a/src/function/pragma/pragma_queries.cpp +++ b/src/function/pragma/pragma_queries.cpp @@ -141,9 +141,9 @@ string PragmaImportDatabase(ClientContext &context, const FunctionParameters &pa auto file_path = fs.JoinPath(parameters.values[0].ToString(), file); auto handle = fs.OpenFile(file_path, FileFlags::FILE_FLAGS_READ); auto fsize = fs.GetFileSize(*handle); - auto buffer = make_unsafe_uniq_array(fsize); + auto buffer = make_unsafe_uniq_array(UnsafeNumericCast(fsize)); fs.Read(*handle, buffer.get(), fsize); - auto query = string(buffer.get(), fsize); + auto query = string(buffer.get(), UnsafeNumericCast(fsize)); // Replace the placeholder with the path provided to IMPORT if (file == "load.sql") { Parser parser; diff --git a/src/function/scalar/compressed_materialization/compress_string.cpp b/src/function/scalar/compressed_materialization/compress_string.cpp index 1125322fcb6e..a8e046b60bcf 100644 --- a/src/function/scalar/compressed_materialization/compress_string.cpp +++ b/src/function/scalar/compressed_materialization/compress_string.cpp @@ -39,7 +39,7 @@ static inline RESULT_TYPE StringCompressInternal(const string_t &input) { ReverseMemCpy(result_ptr + remainder, data_ptr_cast(input.GetPointer()), input.GetSize()); memset(result_ptr, '\0', remainder); } - result_ptr[0] = UnsafeNumericCast(input.GetSize()); + result_ptr[0] = UnsafeNumericCast(input.GetSize()); return result; } diff --git a/src/function/scalar/list/list_extract.cpp b/src/function/scalar/list/list_extract.cpp index 822b9079cdc9..e1566641aafb 100644 --- a/src/function/scalar/list/list_extract.cpp +++ b/src/function/scalar/list/list_extract.cpp @@ -62,13 +62,13 @@ void ListExtractTemplate(idx_t count, UnifiedVectorFormat &list_data, UnifiedVec result_mask.SetInvalid(i); continue; } - child_offset = list_entry.offset + list_entry.length + offsets_entry; + child_offset = list_entry.offset + list_entry.length + UnsafeNumericCast(offsets_entry); } else { if ((idx_t)offsets_entry >= list_entry.length) { result_mask.SetInvalid(i); continue; } - child_offset = list_entry.offset + offsets_entry; + child_offset = list_entry.offset + UnsafeNumericCast(offsets_entry); } auto child_index = child_format.sel->get_index(child_offset); if (child_format.validity.RowIsValid(child_index)) { diff --git a/src/function/scalar/list/list_resize.cpp b/src/function/scalar/list/list_resize.cpp index edd6ca29108b..04f494139cd4 100644 --- a/src/function/scalar/list/list_resize.cpp +++ b/src/function/scalar/list/list_resize.cpp @@ -38,7 +38,7 @@ void ListResizeFunction(DataChunk &args, ExpressionState &state, Vector &result) for (idx_t i = 0; i < count; i++) { auto index = new_size_data.sel->get_index(i); if (new_size_data.validity.RowIsValid(index)) { - new_child_size += new_size_entries[index]; + new_child_size += UnsafeNumericCast(new_size_entries[index]); } } @@ -72,7 +72,7 @@ void ListResizeFunction(DataChunk &args, ExpressionState &state, Vector &result) idx_t new_size_entry = 0; if (new_size_data.validity.RowIsValid(new_index)) { - new_size_entry = new_size_entries[new_index]; + new_size_entry = UnsafeNumericCast(new_size_entries[new_index]); } // find the smallest size between lists and new_sizes diff --git a/src/function/scalar/list/list_select.cpp b/src/function/scalar/list/list_select.cpp index 1878f808ad30..9d45de5a1ecc 100644 --- a/src/function/scalar/list/list_select.cpp +++ b/src/function/scalar/list/list_select.cpp @@ -11,7 +11,8 @@ struct SetSelectionVectorSelect { ValidityMask &input_validity, Vector &selection_entry, idx_t child_idx, idx_t &target_offset, idx_t selection_offset, idx_t input_offset, idx_t target_length) { - idx_t sel_idx = selection_entry.GetValue(selection_offset + child_idx).GetValue() - 1; + auto sel_idx = + UnsafeNumericCast(selection_entry.GetValue(selection_offset + child_idx).GetValue() - 1); if (sel_idx < target_length) { selection_vector.set_index(target_offset, input_offset + sel_idx); if (!input_validity.RowIsValid(input_offset + sel_idx)) { diff --git a/src/function/scalar/strftime_format.cpp b/src/function/scalar/strftime_format.cpp index f50f3dac4b7f..9a948287c3a6 100644 --- a/src/function/scalar/strftime_format.cpp +++ b/src/function/scalar/strftime_format.cpp @@ -80,7 +80,7 @@ idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date if (0 <= year && year <= 9999) { return 4; } else { - return NumericHelper::SignedLength(year); + return UnsafeNumericCast(NumericHelper::SignedLength(year)); } } case StrTimeSpecifier::MONTH_DECIMAL: { @@ -129,11 +129,14 @@ idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date return len; } case StrTimeSpecifier::DAY_OF_MONTH: - return NumericHelper::UnsignedLength(Date::ExtractDay(date)); + return UnsafeNumericCast( + NumericHelper::UnsignedLength(UnsafeNumericCast(Date::ExtractDay(date)))); case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: - return NumericHelper::UnsignedLength(Date::ExtractDayOfTheYear(date)); + return UnsafeNumericCast( + NumericHelper::UnsignedLength(UnsafeNumericCast(Date::ExtractDayOfTheYear(date)))); case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - return NumericHelper::UnsignedLength(AbsValue(Date::ExtractYear(date)) % 100); + return UnsafeNumericCast(NumericHelper::UnsignedLength( + AbsValue(UnsafeNumericCast(Date::ExtractYear(date)) % 100))); default: throw InternalException("Unimplemented specifier for GetSpecifierLength"); } @@ -195,13 +198,13 @@ char *StrfTimeFormat::WritePadded(char *target, uint32_t value, size_t padding) D_ASSERT(padding > 1); if (padding % 2) { int decimals = value % 1000; - WritePadded3(target + padding - 3, decimals); + WritePadded3(target + padding - 3, UnsafeNumericCast(decimals)); value /= 1000; padding -= 3; } for (size_t i = 0; i < padding / 2; i++) { int decimals = value % 100; - WritePadded2(target + padding - 2 * (i + 1), decimals); + WritePadded2(target + padding - 2 * (i + 1), UnsafeNumericCast(decimals)); value /= 100; } return target + padding; @@ -245,26 +248,26 @@ char *StrfTimeFormat::WriteDateSpecifier(StrTimeSpecifier specifier, date_t date } case StrTimeSpecifier::DAY_OF_YEAR_PADDED: { int32_t doy = Date::ExtractDayOfTheYear(date); - target = WritePadded3(target, doy); + target = WritePadded3(target, UnsafeNumericCast(doy)); break; } case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - target = WritePadded2(target, Date::ExtractWeekNumberRegular(date, true)); + target = WritePadded2(target, UnsafeNumericCast(Date::ExtractWeekNumberRegular(date, true))); break; case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - target = WritePadded2(target, Date::ExtractWeekNumberRegular(date, false)); + target = WritePadded2(target, UnsafeNumericCast(Date::ExtractWeekNumberRegular(date, false))); break; case StrTimeSpecifier::WEEK_NUMBER_ISO: - target = WritePadded2(target, Date::ExtractISOWeekNumber(date)); + target = WritePadded2(target, UnsafeNumericCast(Date::ExtractISOWeekNumber(date))); break; case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: { - uint32_t doy = Date::ExtractDayOfTheYear(date); + auto doy = UnsafeNumericCast(Date::ExtractDayOfTheYear(date)); target += NumericHelper::UnsignedLength(doy); NumericHelper::FormatUnsigned(doy, target); break; } case StrTimeSpecifier::YEAR_ISO: - target = WritePadded(target, Date::ExtractISOYearNumber(date), 4); + target = WritePadded(target, UnsafeNumericCast(Date::ExtractISOYearNumber(date)), 4); break; case StrTimeSpecifier::WEEKDAY_ISO: *target = char('0' + uint8_t(Date::ExtractISODayOfTheWeek(date))); @@ -281,7 +284,7 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t // data contains [0] year, [1] month, [2] day, [3] hour, [4] minute, [5] second, [6] msec, [7] utc switch (specifier) { case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - target = WritePadded2(target, data[2]); + target = WritePadded2(target, UnsafeNumericCast(data[2])); break; case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: { auto &month_name = Date::MONTH_NAMES_ABBREVIATED[data[1] - 1]; @@ -292,14 +295,14 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t return WriteString(target, month_name); } case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - target = WritePadded2(target, data[1]); + target = WritePadded2(target, UnsafeNumericCast(data[1])); break; case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - target = WritePadded2(target, AbsValue(data[0]) % 100); + target = WritePadded2(target, UnsafeNumericCast(AbsValue(data[0]) % 100)); break; case StrTimeSpecifier::YEAR_DECIMAL: if (data[0] >= 0 && data[0] <= 9999) { - target = WritePadded(target, data[0], 4); + target = WritePadded(target, UnsafeNumericCast(data[0]), 4); } else { int32_t year = data[0]; if (data[0] < 0) { @@ -307,13 +310,13 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t year = -year; target++; } - auto len = NumericHelper::UnsignedLength(year); + auto len = NumericHelper::UnsignedLength(UnsafeNumericCast(year)); NumericHelper::FormatUnsigned(year, target + len); target += len; } break; case StrTimeSpecifier::HOUR_24_PADDED: { - target = WritePadded2(target, data[3]); + target = WritePadded2(target, UnsafeNumericCast(data[3])); break; } case StrTimeSpecifier::HOUR_12_PADDED: { @@ -321,7 +324,7 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t if (hour == 0) { hour = 12; } - target = WritePadded2(target, hour); + target = WritePadded2(target, UnsafeNumericCast(hour)); break; } case StrTimeSpecifier::AM_PM: @@ -329,20 +332,20 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t *target++ = 'M'; break; case StrTimeSpecifier::MINUTE_PADDED: { - target = WritePadded2(target, data[4]); + target = WritePadded2(target, UnsafeNumericCast(data[4])); break; } case StrTimeSpecifier::SECOND_PADDED: - target = WritePadded2(target, data[5]); + target = WritePadded2(target, UnsafeNumericCast(data[5])); break; case StrTimeSpecifier::NANOSECOND_PADDED: - target = WritePadded(target, data[6] * Interval::NANOS_PER_MICRO, 9); + target = WritePadded(target, UnsafeNumericCast(data[6] * Interval::NANOS_PER_MICRO), 9); break; case StrTimeSpecifier::MICROSECOND_PADDED: - target = WritePadded(target, data[6], 6); + target = WritePadded(target, UnsafeNumericCast(data[6]), 6); break; case StrTimeSpecifier::MILLISECOND_PADDED: - target = WritePadded3(target, data[6] / Interval::MICROS_PER_MSEC); + target = WritePadded3(target, UnsafeNumericCast(data[6] / Interval::MICROS_PER_MSEC)); break; case StrTimeSpecifier::UTC_OFFSET: { *target++ = (data[7] < 0) ? '-' : '+'; @@ -350,10 +353,10 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t auto offset = abs(data[7]); auto offset_hours = offset / Interval::MINS_PER_HOUR; auto offset_minutes = offset % Interval::MINS_PER_HOUR; - target = WritePadded2(target, offset_hours); + target = WritePadded2(target, UnsafeNumericCast(offset_hours)); if (offset_minutes) { *target++ = ':'; - target = WritePadded2(target, offset_minutes); + target = WritePadded2(target, UnsafeNumericCast(offset_minutes)); } break; } @@ -364,7 +367,7 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t } break; case StrTimeSpecifier::DAY_OF_MONTH: { - target = Write2(target, data[2] % 100); + target = Write2(target, UnsafeNumericCast(data[2] % 100)); break; } case StrTimeSpecifier::MONTH_DECIMAL: { @@ -372,7 +375,7 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t break; } case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: { - target = Write2(target, AbsValue(data[0]) % 100); + target = Write2(target, UnsafeNumericCast(AbsValue(data[0]) % 100)); break; } case StrTimeSpecifier::HOUR_24_DECIMAL: { @@ -845,9 +848,9 @@ bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { // numeric specifier: parse a number uint64_t number = 0; size_t start_pos = pos; - size_t end_pos = start_pos + numeric_width[i]; + size_t end_pos = start_pos + UnsafeNumericCast(numeric_width[i]); while (pos < size && pos < end_pos && StringUtil::CharacterIsDigit(data[pos])) { - number = number * 10 + data[pos] - '0'; + number = number * 10 + UnsafeNumericCast(data[pos]) - '0'; pos++; } if (pos == start_pos) { @@ -1229,7 +1232,7 @@ bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { // But tz must not be empty. if (tz_end == tz_begin) { error_message = "Empty Time Zone name"; - error_position = tz_begin - data; + error_position = UnsafeNumericCast(tz_begin - data); return false; } result.tz.assign(tz_begin, tz_end); @@ -1288,7 +1291,7 @@ bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: { // Adjust weekday to be 0-based for the week type if (has_weekday) { - weekday = (weekday + 7 - int(offset_specifier == StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST)) % 7; + weekday = (weekday + 7 - uint64_t(offset_specifier == StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST)) % 7; } // Get the start of week 1, move back 7 days and then weekno * 7 + weekday gives the date const auto jan1 = Date::FromDate(result_data[0], 1, 1); diff --git a/src/function/scalar/string/caseconvert.cpp b/src/function/scalar/string/caseconvert.cpp index 100ed9765a35..bf54b9ac38e3 100644 --- a/src/function/scalar/string/caseconvert.cpp +++ b/src/function/scalar/string/caseconvert.cpp @@ -58,11 +58,11 @@ static idx_t GetResultLength(const char *input_data, idx_t input_length) { if (input_data[i] & 0x80) { // unicode int sz = 0; - int codepoint = utf8proc_codepoint(input_data + i, sz); - int converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); - int new_sz = utf8proc_codepoint_length(converted_codepoint); + auto codepoint = utf8proc_codepoint(input_data + i, sz); + auto converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); + auto new_sz = utf8proc_codepoint_length(converted_codepoint); D_ASSERT(new_sz >= 0); - output_length += new_sz; + output_length += UnsafeNumericCast(new_sz); i += sz; } else { // ascii diff --git a/src/include/duckdb/common/types/bit.hpp b/src/include/duckdb/common/types/bit.hpp index 2dd594dd5524..903b3cc75ff9 100644 --- a/src/include/duckdb/common/types/bit.hpp +++ b/src/include/duckdb/common/types/bit.hpp @@ -104,7 +104,7 @@ void Bit::NumericToBit(T numeric, string_t &output_str) { *output = 0; // set padding to 0 ++output; for (idx_t idx = 0; idx < sizeof(T); ++idx) { - output[idx] = data[sizeof(T) - idx - 1]; + output[idx] = UnsafeNumericCast(data[sizeof(T) - idx - 1]); } Bit::Finalize(output_str); } From 4b460b42186f8cea996ce0c7c3e97bf2ff7a8f5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Wed, 3 Apr 2024 13:44:55 +0200 Subject: [PATCH 040/201] moar --- src/catalog/dependency_manager.cpp | 584 ------------------ src/catalog/duck_catalog.cpp | 153 ----- src/function/scalar/string/caseconvert.cpp | 16 +- src/function/scalar/string/length.cpp | 18 +- src/function/scalar/string/substring.cpp | 2 +- .../rule/conjunction_simplification.cpp | 69 --- src/optimizer/rule/constant_folding.cpp | 42 -- .../rule/date_part_simplification.cpp | 103 --- src/optimizer/rule/empty_needle_removal.cpp | 53 -- src/optimizer/rule/enum_comparison.cpp | 69 --- .../rule/in_clause_simplification_rule.cpp | 56 -- src/optimizer/rule/like_optimizations.cpp | 161 ----- src/optimizer/rule/move_constants.cpp | 164 ----- .../rule/ordered_aggregate_optimizer.cpp | 101 --- src/planner/CMakeLists.txt | 26 - 15 files changed, 18 insertions(+), 1599 deletions(-) diff --git a/src/catalog/dependency_manager.cpp b/src/catalog/dependency_manager.cpp index 918bd3b276b7..8b137891791f 100644 --- a/src/catalog/dependency_manager.cpp +++ b/src/catalog/dependency_manager.cpp @@ -1,585 +1 @@ -#include "duckdb/catalog/dependency_manager.hpp" -#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" -#include "duckdb/catalog/duck_catalog.hpp" -#include "duckdb/catalog/catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/catalog/dependency_list.hpp" -#include "duckdb/common/enums/catalog_type.hpp" -#include "duckdb/catalog/catalog_entry/dependency/dependency_entry.hpp" -#include "duckdb/catalog/catalog_entry/dependency/dependency_subject_entry.hpp" -#include "duckdb/catalog/catalog_entry/dependency/dependency_dependent_entry.hpp" -#include "duckdb/catalog/catalog_entry/duck_schema_entry.hpp" -#include "duckdb/catalog/dependency_catalog_set.hpp" -namespace duckdb { - -static void AssertMangledName(const string &mangled_name, idx_t expected_null_bytes) { -#ifdef DEBUG - idx_t nullbyte_count = 0; - for (auto &ch : mangled_name) { - nullbyte_count += ch == '\0'; - } - D_ASSERT(nullbyte_count == expected_null_bytes); -#endif -} - -MangledEntryName::MangledEntryName(const CatalogEntryInfo &info) { - auto &type = info.type; - auto &schema = info.schema; - auto &name = info.name; - - this->name = CatalogTypeToString(type) + '\0' + schema + '\0' + name; - AssertMangledName(this->name, 2); -} - -MangledDependencyName::MangledDependencyName(const MangledEntryName &from, const MangledEntryName &to) { - this->name = from.name + '\0' + to.name; - AssertMangledName(this->name, 5); -} - -DependencyManager::DependencyManager(DuckCatalog &catalog) : catalog(catalog), subjects(catalog), dependents(catalog) { -} - -string DependencyManager::GetSchema(CatalogEntry &entry) { - if (entry.type == CatalogType::SCHEMA_ENTRY) { - return entry.name; - } - return entry.ParentSchema().name; -} - -MangledEntryName DependencyManager::MangleName(const CatalogEntryInfo &info) { - return MangledEntryName(info); -} - -MangledEntryName DependencyManager::MangleName(CatalogEntry &entry) { - if (entry.type == CatalogType::DEPENDENCY_ENTRY) { - auto &dependency_entry = entry.Cast(); - return dependency_entry.EntryMangledName(); - } - auto type = entry.type; - auto schema = GetSchema(entry); - auto name = entry.name; - CatalogEntryInfo info {type, schema, name}; - - return MangleName(info); -} - -DependencyInfo DependencyInfo::FromSubject(DependencyEntry &dep) { - return DependencyInfo {/*dependent = */ dep.Dependent(), - /*subject = */ dep.Subject()}; -} - -DependencyInfo DependencyInfo::FromDependent(DependencyEntry &dep) { - return DependencyInfo {/*dependent = */ dep.Dependent(), - /*subject = */ dep.Subject()}; -} - -// ----------- DEPENDENCY_MANAGER ----------- - -bool DependencyManager::IsSystemEntry(CatalogEntry &entry) const { - if (entry.internal) { - return true; - } - - switch (entry.type) { - case CatalogType::DEPENDENCY_ENTRY: - case CatalogType::DATABASE_ENTRY: - case CatalogType::RENAMED_ENTRY: - return true; - default: - return false; - } -} - -CatalogSet &DependencyManager::Dependents() { - return dependents; -} - -CatalogSet &DependencyManager::Subjects() { - return subjects; -} - -void DependencyManager::ScanSetInternal(CatalogTransaction transaction, const CatalogEntryInfo &info, - bool scan_subjects, dependency_callback_t &callback) { - catalog_entry_set_t other_entries; - - auto cb = [&](CatalogEntry &other) { - D_ASSERT(other.type == CatalogType::DEPENDENCY_ENTRY); - auto &other_entry = other.Cast(); -#ifdef DEBUG - auto side = other_entry.Side(); - if (scan_subjects) { - D_ASSERT(side == DependencyEntryType::SUBJECT); - } else { - D_ASSERT(side == DependencyEntryType::DEPENDENT); - } - -#endif - - other_entries.insert(other_entry); - callback(other_entry); - }; - - if (scan_subjects) { - DependencyCatalogSet subjects(Subjects(), info); - subjects.Scan(transaction, cb); - } else { - DependencyCatalogSet dependents(Dependents(), info); - dependents.Scan(transaction, cb); - } - -#ifdef DEBUG - // Verify some invariants - // Every dependency should have a matching dependent in the other set - // And vice versa - auto mangled_name = MangleName(info); - - if (scan_subjects) { - for (auto &entry : other_entries) { - auto other_info = GetLookupProperties(entry); - DependencyCatalogSet other_dependents(Dependents(), other_info); - - // Verify that the other half of the dependency also exists - auto dependent = other_dependents.GetEntryDetailed(transaction, mangled_name); - D_ASSERT(dependent.reason != CatalogSet::EntryLookup::FailureReason::NOT_PRESENT); - } - } else { - for (auto &entry : other_entries) { - auto other_info = GetLookupProperties(entry); - DependencyCatalogSet other_subjects(Subjects(), other_info); - - // Verify that the other half of the dependent also exists - auto subject = other_subjects.GetEntryDetailed(transaction, mangled_name); - D_ASSERT(subject.reason != CatalogSet::EntryLookup::FailureReason::NOT_PRESENT); - } - } -#endif -} - -void DependencyManager::ScanDependents(CatalogTransaction transaction, const CatalogEntryInfo &info, - dependency_callback_t &callback) { - ScanSetInternal(transaction, info, false, callback); -} - -void DependencyManager::ScanSubjects(CatalogTransaction transaction, const CatalogEntryInfo &info, - dependency_callback_t &callback) { - ScanSetInternal(transaction, info, true, callback); -} - -void DependencyManager::RemoveDependency(CatalogTransaction transaction, const DependencyInfo &info) { - auto &dependent = info.dependent; - auto &subject = info.subject; - - // The dependents of the dependency (target) - DependencyCatalogSet dependents(Dependents(), subject.entry); - // The subjects of the dependencies of the dependent - DependencyCatalogSet subjects(Subjects(), dependent.entry); - - auto dependent_mangled = MangledEntryName(dependent.entry); - auto subject_mangled = MangledEntryName(subject.entry); - - auto dependent_p = dependents.GetEntry(transaction, dependent_mangled); - if (dependent_p) { - // 'dependent' is no longer inhibiting the deletion of 'dependency' - dependents.DropEntry(transaction, dependent_mangled, false); - } - auto subject_p = subjects.GetEntry(transaction, subject_mangled); - if (subject_p) { - // 'dependency' is no longer required by 'dependent' - subjects.DropEntry(transaction, subject_mangled, false); - } -} - -void DependencyManager::CreateSubject(CatalogTransaction transaction, const DependencyInfo &info) { - auto &from = info.dependent.entry; - - DependencyCatalogSet set(Subjects(), from); - auto dep = make_uniq_base(catalog, info); - auto entry_name = dep->EntryMangledName(); - - //! Add to the list of objects that 'dependent' has a dependency on - set.CreateEntry(transaction, entry_name, std::move(dep)); -} - -void DependencyManager::CreateDependent(CatalogTransaction transaction, const DependencyInfo &info) { - auto &from = info.subject.entry; - - DependencyCatalogSet set(Dependents(), from); - auto dep = make_uniq_base(catalog, info); - auto entry_name = dep->EntryMangledName(); - - //! Add to the list of object that depend on 'subject' - set.CreateEntry(transaction, entry_name, std::move(dep)); -} - -void DependencyManager::CreateDependency(CatalogTransaction transaction, DependencyInfo &info) { - DependencyCatalogSet subjects(Subjects(), info.dependent.entry); - DependencyCatalogSet dependents(Dependents(), info.subject.entry); - - auto subject_mangled = MangleName(info.subject.entry); - auto dependent_mangled = MangleName(info.dependent.entry); - - auto &dependent_flags = info.dependent.flags; - auto &subject_flags = info.subject.flags; - - auto existing_subject = subjects.GetEntry(transaction, subject_mangled); - auto existing_dependent = dependents.GetEntry(transaction, dependent_mangled); - - // Inherit the existing flags and drop the existing entry if present - if (existing_subject) { - auto &existing = existing_subject->Cast(); - auto existing_flags = existing.Subject().flags; - if (existing_flags != subject_flags) { - subject_flags.Apply(existing_flags); - } - subjects.DropEntry(transaction, subject_mangled, false, false); - } - if (existing_dependent) { - auto &existing = existing_dependent->Cast(); - auto existing_flags = existing.Dependent().flags; - if (existing_flags != dependent_flags) { - dependent_flags.Apply(existing_flags); - } - dependents.DropEntry(transaction, dependent_mangled, false, false); - } - - // Create an entry in the dependents map of the object that is the target of the dependency - CreateDependent(transaction, info); - // Create an entry in the subjects map of the object that is targeting another entry - CreateSubject(transaction, info); -} - -void DependencyManager::AddObject(CatalogTransaction transaction, CatalogEntry &object, - const DependencyList &dependencies) { - if (IsSystemEntry(object)) { - // Don't do anything for this - return; - } - - // check for each object in the sources if they were not deleted yet - for (auto &dep : dependencies.set) { - auto &dependency = dep.get(); - if (&dependency.ParentCatalog() != &object.ParentCatalog()) { - throw DependencyException( - "Error adding dependency for object \"%s\" - dependency \"%s\" is in catalog " - "\"%s\", which does not match the catalog \"%s\".\nCross catalog dependencies are not supported.", - object.name, dependency.name, dependency.ParentCatalog().GetName(), object.ParentCatalog().GetName()); - } - if (!dependency.set) { - throw InternalException("Dependency has no set"); - } - auto catalog_entry = dependency.set->GetEntry(transaction, dependency.name); - if (!catalog_entry) { - throw InternalException("Dependency has already been deleted?"); - } - } - - // indexes do not require CASCADE to be dropped, they are simply always dropped along with the table - DependencyDependentFlags dependency_flags; - if (object.type != CatalogType::INDEX_ENTRY) { - // indexes do not require CASCADE to be dropped, they are simply always dropped along with the table - dependency_flags.SetBlocking(); - } - - // add the object to the dependents_map of each object that it depends on - for (auto &dependency : dependencies.set) { - DependencyInfo info { - /*dependent = */ DependencyDependent {GetLookupProperties(object), dependency_flags}, - /*subject = */ DependencySubject {GetLookupProperties(dependency), DependencySubjectFlags()}}; - CreateDependency(transaction, info); - } -} - -static bool CascadeDrop(bool cascade, const DependencyDependentFlags &flags) { - if (cascade) { - return true; - } - if (flags.IsOwnedBy()) { - // We are owned by this object, while it exists we can not be dropped without cascade. - return false; - } - return !flags.IsBlocking(); -} - -CatalogEntryInfo DependencyManager::GetLookupProperties(CatalogEntry &entry) { - if (entry.type == CatalogType::DEPENDENCY_ENTRY) { - auto &dependency_entry = entry.Cast(); - return dependency_entry.EntryInfo(); - } else { - auto schema = DependencyManager::GetSchema(entry); - auto &name = entry.name; - auto &type = entry.type; - return CatalogEntryInfo {type, schema, name}; - } -} - -optional_ptr DependencyManager::LookupEntry(CatalogTransaction transaction, CatalogEntry &dependency) { - auto info = GetLookupProperties(dependency); - - auto &type = info.type; - auto &schema = info.schema; - auto &name = info.name; - - // Lookup the schema - auto schema_entry = catalog.GetSchema(transaction, schema, OnEntryNotFound::RETURN_NULL); - if (type == CatalogType::SCHEMA_ENTRY || !schema_entry) { - // This is a schema entry, perform the callback only providing the schema - return reinterpret_cast(schema_entry.get()); - } - auto entry = schema_entry->GetEntry(transaction, type, name); - return entry; -} - -void DependencyManager::CleanupDependencies(CatalogTransaction transaction, CatalogEntry &object) { - // Collect the dependencies - vector to_remove; - - auto info = GetLookupProperties(object); - ScanSubjects(transaction, info, - [&](DependencyEntry &dep) { to_remove.push_back(DependencyInfo::FromSubject(dep)); }); - ScanDependents(transaction, info, - [&](DependencyEntry &dep) { to_remove.push_back(DependencyInfo::FromDependent(dep)); }); - - // Remove the dependency entries - for (auto &dep : to_remove) { - RemoveDependency(transaction, dep); - } -} - -void DependencyManager::DropObject(CatalogTransaction transaction, CatalogEntry &object, bool cascade) { - if (IsSystemEntry(object)) { - // Don't do anything for this - return; - } - - auto info = GetLookupProperties(object); - // Check if there are any entries that block the DROP because they still depend on the object - catalog_entry_set_t to_drop; - ScanDependents(transaction, info, [&](DependencyEntry &dep) { - // It makes no sense to have a schema depend on anything - D_ASSERT(dep.EntryInfo().type != CatalogType::SCHEMA_ENTRY); - auto entry = LookupEntry(transaction, dep); - if (!entry) { - return; - } - - if (!CascadeDrop(cascade, dep.Dependent().flags)) { - // no cascade and there are objects that depend on this object: throw error - throw DependencyException("Cannot drop entry \"%s\" because there are entries that " - "depend on it. Use DROP...CASCADE to drop all dependents.", - object.name); - } - to_drop.insert(*entry); - }); - ScanSubjects(transaction, info, [&](DependencyEntry &dep) { - auto flags = dep.Subject().flags; - if (flags.IsOwnership()) { - // We own this object, it should be dropped along with the table - auto entry = LookupEntry(transaction, dep); - to_drop.insert(*entry); - } - }); - - CleanupDependencies(transaction, object); - - for (auto &entry : to_drop) { - auto set = entry.get().set; - D_ASSERT(set); - set->DropEntry(transaction, entry.get().name, cascade); - } -} - -void DependencyManager::AlterObject(CatalogTransaction transaction, CatalogEntry &old_obj, CatalogEntry &new_obj) { - if (IsSystemEntry(new_obj)) { - D_ASSERT(IsSystemEntry(old_obj)); - // Don't do anything for this - return; - } - - auto info = GetLookupProperties(old_obj); - dependency_set_t owned_objects; - ScanDependents(transaction, info, [&](DependencyEntry &dep) { - // It makes no sense to have a schema depend on anything - D_ASSERT(dep.EntryInfo().type != CatalogType::SCHEMA_ENTRY); - - auto entry = LookupEntry(transaction, dep); - if (!entry) { - return; - } - // conflict: attempting to alter this object but the dependent object still exists - // no cascade and there are objects that depend on this object: throw error - throw DependencyException("Cannot alter entry \"%s\" because there are entries that " - "depend on it.", - old_obj.name); - }); - - // Keep old dependencies - dependency_set_t dependents; - ScanSubjects(transaction, info, [&](DependencyEntry &dep) { - auto entry = LookupEntry(transaction, dep); - if (!entry) { - return; - } - if (dep.Subject().flags.IsOwnership()) { - owned_objects.insert(Dependency(*entry, dep.Dependent().flags)); - return; - } - dependents.insert(Dependency(*entry, dep.Dependent().flags)); - }); - - // FIXME: we should update dependencies in the future - // some alters could cause dependencies to change (imagine types of table columns) - // or DEFAULT depending on a sequence - if (StringUtil::CIEquals(old_obj.name, new_obj.name)) { - // The name was not changed, we do not need to recreate the dependency links - return; - } - CleanupDependencies(transaction, old_obj); - - for (auto &dep : dependents) { - auto &other = dep.entry.get(); - DependencyInfo info {/*dependent = */ DependencyDependent {GetLookupProperties(new_obj), dep.flags}, - /*subject = */ DependencySubject {GetLookupProperties(other), DependencySubjectFlags()}}; - CreateDependency(transaction, info); - } - - // For all the objects we own, re establish the dependency of the owner on the object - for (auto &object : owned_objects) { - auto &entry = object.entry.get(); - { - DependencyInfo info { - /*dependent = */ DependencyDependent {GetLookupProperties(new_obj), - DependencyDependentFlags().SetOwnedBy()}, - /*subject = */ DependencySubject {GetLookupProperties(entry), DependencySubjectFlags().SetOwnership()}}; - CreateDependency(transaction, info); - } - } -} - -void DependencyManager::Scan( - ClientContext &context, - const std::function &callback) { - lock_guard write_lock(catalog.GetWriteLock()); - auto transaction = catalog.GetCatalogTransaction(context); - - // All the objects registered in the dependency manager - catalog_entry_set_t entries; - dependents.Scan(transaction, [&](CatalogEntry &set) { - auto entry = LookupEntry(transaction, set); - entries.insert(*entry); - }); - - // For every registered entry, get the dependents - for (auto &entry : entries) { - auto entry_info = GetLookupProperties(entry); - // Scan all the dependents of the entry - ScanDependents(transaction, entry_info, [&](DependencyEntry &dependent) { - auto dep = LookupEntry(transaction, dependent); - if (!dep) { - return; - } - auto &dependent_entry = *dep; - callback(entry, dependent_entry, dependent.Dependent().flags); - }); - } -} - -void DependencyManager::AddOwnership(CatalogTransaction transaction, CatalogEntry &owner, CatalogEntry &entry) { - if (IsSystemEntry(entry) || IsSystemEntry(owner)) { - return; - } - - // If the owner is already owned by something else, throw an error - auto owner_info = GetLookupProperties(owner); - ScanDependents(transaction, owner_info, [&](DependencyEntry &dep) { - if (dep.Dependent().flags.IsOwnedBy()) { - throw DependencyException("%s can not become the owner, it is already owned by %s", owner.name, - dep.EntryInfo().name); - } - }); - - // If the entry is the owner of another entry, throw an error - auto entry_info = GetLookupProperties(entry); - ScanSubjects(transaction, entry_info, [&](DependencyEntry &other) { - auto dependent_entry = LookupEntry(transaction, other); - if (!dependent_entry) { - return; - } - auto &dep = *dependent_entry; - - auto flags = other.Dependent().flags; - if (!flags.IsOwnedBy()) { - return; - } - throw DependencyException("%s already owns %s. Cannot have circular dependencies", entry.name, dep.name); - }); - - // If the entry is already owned, throw an error - ScanDependents(transaction, entry_info, [&](DependencyEntry &other) { - auto dependent_entry = LookupEntry(transaction, other); - if (!dependent_entry) { - return; - } - - auto &dep = *dependent_entry; - auto flags = other.Subject().flags; - if (!flags.IsOwnership()) { - return; - } - if (&dep != &owner) { - throw DependencyException("%s is already owned by %s", entry.name, dep.name); - } - }); - - DependencyInfo info { - /*dependent = */ DependencyDependent {GetLookupProperties(owner), DependencyDependentFlags().SetOwnedBy()}, - /*subject = */ DependencySubject {GetLookupProperties(entry), DependencySubjectFlags().SetOwnership()}}; - CreateDependency(transaction, info); -} - -static string FormatString(const MangledEntryName &mangled) { - auto input = mangled.name; - for (size_t i = 0; i < input.size(); i++) { - if (input[i] == '\0') { - input[i] = '_'; - } - } - return input; -} - -void DependencyManager::PrintSubjects(CatalogTransaction transaction, const CatalogEntryInfo &info) { - auto name = MangleName(info); - Printer::Print(StringUtil::Format("Subjects of %s", FormatString(name))); - auto subjects = DependencyCatalogSet(Subjects(), info); - subjects.Scan(transaction, [&](CatalogEntry &dependency) { - auto &dep = dependency.Cast(); - auto &entry_info = dep.EntryInfo(); - auto type = entry_info.type; - auto schema = entry_info.schema; - auto name = entry_info.name; - Printer::Print(StringUtil::Format("Schema: %s | Name: %s | Type: %s | Dependent type: %s | Subject type: %s", - schema, name, CatalogTypeToString(type), dep.Dependent().flags.ToString(), - dep.Subject().flags.ToString())); - }); -} - -void DependencyManager::PrintDependents(CatalogTransaction transaction, const CatalogEntryInfo &info) { - auto name = MangleName(info); - Printer::Print(StringUtil::Format("Dependents of %s", FormatString(name))); - auto dependents = DependencyCatalogSet(Dependents(), info); - dependents.Scan(transaction, [&](CatalogEntry &dependent) { - auto &dep = dependent.Cast(); - auto &entry_info = dep.EntryInfo(); - auto type = entry_info.type; - auto schema = entry_info.schema; - auto name = entry_info.name; - Printer::Print(StringUtil::Format("Schema: %s | Name: %s | Type: %s | Dependent type: %s | Subject type: %s", - schema, name, CatalogTypeToString(type), dep.Dependent().flags.ToString(), - dep.Subject().flags.ToString())); - }); -} - -} // namespace duckdb diff --git a/src/catalog/duck_catalog.cpp b/src/catalog/duck_catalog.cpp index 8f712b843ed3..8b137891791f 100644 --- a/src/catalog/duck_catalog.cpp +++ b/src/catalog/duck_catalog.cpp @@ -1,154 +1 @@ -#include "duckdb/catalog/duck_catalog.hpp" -#include "duckdb/catalog/dependency_manager.hpp" -#include "duckdb/catalog/catalog_entry/duck_schema_entry.hpp" -#include "duckdb/storage/storage_manager.hpp" -#include "duckdb/parser/parsed_data/drop_info.hpp" -#include "duckdb/parser/parsed_data/create_schema_info.hpp" -#include "duckdb/catalog/default/default_schemas.hpp" -#include "duckdb/function/built_in_functions.hpp" -#include "duckdb/main/attached_database.hpp" -#ifndef DISABLE_CORE_FUNCTIONS_EXTENSION -#include "duckdb/core_functions/core_functions.hpp" -#endif -namespace duckdb { - -DuckCatalog::DuckCatalog(AttachedDatabase &db) - : Catalog(db), dependency_manager(make_uniq(*this)), - schemas(make_uniq(*this, make_uniq(*this))) { -} - -DuckCatalog::~DuckCatalog() { -} - -void DuckCatalog::Initialize(bool load_builtin) { - // first initialize the base system catalogs - // these are never written to the WAL - // we start these at 1 because deleted entries default to 0 - auto data = CatalogTransaction::GetSystemTransaction(GetDatabase()); - - // create the default schema - CreateSchemaInfo info; - info.schema = DEFAULT_SCHEMA; - info.internal = true; - CreateSchema(data, info); - - if (load_builtin) { - // initialize default functions - BuiltinFunctions builtin(data, *this); - builtin.Initialize(); - -#ifndef DISABLE_CORE_FUNCTIONS_EXTENSION - CoreFunctions::RegisterFunctions(*this, data); -#endif - } - - Verify(); -} - -bool DuckCatalog::IsDuckCatalog() { - return true; -} - -//===--------------------------------------------------------------------===// -// Schema -//===--------------------------------------------------------------------===// -optional_ptr DuckCatalog::CreateSchemaInternal(CatalogTransaction transaction, CreateSchemaInfo &info) { - DependencyList dependencies; - auto entry = make_uniq(*this, info); - auto result = entry.get(); - if (!schemas->CreateEntry(transaction, info.schema, std::move(entry), dependencies)) { - return nullptr; - } - return (CatalogEntry *)result; -} - -optional_ptr DuckCatalog::CreateSchema(CatalogTransaction transaction, CreateSchemaInfo &info) { - D_ASSERT(!info.schema.empty()); - auto result = CreateSchemaInternal(transaction, info); - if (!result) { - switch (info.on_conflict) { - case OnCreateConflict::ERROR_ON_CONFLICT: - throw CatalogException::EntryAlreadyExists(CatalogType::SCHEMA_ENTRY, info.schema); - case OnCreateConflict::REPLACE_ON_CONFLICT: { - DropInfo drop_info; - drop_info.type = CatalogType::SCHEMA_ENTRY; - drop_info.catalog = info.catalog; - drop_info.name = info.schema; - DropSchema(transaction, drop_info); - result = CreateSchemaInternal(transaction, info); - if (!result) { - throw InternalException("Failed to create schema entry in CREATE_OR_REPLACE"); - } - break; - } - case OnCreateConflict::IGNORE_ON_CONFLICT: - break; - default: - throw InternalException("Unsupported OnCreateConflict for CreateSchema"); - } - return nullptr; - } - return result; -} - -void DuckCatalog::DropSchema(CatalogTransaction transaction, DropInfo &info) { - D_ASSERT(!info.name.empty()); - ModifyCatalog(); - if (!schemas->DropEntry(transaction, info.name, info.cascade)) { - if (info.if_not_found == OnEntryNotFound::THROW_EXCEPTION) { - throw CatalogException::MissingEntry(CatalogType::SCHEMA_ENTRY, info.name, string()); - } - } -} - -void DuckCatalog::DropSchema(ClientContext &context, DropInfo &info) { - DropSchema(GetCatalogTransaction(context), info); -} - -void DuckCatalog::ScanSchemas(ClientContext &context, std::function callback) { - schemas->Scan(GetCatalogTransaction(context), - [&](CatalogEntry &entry) { callback(entry.Cast()); }); -} - -void DuckCatalog::ScanSchemas(std::function callback) { - schemas->Scan([&](CatalogEntry &entry) { callback(entry.Cast()); }); -} - -optional_ptr DuckCatalog::GetSchema(CatalogTransaction transaction, const string &schema_name, - OnEntryNotFound if_not_found, QueryErrorContext error_context) { - D_ASSERT(!schema_name.empty()); - auto entry = schemas->GetEntry(transaction, schema_name); - if (!entry) { - if (if_not_found == OnEntryNotFound::THROW_EXCEPTION) { - throw CatalogException(error_context, "Schema with name %s does not exist!", schema_name); - } - return nullptr; - } - return &entry->Cast(); -} - -DatabaseSize DuckCatalog::GetDatabaseSize(ClientContext &context) { - return db.GetStorageManager().GetDatabaseSize(); -} - -vector DuckCatalog::GetMetadataInfo(ClientContext &context) { - return db.GetStorageManager().GetMetadataInfo(); -} - -bool DuckCatalog::InMemory() { - return db.GetStorageManager().InMemory(); -} - -string DuckCatalog::GetDBPath() { - return db.GetStorageManager().GetDBPath(); -} - -void DuckCatalog::Verify() { -#ifdef DEBUG - Catalog::Verify(); - schemas->Verify(*this); -#endif -} - -} // namespace duckdb diff --git a/src/function/scalar/string/caseconvert.cpp b/src/function/scalar/string/caseconvert.cpp index bf54b9ac38e3..fa5b612f1860 100644 --- a/src/function/scalar/string/caseconvert.cpp +++ b/src/function/scalar/string/caseconvert.cpp @@ -44,8 +44,8 @@ static string_t ASCIICaseConvert(Vector &result, const char *input_data, idx_t i auto result_str = StringVector::EmptyString(result, output_length); auto result_data = result_str.GetDataWriteable(); for (idx_t i = 0; i < input_length; i++) { - result_data[i] = IS_UPPER ? UpperFun::ASCII_TO_UPPER_MAP[uint8_t(input_data[i])] - : LowerFun::ASCII_TO_LOWER_MAP[uint8_t(input_data[i])]; + result_data[i] = UnsafeNumericCast(IS_UPPER ? UpperFun::ASCII_TO_UPPER_MAP[uint8_t(input_data[i])] + : LowerFun::ASCII_TO_LOWER_MAP[uint8_t(input_data[i])]); } result_str.Finalize(); return result_str; @@ -63,7 +63,7 @@ static idx_t GetResultLength(const char *input_data, idx_t input_length) { auto new_sz = utf8proc_codepoint_length(converted_codepoint); D_ASSERT(new_sz >= 0); output_length += UnsafeNumericCast(new_sz); - i += sz; + i += UnsafeNumericCast(sz); } else { // ascii output_length++; @@ -79,17 +79,17 @@ static void CaseConvert(const char *input_data, idx_t input_length, char *result if (input_data[i] & 0x80) { // non-ascii character int sz = 0, new_sz = 0; - int codepoint = utf8proc_codepoint(input_data + i, sz); - int converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); + auto codepoint = utf8proc_codepoint(input_data + i, sz); + auto converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); auto success = utf8proc_codepoint_to_utf8(converted_codepoint, new_sz, result_data); D_ASSERT(success); (void)success; result_data += new_sz; - i += sz; + i += UnsafeNumericCast(sz); } else { // ascii - *result_data = IS_UPPER ? UpperFun::ASCII_TO_UPPER_MAP[uint8_t(input_data[i])] - : LowerFun::ASCII_TO_LOWER_MAP[uint8_t(input_data[i])]; + *result_data = UnsafeNumericCast(IS_UPPER ? UpperFun::ASCII_TO_UPPER_MAP[uint8_t(input_data[i])] + : LowerFun::ASCII_TO_LOWER_MAP[uint8_t(input_data[i])]); result_data++; i++; } diff --git a/src/function/scalar/string/length.cpp b/src/function/scalar/string/length.cpp index 218e4e84626d..129105bee80a 100644 --- a/src/function/scalar/string/length.cpp +++ b/src/function/scalar/string/length.cpp @@ -29,14 +29,14 @@ struct GraphemeCountOperator { struct StrLenOperator { template static inline TR Operation(TA input) { - return input.GetSize(); + return UnsafeNumericCast(input.GetSize()); } }; struct OctetLenOperator { template static inline TR Operation(TA input) { - return Bit::OctetLength(input); + return UnsafeNumericCast(Bit::OctetLength(input)); } }; @@ -44,7 +44,7 @@ struct OctetLenOperator { struct BitLenOperator { template static inline TR Operation(TA input) { - return 8 * input.GetSize(); + return UnsafeNumericCast(8 * input.GetSize()); } }; @@ -52,7 +52,7 @@ struct BitLenOperator { struct BitStringLenOperator { template static inline TR Operation(TA input) { - return Bit::BitLength(input); + return UnsafeNumericCast(Bit::BitLength(input)); } }; @@ -74,8 +74,8 @@ static unique_ptr LengthPropagateStats(ClientContext &context, F static void ListLengthFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &input = args.data[0]; D_ASSERT(input.GetType().id() == LogicalTypeId::LIST); - UnaryExecutor::Execute(input, result, args.size(), - [](list_entry_t input) { return input.length; }); + UnaryExecutor::Execute( + input, result, args.size(), [](list_entry_t input) { return UnsafeNumericCast(input.length); }); if (args.AllConstant()) { result.SetVectorType(VectorType::CONSTANT_VECTOR); } @@ -117,7 +117,7 @@ static void ListLengthBinaryFunction(DataChunk &args, ExpressionState &, Vector if (dimension != 1) { throw NotImplementedException("array_length for lists with dimensions other than 1 not implemented"); } - return input.length; + return UnsafeNumericCast(input.length); }); if (args.AllConstant()) { result.SetVectorType(VectorType::CONSTANT_VECTOR); @@ -153,7 +153,7 @@ static void ArrayLengthBinaryFunction(DataChunk &args, ExpressionState &state, V throw OutOfRangeException(StringUtil::Format( "array_length dimension '%lld' out of range (min: '1', max: '%lld')", dimension, max_dimension)); } - return dimensions[dimension - 1]; + return dimensions[UnsafeNumericCast(dimension - 1)]; }); if (args.AllConstant()) { @@ -175,7 +175,7 @@ static unique_ptr ArrayOrListLengthBinaryBind(ClientContext &conte vector dimensions; while (true) { if (type.id() == LogicalTypeId::ARRAY) { - dimensions.push_back(ArrayType::GetSize(type)); + dimensions.push_back(UnsafeNumericCast(ArrayType::GetSize(type))); type = ArrayType::GetChildType(type); } else { break; diff --git a/src/function/scalar/string/substring.cpp b/src/function/scalar/string/substring.cpp index b0b2d8161e00..a6d582c4ecd8 100644 --- a/src/function/scalar/string/substring.cpp +++ b/src/function/scalar/string/substring.cpp @@ -40,7 +40,7 @@ string_t SubstringEmptyString(Vector &result) { } string_t SubstringSlice(Vector &result, const char *input_data, int64_t offset, int64_t length) { - auto result_string = StringVector::EmptyString(result, length); + auto result_string = StringVector::EmptyString(result, UnsafeNumericCast(length)); auto result_data = result_string.GetDataWriteable(); memcpy(result_data, input_data + offset, length); result_string.Finalize(); diff --git a/src/optimizer/rule/conjunction_simplification.cpp b/src/optimizer/rule/conjunction_simplification.cpp index 646471b9412e..8b137891791f 100644 --- a/src/optimizer/rule/conjunction_simplification.cpp +++ b/src/optimizer/rule/conjunction_simplification.cpp @@ -1,70 +1 @@ -#include "duckdb/optimizer/rule/conjunction_simplification.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_conjunction_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" - -namespace duckdb { - -ConjunctionSimplificationRule::ConjunctionSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on a ComparisonExpression that has a ConstantExpression as a check - auto op = make_uniq(); - op->matchers.push_back(make_uniq()); - op->policy = SetMatcher::Policy::SOME; - root = std::move(op); -} - -unique_ptr ConjunctionSimplificationRule::RemoveExpression(BoundConjunctionExpression &conj, - const Expression &expr) { - for (idx_t i = 0; i < conj.children.size(); i++) { - if (conj.children[i].get() == &expr) { - // erase the expression - conj.children.erase_at(i); - break; - } - } - if (conj.children.size() == 1) { - // one expression remaining: simply return that expression and erase the conjunction - return std::move(conj.children[0]); - } - return nullptr; -} - -unique_ptr ConjunctionSimplificationRule::Apply(LogicalOperator &op, - vector> &bindings, bool &changes_made, - bool is_root) { - auto &conjunction = bindings[0].get().Cast(); - auto &constant_expr = bindings[1].get(); - // the constant_expr is a scalar expression that we have to fold - // use an ExpressionExecutor to execute the expression - D_ASSERT(constant_expr.IsFoldable()); - Value constant_value; - if (!ExpressionExecutor::TryEvaluateScalar(GetContext(), constant_expr, constant_value)) { - return nullptr; - } - constant_value = constant_value.DefaultCastAs(LogicalType::BOOLEAN); - if (constant_value.IsNull()) { - // we can't simplify conjunctions with a constant NULL - return nullptr; - } - if (conjunction.type == ExpressionType::CONJUNCTION_AND) { - if (!BooleanValue::Get(constant_value)) { - // FALSE in AND, result of expression is false - return make_uniq(Value::BOOLEAN(false)); - } else { - // TRUE in AND, remove the expression from the set - return RemoveExpression(conjunction, constant_expr); - } - } else { - D_ASSERT(conjunction.type == ExpressionType::CONJUNCTION_OR); - if (!BooleanValue::Get(constant_value)) { - // FALSE in OR, remove the expression from the set - return RemoveExpression(conjunction, constant_expr); - } else { - // TRUE in OR, result of expression is true - return make_uniq(Value::BOOLEAN(true)); - } - } -} - -} // namespace duckdb diff --git a/src/optimizer/rule/constant_folding.cpp b/src/optimizer/rule/constant_folding.cpp index 6b7d20c46aa5..8b137891791f 100644 --- a/src/optimizer/rule/constant_folding.cpp +++ b/src/optimizer/rule/constant_folding.cpp @@ -1,43 +1 @@ -#include "duckdb/optimizer/rule/constant_folding.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/optimizer/expression_rewriter.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" - -namespace duckdb { - -//! The ConstantFoldingExpressionMatcher matches on any scalar expression (i.e. Expression::IsFoldable is true) -class ConstantFoldingExpressionMatcher : public FoldableConstantMatcher { -public: - bool Match(Expression &expr, vector> &bindings) override { - // we also do not match on ConstantExpressions, because we cannot fold those any further - if (expr.type == ExpressionType::VALUE_CONSTANT) { - return false; - } - return FoldableConstantMatcher::Match(expr, bindings); - } -}; - -ConstantFoldingRule::ConstantFoldingRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - auto op = make_uniq(); - root = std::move(op); -} - -unique_ptr ConstantFoldingRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &root = bindings[0].get(); - // the root is a scalar expression that we have to fold - D_ASSERT(root.IsFoldable() && root.type != ExpressionType::VALUE_CONSTANT); - - // use an ExpressionExecutor to execute the expression - Value result_value; - if (!ExpressionExecutor::TryEvaluateScalar(GetContext(), root, result_value)) { - return nullptr; - } - D_ASSERT(result_value.type().InternalType() == root.return_type.InternalType()); - // now get the value from the result vector and insert it back into the plan as a constant expression - return make_uniq(result_value); -} - -} // namespace duckdb diff --git a/src/optimizer/rule/date_part_simplification.cpp b/src/optimizer/rule/date_part_simplification.cpp index 6737e576c7dd..8b137891791f 100644 --- a/src/optimizer/rule/date_part_simplification.cpp +++ b/src/optimizer/rule/date_part_simplification.cpp @@ -1,104 +1 @@ -#include "duckdb/optimizer/rule/date_part_simplification.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/optimizer/matcher/expression_matcher.hpp" -#include "duckdb/optimizer/expression_rewriter.hpp" -#include "duckdb/common/enums/date_part_specifier.hpp" -#include "duckdb/function/function.hpp" -#include "duckdb/function/function_binder.hpp" - -namespace duckdb { - -DatePartSimplificationRule::DatePartSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - auto func = make_uniq(); - func->function = make_uniq("date_part"); - func->matchers.push_back(make_uniq()); - func->matchers.push_back(make_uniq()); - func->policy = SetMatcher::Policy::ORDERED; - root = std::move(func); -} - -unique_ptr DatePartSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &date_part = bindings[0].get().Cast(); - auto &constant_expr = bindings[1].get().Cast(); - auto &constant = constant_expr.value; - - if (constant.IsNull()) { - // NULL specifier: return constant NULL - return make_uniq(Value(date_part.return_type)); - } - // otherwise check the specifier - auto specifier = GetDatePartSpecifier(StringValue::Get(constant)); - string new_function_name; - switch (specifier) { - case DatePartSpecifier::YEAR: - new_function_name = "year"; - break; - case DatePartSpecifier::MONTH: - new_function_name = "month"; - break; - case DatePartSpecifier::DAY: - new_function_name = "day"; - break; - case DatePartSpecifier::DECADE: - new_function_name = "decade"; - break; - case DatePartSpecifier::CENTURY: - new_function_name = "century"; - break; - case DatePartSpecifier::MILLENNIUM: - new_function_name = "millennium"; - break; - case DatePartSpecifier::QUARTER: - new_function_name = "quarter"; - break; - case DatePartSpecifier::WEEK: - new_function_name = "week"; - break; - case DatePartSpecifier::YEARWEEK: - new_function_name = "yearweek"; - break; - case DatePartSpecifier::DOW: - new_function_name = "dayofweek"; - break; - case DatePartSpecifier::ISODOW: - new_function_name = "isodow"; - break; - case DatePartSpecifier::DOY: - new_function_name = "dayofyear"; - break; - case DatePartSpecifier::MICROSECONDS: - new_function_name = "microsecond"; - break; - case DatePartSpecifier::MILLISECONDS: - new_function_name = "millisecond"; - break; - case DatePartSpecifier::SECOND: - new_function_name = "second"; - break; - case DatePartSpecifier::MINUTE: - new_function_name = "minute"; - break; - case DatePartSpecifier::HOUR: - new_function_name = "hour"; - break; - default: - return nullptr; - } - // found a replacement function: bind it - vector> children; - children.push_back(std::move(date_part.children[1])); - - ErrorData error; - FunctionBinder binder(rewriter.context); - auto function = binder.BindScalarFunction(DEFAULT_SCHEMA, new_function_name, std::move(children), error, false); - if (!function) { - error.Throw(); - } - return function; -} - -} // namespace duckdb diff --git a/src/optimizer/rule/empty_needle_removal.cpp b/src/optimizer/rule/empty_needle_removal.cpp index 500d639a16df..8b137891791f 100644 --- a/src/optimizer/rule/empty_needle_removal.cpp +++ b/src/optimizer/rule/empty_needle_removal.cpp @@ -1,54 +1 @@ -#include "duckdb/optimizer/rule/empty_needle_removal.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_operator_expression.hpp" -#include "duckdb/planner/expression/bound_case_expression.hpp" -#include "duckdb/optimizer/expression_rewriter.hpp" - -namespace duckdb { - -EmptyNeedleRemovalRule::EmptyNeedleRemovalRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on a FunctionExpression that has a foldable ConstantExpression - auto func = make_uniq(); - func->matchers.push_back(make_uniq()); - func->matchers.push_back(make_uniq()); - func->policy = SetMatcher::Policy::SOME; - - unordered_set functions = {"prefix", "contains", "suffix"}; - func->function = make_uniq(functions); - root = std::move(func); -} - -unique_ptr EmptyNeedleRemovalRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &root = bindings[0].get().Cast(); - D_ASSERT(root.children.size() == 2); - auto &prefix_expr = bindings[2].get(); - - // the constant_expr is a scalar expression that we have to fold - if (!prefix_expr.IsFoldable()) { - return nullptr; - } - D_ASSERT(root.return_type.id() == LogicalTypeId::BOOLEAN); - - auto prefix_value = ExpressionExecutor::EvaluateScalar(GetContext(), prefix_expr); - - if (prefix_value.IsNull()) { - return make_uniq(Value(LogicalType::BOOLEAN)); - } - - D_ASSERT(prefix_value.type() == prefix_expr.return_type); - auto &needle_string = StringValue::Get(prefix_value); - - // PREFIX('xyz', '') is TRUE - // PREFIX(NULL, '') is NULL - // so rewrite PREFIX(x, '') to TRUE_OR_NULL(x) - if (needle_string.empty()) { - return ExpressionRewriter::ConstantOrNull(std::move(root.children[0]), Value::BOOLEAN(true)); - } - return nullptr; -} - -} // namespace duckdb diff --git a/src/optimizer/rule/enum_comparison.cpp b/src/optimizer/rule/enum_comparison.cpp index 8b0525789974..8b137891791f 100644 --- a/src/optimizer/rule/enum_comparison.cpp +++ b/src/optimizer/rule/enum_comparison.cpp @@ -1,70 +1 @@ -#include "duckdb/optimizer/rule/enum_comparison.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/optimizer/matcher/type_matcher_id.hpp" -#include "duckdb/optimizer/expression_rewriter.hpp" -#include "duckdb/common/types.hpp" - -namespace duckdb { - -EnumComparisonRule::EnumComparisonRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on a ComparisonExpression that is an Equality and has a VARCHAR and ENUM as its children - auto op = make_uniq(); - // Enum requires expression to be root - op->expr_type = make_uniq(ExpressionType::COMPARE_EQUAL); - for (idx_t i = 0; i < 2; i++) { - auto child = make_uniq(); - child->type = make_uniq(LogicalTypeId::VARCHAR); - child->matcher = make_uniq(); - child->matcher->type = make_uniq(LogicalTypeId::ENUM); - op->matchers.push_back(std::move(child)); - } - root = std::move(op); -} - -bool AreMatchesPossible(LogicalType &left, LogicalType &right) { - LogicalType *small_enum, *big_enum; - if (EnumType::GetSize(left) < EnumType::GetSize(right)) { - small_enum = &left; - big_enum = &right; - } else { - small_enum = &right; - big_enum = &left; - } - auto &string_vec = EnumType::GetValuesInsertOrder(*small_enum); - auto string_vec_ptr = FlatVector::GetData(string_vec); - auto size = EnumType::GetSize(*small_enum); - for (idx_t i = 0; i < size; i++) { - auto key = string_vec_ptr[i].GetString(); - if (EnumType::GetPos(*big_enum, key) != -1) { - return true; - } - } - return false; -} -unique_ptr EnumComparisonRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - - auto &root = bindings[0].get().Cast(); - auto &left_child = bindings[1].get().Cast(); - auto &right_child = bindings[3].get().Cast(); - - if (!AreMatchesPossible(left_child.child->return_type, right_child.child->return_type)) { - vector> children; - children.push_back(std::move(root.left)); - children.push_back(std::move(root.right)); - return ExpressionRewriter::ConstantOrNull(std::move(children), Value::BOOLEAN(false)); - } - - if (!is_root || op.type != LogicalOperatorType::LOGICAL_FILTER) { - return nullptr; - } - - auto cast_left_to_right = - BoundCastExpression::AddDefaultCastToType(std::move(left_child.child), right_child.child->return_type, true); - return make_uniq(root.type, std::move(cast_left_to_right), std::move(right_child.child)); -} - -} // namespace duckdb diff --git a/src/optimizer/rule/in_clause_simplification_rule.cpp b/src/optimizer/rule/in_clause_simplification_rule.cpp index e1ad4fd9e78e..8b137891791f 100644 --- a/src/optimizer/rule/in_clause_simplification_rule.cpp +++ b/src/optimizer/rule/in_clause_simplification_rule.cpp @@ -1,57 +1 @@ -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/optimizer/rule/in_clause_simplification.hpp" -#include "duckdb/planner/expression/list.hpp" -#include "duckdb/planner/expression/bound_operator_expression.hpp" -namespace duckdb { - -InClauseSimplificationRule::InClauseSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on InClauseExpression that has a ConstantExpression as a check - auto op = make_uniq(); - op->policy = SetMatcher::Policy::SOME; - root = std::move(op); -} - -unique_ptr InClauseSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &expr = bindings[0].get().Cast(); - if (expr.children[0]->expression_class != ExpressionClass::BOUND_CAST) { - return nullptr; - } - auto &cast_expression = expr.children[0]->Cast(); - if (cast_expression.child->expression_class != ExpressionClass::BOUND_COLUMN_REF) { - return nullptr; - } - //! Here we check if we can apply the expression on the constant side - auto target_type = cast_expression.source_type(); - if (!BoundCastExpression::CastIsInvertible(cast_expression.return_type, target_type)) { - return nullptr; - } - vector> cast_list; - //! First check if we can cast all children - for (size_t i = 1; i < expr.children.size(); i++) { - if (expr.children[i]->expression_class != ExpressionClass::BOUND_CONSTANT) { - return nullptr; - } - D_ASSERT(expr.children[i]->IsFoldable()); - auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), *expr.children[i]); - auto new_constant = constant_value.DefaultTryCastAs(target_type); - if (!new_constant) { - return nullptr; - } else { - auto new_constant_expr = make_uniq(constant_value); - cast_list.push_back(std::move(new_constant_expr)); - } - } - //! We can cast, so we move the new constant - for (size_t i = 1; i < expr.children.size(); i++) { - expr.children[i] = std::move(cast_list[i - 1]); - - // expr->children[i] = std::move(new_constant_expr); - } - //! We can cast the full list, so we move the column - expr.children[0] = std::move(cast_expression.child); - return nullptr; -} - -} // namespace duckdb diff --git a/src/optimizer/rule/like_optimizations.cpp b/src/optimizer/rule/like_optimizations.cpp index 96f7b1501e8a..8b137891791f 100644 --- a/src/optimizer/rule/like_optimizations.cpp +++ b/src/optimizer/rule/like_optimizations.cpp @@ -1,162 +1 @@ -#include "duckdb/optimizer/rule/like_optimizations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_operator_expression.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" - -namespace duckdb { - -LikeOptimizationRule::LikeOptimizationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - // match on a FunctionExpression that has a foldable ConstantExpression - auto func = make_uniq(); - func->matchers.push_back(make_uniq()); - func->matchers.push_back(make_uniq()); - func->policy = SetMatcher::Policy::ORDERED; - // we match on LIKE ("~~") and NOT LIKE ("!~~") - func->function = make_uniq(unordered_set {"!~~", "~~"}); - root = std::move(func); -} - -static bool PatternIsConstant(const string &pattern) { - for (idx_t i = 0; i < pattern.size(); i++) { - if (pattern[i] == '%' || pattern[i] == '_') { - return false; - } - } - return true; -} - -static bool PatternIsPrefix(const string &pattern) { - idx_t i; - for (i = pattern.size(); i > 0; i--) { - if (pattern[i - 1] != '%') { - break; - } - } - if (i == pattern.size()) { - // no trailing % - // cannot be a prefix - return false; - } - // continue to look in the string - // if there is a % or _ in the string (besides at the very end) this is not a prefix match - for (; i > 0; i--) { - if (pattern[i - 1] == '%' || pattern[i - 1] == '_') { - return false; - } - } - return true; -} - -static bool PatternIsSuffix(const string &pattern) { - idx_t i; - for (i = 0; i < pattern.size(); i++) { - if (pattern[i] != '%') { - break; - } - } - if (i == 0) { - // no leading % - // cannot be a suffix - return false; - } - // continue to look in the string - // if there is a % or _ in the string (besides at the beginning) this is not a suffix match - for (; i < pattern.size(); i++) { - if (pattern[i] == '%' || pattern[i] == '_') { - return false; - } - } - return true; -} - -static bool PatternIsContains(const string &pattern) { - idx_t start; - idx_t end; - for (start = 0; start < pattern.size(); start++) { - if (pattern[start] != '%') { - break; - } - } - for (end = pattern.size(); end > 0; end--) { - if (pattern[end - 1] != '%') { - break; - } - } - if (start == 0 || end == pattern.size()) { - // contains requires both a leading AND a trailing % - return false; - } - // check if there are any other special characters in the string - // if there is a % or _ in the string (besides at the beginning/end) this is not a contains match - for (idx_t i = start; i < end; i++) { - if (pattern[i] == '%' || pattern[i] == '_') { - return false; - } - } - return true; -} - -unique_ptr LikeOptimizationRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &root = bindings[0].get().Cast(); - auto &constant_expr = bindings[2].get().Cast(); - D_ASSERT(root.children.size() == 2); - - if (constant_expr.value.IsNull()) { - return make_uniq(Value(root.return_type)); - } - - // the constant_expr is a scalar expression that we have to fold - if (!constant_expr.IsFoldable()) { - return nullptr; - } - - auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), constant_expr); - D_ASSERT(constant_value.type() == constant_expr.return_type); - auto &patt_str = StringValue::Get(constant_value); - - bool is_not_like = root.function.name == "!~~"; - if (PatternIsConstant(patt_str)) { - // Pattern is constant - return make_uniq(is_not_like ? ExpressionType::COMPARE_NOTEQUAL - : ExpressionType::COMPARE_EQUAL, - std::move(root.children[0]), std::move(root.children[1])); - } else if (PatternIsPrefix(patt_str)) { - // Prefix LIKE pattern : [^%_]*[%]+, ignoring underscore - return ApplyRule(root, PrefixFun::GetFunction(), patt_str, is_not_like); - } else if (PatternIsSuffix(patt_str)) { - // Suffix LIKE pattern: [%]+[^%_]*, ignoring underscore - return ApplyRule(root, SuffixFun::GetFunction(), patt_str, is_not_like); - } else if (PatternIsContains(patt_str)) { - // Contains LIKE pattern: [%]+[^%_]*[%]+, ignoring underscore - return ApplyRule(root, ContainsFun::GetFunction(), patt_str, is_not_like); - } - return nullptr; -} - -unique_ptr LikeOptimizationRule::ApplyRule(BoundFunctionExpression &expr, ScalarFunction function, - string pattern, bool is_not_like) { - // replace LIKE by an optimized function - unique_ptr result; - auto new_function = - make_uniq(expr.return_type, std::move(function), std::move(expr.children), nullptr); - - // removing "%" from the pattern - pattern.erase(std::remove(pattern.begin(), pattern.end(), '%'), pattern.end()); - - new_function->children[1] = make_uniq(Value(std::move(pattern))); - - result = std::move(new_function); - if (is_not_like) { - auto negation = make_uniq(ExpressionType::OPERATOR_NOT, LogicalType::BOOLEAN); - negation->children.push_back(std::move(result)); - result = std::move(negation); - } - - return result; -} - -} // namespace duckdb diff --git a/src/optimizer/rule/move_constants.cpp b/src/optimizer/rule/move_constants.cpp index 636265ff9131..8b137891791f 100644 --- a/src/optimizer/rule/move_constants.cpp +++ b/src/optimizer/rule/move_constants.cpp @@ -1,165 +1 @@ -#include "duckdb/optimizer/rule/move_constants.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/value_operations/value_operations.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/optimizer/expression_rewriter.hpp" - -namespace duckdb { - -MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - auto op = make_uniq(); - op->matchers.push_back(make_uniq()); - op->policy = SetMatcher::Policy::UNORDERED; - - auto arithmetic = make_uniq(); - // we handle multiplication, addition and subtraction because those are "easy" - // integer division makes the division case difficult - // e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules - arithmetic->function = make_uniq(unordered_set {"+", "-", "*"}); - // we match only on integral numeric types - arithmetic->type = make_uniq(); - auto child_constant_matcher = make_uniq(); - auto child_expression_matcher = make_uniq(); - child_constant_matcher->type = make_uniq(); - child_expression_matcher->type = make_uniq(); - arithmetic->matchers.push_back(std::move(child_constant_matcher)); - arithmetic->matchers.push_back(std::move(child_expression_matcher)); - arithmetic->policy = SetMatcher::Policy::SOME; - op->matchers.push_back(std::move(arithmetic)); - root = std::move(op); -} - -unique_ptr MoveConstantsRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &comparison = bindings[0].get().Cast(); - auto &outer_constant = bindings[1].get().Cast(); - auto &arithmetic = bindings[2].get().Cast(); - auto &inner_constant = bindings[3].get().Cast(); - D_ASSERT(arithmetic.return_type.IsIntegral()); - D_ASSERT(arithmetic.children[0]->return_type.IsIntegral()); - if (inner_constant.value.IsNull() || outer_constant.value.IsNull()) { - return make_uniq(Value(comparison.return_type)); - } - auto &constant_type = outer_constant.return_type; - hugeint_t outer_value = IntegralValue::Get(outer_constant.value); - hugeint_t inner_value = IntegralValue::Get(inner_constant.value); - - idx_t arithmetic_child_index = arithmetic.children[0].get() == &inner_constant ? 1 : 0; - auto &op_type = arithmetic.function.name; - if (op_type == "+") { - // [x + 1 COMP 10] OR [1 + x COMP 10] - // order does not matter in addition: - // simply change right side to 10-1 (outer_constant - inner_constant) - if (!Hugeint::TrySubtractInPlace(outer_value, inner_value)) { - return nullptr; - } - auto result_value = Value::HUGEINT(outer_value); - if (!result_value.DefaultTryCastAs(constant_type)) { - if (comparison.type != ExpressionType::COMPARE_EQUAL) { - return nullptr; - } - // if the cast is not possible then the comparison is not possible - // for example, if we have x + 5 = 3, where x is an unsigned number, we will get x = -2 - // since this is not possible we can remove the entire branch here - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(false)); - } - outer_constant.value = std::move(result_value); - } else if (op_type == "-") { - // [x - 1 COMP 10] O R [1 - x COMP 10] - // order matters in subtraction: - if (arithmetic_child_index == 0) { - // [x - 1 COMP 10] - // change right side to 10+1 (outer_constant + inner_constant) - if (!Hugeint::TryAddInPlace(outer_value, inner_value)) { - return nullptr; - } - auto result_value = Value::HUGEINT(outer_value); - if (!result_value.DefaultTryCastAs(constant_type)) { - // if the cast is not possible then an equality comparison is not possible - if (comparison.type != ExpressionType::COMPARE_EQUAL) { - return nullptr; - } - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(false)); - } - outer_constant.value = std::move(result_value); - } else { - // [1 - x COMP 10] - // change right side to 1-10=-9 - if (!Hugeint::TrySubtractInPlace(inner_value, outer_value)) { - return nullptr; - } - auto result_value = Value::HUGEINT(inner_value); - if (!result_value.DefaultTryCastAs(constant_type)) { - // if the cast is not possible then an equality comparison is not possible - if (comparison.type != ExpressionType::COMPARE_EQUAL) { - return nullptr; - } - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(false)); - } - outer_constant.value = std::move(result_value); - // in this case, we should also flip the comparison - // e.g. if we have [4 - x < 2] then we should have [x > 2] - comparison.type = FlipComparisonExpression(comparison.type); - } - } else { - D_ASSERT(op_type == "*"); - // [x * 2 COMP 10] OR [2 * x COMP 10] - // order does not matter in multiplication: - // change right side to 10/2 (outer_constant / inner_constant) - // but ONLY if outer_constant is cleanly divisible by the inner_constant - if (inner_value == 0) { - // x * 0, the result is either 0 or NULL - // we let the arithmetic_simplification rule take care of simplifying this first - return nullptr; - } - // check out of range for HUGEINT or not cleanly divisible - // HUGEINT is not cleanly divisible when outer_value == minimum and inner value == -1. (modulo overflow) - if ((outer_value == NumericLimits::Minimum() && inner_value == -1) || - outer_value % inner_value != 0) { - bool is_equality = comparison.type == ExpressionType::COMPARE_EQUAL; - bool is_inequality = comparison.type == ExpressionType::COMPARE_NOTEQUAL; - if (is_equality || is_inequality) { - // we know the values are not equal - // the result will be either FALSE or NULL (if COMPARE_EQUAL) - // or TRUE or NULL (if COMPARE_NOTEQUAL) - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(is_inequality)); - } else { - // not cleanly divisible and we are doing > >= < <=, skip the simplification for now - return nullptr; - } - } - if (inner_value < 0) { - // multiply by negative value, need to flip expression - comparison.type = FlipComparisonExpression(comparison.type); - } - // else divide the RHS by the LHS - // we need to do a range check on the cast even though we do a division - // because e.g. -128 / -1 = 128, which is out of range - auto result_value = Value::HUGEINT(outer_value / inner_value); - if (!result_value.DefaultTryCastAs(constant_type)) { - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(false)); - } - outer_constant.value = std::move(result_value); - } - // replace left side with x - // first extract x from the arithmetic expression - auto arithmetic_child = std::move(arithmetic.children[arithmetic_child_index]); - // then place in the comparison - if (comparison.left.get() == &outer_constant) { - comparison.right = std::move(arithmetic_child); - } else { - comparison.left = std::move(arithmetic_child); - } - changes_made = true; - return nullptr; -} - -} // namespace duckdb diff --git a/src/optimizer/rule/ordered_aggregate_optimizer.cpp b/src/optimizer/rule/ordered_aggregate_optimizer.cpp index 553c0e30f450..8b137891791f 100644 --- a/src/optimizer/rule/ordered_aggregate_optimizer.cpp +++ b/src/optimizer/rule/ordered_aggregate_optimizer.cpp @@ -1,102 +1 @@ -#include "duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp" -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/function/function_binder.hpp" -#include "duckdb/optimizer/matcher/expression_matcher.hpp" -#include "duckdb/optimizer/expression_rewriter.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/planner/operator/logical_aggregate.hpp" - -namespace duckdb { - -OrderedAggregateOptimizer::OrderedAggregateOptimizer(ExpressionRewriter &rewriter) : Rule(rewriter) { - // we match on an OR expression within a LogicalFilter node - root = make_uniq(); - root->expr_class = ExpressionClass::BOUND_AGGREGATE; -} - -unique_ptr OrderedAggregateOptimizer::Apply(ClientContext &context, BoundAggregateExpression &aggr, - vector> &groups, bool &changes_made) { - if (!aggr.order_bys) { - // no ORDER BYs defined - return nullptr; - } - if (aggr.function.order_dependent == AggregateOrderDependent::NOT_ORDER_DEPENDENT) { - // not an order dependent aggregate but we have an ORDER BY clause - remove it - aggr.order_bys.reset(); - changes_made = true; - return nullptr; - } - - // Remove unnecessary ORDER BY clauses and return if nothing remains - if (aggr.order_bys->Simplify(groups)) { - aggr.order_bys.reset(); - changes_made = true; - return nullptr; - } - - // Rewrite first/last/arbitrary/any_value to use arg_xxx[_null] and create_sort_key - const auto &aggr_name = aggr.function.name; - string arg_xxx_name; - if (aggr_name == "last") { - arg_xxx_name = "arg_max_null"; - } else if (aggr_name == "first" || aggr_name == "arbitrary") { - arg_xxx_name = "arg_min_null"; - } else if (aggr_name == "any_value") { - arg_xxx_name = "arg_min"; - } else { - return nullptr; - } - - FunctionBinder binder(context); - vector> sort_children; - for (auto &order : aggr.order_bys->orders) { - sort_children.emplace_back(std::move(order.expression)); - - string modifier; - modifier += (order.type == OrderType::ASCENDING) ? "ASC" : "DESC"; - modifier += " NULLS"; - modifier += (order.null_order == OrderByNullType::NULLS_FIRST) ? " FIRST" : " LAST"; - sort_children.emplace_back(make_uniq(Value(modifier))); - } - aggr.order_bys.reset(); - - ErrorData error; - auto sort_key = binder.BindScalarFunction(DEFAULT_SCHEMA, "create_sort_key", std::move(sort_children), error); - if (!sort_key) { - error.Throw(); - } - - auto &children = aggr.children; - children.emplace_back(std::move(sort_key)); - - // Look up the arg_xxx_name function in the catalog - QueryErrorContext error_context; - auto &func = Catalog::GetEntry(context, SYSTEM_CATALOG, DEFAULT_SCHEMA, arg_xxx_name, - error_context); - D_ASSERT(func.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); - - // bind the aggregate - vector types; - for (const auto &child : children) { - types.emplace_back(child->return_type); - } - auto best_function = binder.BindFunction(func.name, func.functions, types, error); - if (best_function == DConstants::INVALID_INDEX) { - error.Throw(); - } - // found a matching function! - auto bound_function = func.functions.GetFunctionByOffset(best_function); - return binder.BindAggregateFunction(bound_function, std::move(children), std::move(aggr.filter), - aggr.IsDistinct() ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT); -} - -unique_ptr OrderedAggregateOptimizer::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &aggr = bindings[0].get().Cast(); - return Apply(rewriter.context, aggr, op.Cast().groups, changes_made); -} - -} // namespace duckdb diff --git a/src/planner/CMakeLists.txt b/src/planner/CMakeLists.txt index 19f4c28a0758..8b137891791f 100644 --- a/src/planner/CMakeLists.txt +++ b/src/planner/CMakeLists.txt @@ -1,27 +1 @@ -add_subdirectory(expression) -add_subdirectory(binder) -add_subdirectory(expression_binder) -add_subdirectory(filter) -add_subdirectory(operator) -add_subdirectory(subquery) -add_library_unity( - duckdb_planner - OBJECT - bound_result_modifier.cpp - bound_parameter_map.cpp - expression_iterator.cpp - expression.cpp - table_binding.cpp - expression_binder.cpp - joinside.cpp - logical_operator.cpp - binder.cpp - bind_context.cpp - planner.cpp - pragma_handler.cpp - logical_operator_visitor.cpp - table_filter.cpp) -set(ALL_OBJECT_FILES - ${ALL_OBJECT_FILES} $ - PARENT_SCOPE) From b0966a06757f6935c87661869a4186b4268cb520 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Wed, 3 Apr 2024 15:11:03 +0200 Subject: [PATCH 041/201] moar --- src/function/scalar/string/contains.cpp | 4 +- .../string/regexp/regexp_extract_all.cpp | 4 +- src/function/scalar/string/strip_accents.cpp | 3 +- src/function/scalar/string/substring.cpp | 18 ++- src/function/scalar/struct/struct_extract.cpp | 4 +- src/function/table/arrow.cpp | 7 +- src/function/table/arrow_conversion.cpp | 149 ++++++++++-------- 7 files changed, 109 insertions(+), 80 deletions(-) diff --git a/src/function/scalar/string/contains.cpp b/src/function/scalar/string/contains.cpp index fb68e00e8dd8..3e24ed6a71b4 100644 --- a/src/function/scalar/string/contains.cpp +++ b/src/function/scalar/string/contains.cpp @@ -22,7 +22,7 @@ static idx_t ContainsUnaligned(const unsigned char *haystack, idx_t haystack_siz UNSIGNED haystack_entry = 0; const UNSIGNED start = (sizeof(UNSIGNED) * 8) - 8; const UNSIGNED shift = (sizeof(UNSIGNED) - NEEDLE_SIZE) * 8; - for (int i = 0; i < NEEDLE_SIZE; i++) { + for (idx_t i = 0; i < NEEDLE_SIZE; i++) { needle_entry |= UNSIGNED(needle[i]) << UNSIGNED(start - i * 8); haystack_entry |= UNSIGNED(haystack[i]) << UNSIGNED(start - i * 8); } @@ -106,7 +106,7 @@ idx_t ContainsFun::Find(const unsigned char *haystack, idx_t haystack_size, cons if (location == nullptr) { return DConstants::INVALID_INDEX; } - idx_t base_offset = const_uchar_ptr_cast(location) - haystack; + idx_t base_offset = UnsafeNumericCast(const_uchar_ptr_cast(location) - haystack); haystack_size -= base_offset; haystack = const_uchar_ptr_cast(location); // switch algorithm depending on needle size diff --git a/src/function/scalar/string/regexp/regexp_extract_all.cpp b/src/function/scalar/string/regexp/regexp_extract_all.cpp index f8a5736b4374..0e6cfa7dd5c5 100644 --- a/src/function/scalar/string/regexp/regexp_extract_all.cpp +++ b/src/function/scalar/string/regexp/regexp_extract_all.cpp @@ -96,7 +96,7 @@ void ExtractSingleTuple(const string_t &string, duckdb_re2::RE2 &pattern, int32_ // Every group is a substring of the original, we can find out the offset using the pointer // the 'match_group' address is guaranteed to be bigger than that of the source D_ASSERT(const_char_ptr_cast(match_group.begin()) >= string.GetData()); - idx_t offset = match_group.begin() - string.GetData(); + auto offset = UnsafeNumericCast(match_group.begin() - string.GetData()); list_content[child_idx] = string_t(string.GetData() + offset, UnsafeNumericCast(match_group.size())); } @@ -199,7 +199,7 @@ void RegexpExtractAll::Execute(DataChunk &args, ExpressionState &state, Vector & if (group_count_p == -1) { throw InvalidInputException("Pattern failed to parse, error: '%s'", stored_re->error()); } - non_const_args->SetSize(group_count_p); + non_const_args->SetSize(UnsafeNumericCast(group_count_p)); } } diff --git a/src/function/scalar/string/strip_accents.cpp b/src/function/scalar/string/strip_accents.cpp index 758c72646ec7..1883c60f0f73 100644 --- a/src/function/scalar/string/strip_accents.cpp +++ b/src/function/scalar/string/strip_accents.cpp @@ -22,7 +22,8 @@ struct StripAccentsOperator { } // non-ascii, perform collation - auto stripped = utf8proc_remove_accents((const utf8proc_uint8_t *)input.GetData(), input.GetSize()); + auto stripped = utf8proc_remove_accents((const utf8proc_uint8_t *)input.GetData(), + UnsafeNumericCast(input.GetSize())); auto result_str = StringVector::AddString(result, const_char_ptr_cast(stripped)); free(stripped); return result_str; diff --git a/src/function/scalar/string/substring.cpp b/src/function/scalar/string/substring.cpp index a6d582c4ecd8..76f4859f2734 100644 --- a/src/function/scalar/string/substring.cpp +++ b/src/function/scalar/string/substring.cpp @@ -42,7 +42,7 @@ string_t SubstringEmptyString(Vector &result) { string_t SubstringSlice(Vector &result, const char *input_data, int64_t offset, int64_t length) { auto result_string = StringVector::EmptyString(result, UnsafeNumericCast(length)); auto result_data = result_string.GetDataWriteable(); - memcpy(result_data, input_data + offset, length); + memcpy(result_data, input_data + offset, UnsafeNumericCast(length)); result_string.Finalize(); return result_string; } @@ -88,10 +88,10 @@ string_t SubstringASCII(Vector &result, string_t input, int64_t offset, int64_t AssertInSupportedRange(input_size, offset, length); int64_t start, end; - if (!SubstringStartEnd(input_size, offset, length, start, end)) { + if (!SubstringStartEnd(UnsafeNumericCast(input_size), offset, length, start, end)) { return SubstringEmptyString(result); } - return SubstringSlice(result, input_data, start, end - start); + return SubstringSlice(result, input_data, start, UnsafeNumericCast(end - start)); } string_t SubstringFun::SubstringUnicode(Vector &result, string_t input, int64_t offset, int64_t length) { @@ -186,7 +186,8 @@ string_t SubstringFun::SubstringUnicode(Vector &result, string_t input, int64_t } D_ASSERT(end_pos >= start_pos); // after we have found these, we can slice the substring - return SubstringSlice(result, input_data, start_pos, end_pos - start_pos); + return SubstringSlice(result, input_data, UnsafeNumericCast(start_pos), + UnsafeNumericCast(end_pos - start_pos)); } string_t SubstringFun::SubstringGrapheme(Vector &result, string_t input, int64_t offset, int64_t length) { @@ -198,14 +199,14 @@ string_t SubstringFun::SubstringGrapheme(Vector &result, string_t input, int64_t // we don't know yet if the substring is ascii, but we assume it is (for now) // first get the start and end as if this was an ascii string int64_t start, end; - if (!SubstringStartEnd(input_size, offset, length, start, end)) { + if (!SubstringStartEnd(UnsafeNumericCast(input_size), offset, length, start, end)) { return SubstringEmptyString(result); } // now check if all the characters between 0 and end are ascii characters // note that we scan one further to check for a potential combining diacritics (e.g. i + diacritic is ï) bool is_ascii = true; - idx_t ascii_end = MinValue(end + 1, input_size); + idx_t ascii_end = MinValue(UnsafeNumericCast(end + 1), input_size); for (idx_t i = 0; i < ascii_end; i++) { if (input_data[i] & 0x80) { // found a non-ascii character: eek @@ -229,7 +230,7 @@ string_t SubstringFun::SubstringGrapheme(Vector &result, string_t input, int64_t return true; }); // now call substring start and end again, but with the number of unicode characters this time - SubstringStartEnd(num_characters, offset, length, start, end); + SubstringStartEnd(UnsafeNumericCast(num_characters), offset, length, start, end); } // now scan the graphemes of the string to find the positions of the start and end characters @@ -249,7 +250,8 @@ string_t SubstringFun::SubstringGrapheme(Vector &result, string_t input, int64_t return SubstringEmptyString(result); } // after we have found these, we can slice the substring - return SubstringSlice(result, input_data, start_pos, end_pos - start_pos); + return SubstringSlice(result, input_data, UnsafeNumericCast(start_pos), + UnsafeNumericCast(end_pos - start_pos)); } struct SubstringUnicodeOp { diff --git a/src/function/scalar/struct/struct_extract.cpp b/src/function/scalar/struct/struct_extract.cpp index 50e033b77079..2572cb51bc30 100644 --- a/src/function/scalar/struct/struct_extract.cpp +++ b/src/function/scalar/struct/struct_extract.cpp @@ -133,8 +133,8 @@ static unique_ptr StructExtractBindIndex(ClientContext &context, S throw BinderException("Key index %lld for struct_extract out of range - expected an index between 1 and %llu", index, struct_children.size()); } - bound_function.return_type = struct_children[index - 1].second; - return make_uniq(index - 1); + bound_function.return_type = struct_children[NumericCast(index - 1)].second; + return make_uniq(NumericCast(index - 1)); } static unique_ptr PropagateStructExtractStats(ClientContext &context, FunctionStatisticsInput &input) { diff --git a/src/function/table/arrow.cpp b/src/function/table/arrow.cpp index a65257a8363b..a38576170b31 100644 --- a/src/function/table/arrow.cpp +++ b/src/function/table/arrow.cpp @@ -117,7 +117,7 @@ static unique_ptr GetArrowLogicalTypeNoDictionary(ArrowSchema &schema return list_type; } else if (format[0] == '+' && format[1] == 'w') { std::string parameters = format.substr(format.find(':') + 1); - idx_t fixed_size = std::stoi(parameters); + auto fixed_size = NumericCast(std::stoi(parameters)); auto child_type = ArrowTableFunction::GetArrowLogicalType(*schema.children[0]); auto list_type = make_uniq(LogicalType::ARRAY(child_type->GetDuckType(), fixed_size), fixed_size); list_type->AddChild(std::move(child_type)); @@ -197,7 +197,7 @@ static unique_ptr GetArrowLogicalTypeNoDictionary(ArrowSchema &schema return make_uniq(LogicalType::BLOB, ArrowVariableSizeType::SUPER_SIZE); } else if (format[0] == 'w') { std::string parameters = format.substr(format.find(':') + 1); - idx_t fixed_size = std::stoi(parameters); + auto fixed_size = NumericCast(std::stoi(parameters)); return make_uniq(LogicalType::BLOB, fixed_size); } else if (format[0] == 't' && format[1] == 's') { // Timestamp with Timezone @@ -366,7 +366,8 @@ void ArrowTableFunction::ArrowScanFunction(ClientContext &context, TableFunction return; } } - int64_t output_size = MinValue(STANDARD_VECTOR_SIZE, state.chunk->arrow_array.length - state.chunk_offset); + auto output_size = + MinValue(STANDARD_VECTOR_SIZE, NumericCast(state.chunk->arrow_array.length) - state.chunk_offset); data.lines_read += output_size; if (global_state.CanRemoveFilterColumns()) { state.all_columns.Reset(); diff --git a/src/function/table/arrow_conversion.cpp b/src/function/table/arrow_conversion.cpp index 55058683b24d..c82b6a64ad4d 100644 --- a/src/function/table/arrow_conversion.cpp +++ b/src/function/table/arrow_conversion.cpp @@ -40,12 +40,12 @@ idx_t GetEffectiveOffset(ArrowArray &array, int64_t parent_offset, const ArrowSc if (nested_offset != -1) { // The parent of this array is a list // We just ignore the parent offset, it's already applied to the list - return array.offset + nested_offset; + return UnsafeNumericCast(array.offset + nested_offset); } // Parent offset is set in the case of a struct, it applies to all child arrays // 'chunk_offset' is how much of the chunk we've already scanned, in case the chunk size exceeds // STANDARD_VECTOR_SIZE - return array.offset + parent_offset + state.chunk_offset; + return UnsafeNumericCast(array.offset + parent_offset) + state.chunk_offset; } template @@ -158,7 +158,8 @@ static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, ArrowArrayScanS ListVector::Reserve(vector, list_size); ListVector::SetListSize(vector, list_size); auto &child_vector = ListVector::GetEntry(vector); - SetValidityMask(child_vector, *array.children[0], scan_state, list_size, array.offset, start_offset); + SetValidityMask(child_vector, *array.children[0], scan_state, list_size, array.offset, + NumericCast(start_offset)); auto &list_mask = FlatVector::Validity(vector); if (parent_mask) { //! Since this List is owned by a struct we must guarantee their validity map matches on Null @@ -183,13 +184,16 @@ static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, ArrowArrayScanS switch (array_physical_type) { case ArrowArrayPhysicalType::DICTIONARY_ENCODED: // TODO: add support for offsets - ColumnArrowToDuckDBDictionary(child_vector, child_array, child_state, list_size, child_type, start_offset); + ColumnArrowToDuckDBDictionary(child_vector, child_array, child_state, list_size, child_type, + NumericCast(start_offset)); break; case ArrowArrayPhysicalType::RUN_END_ENCODED: - ColumnArrowToDuckDBRunEndEncoded(child_vector, child_array, child_state, list_size, child_type, start_offset); + ColumnArrowToDuckDBRunEndEncoded(child_vector, child_array, child_state, list_size, child_type, + NumericCast(start_offset)); break; case ArrowArrayPhysicalType::DEFAULT: - ColumnArrowToDuckDB(child_vector, child_array, child_state, list_size, child_type, start_offset); + ColumnArrowToDuckDB(child_vector, child_array, child_state, list_size, child_type, + NumericCast(start_offset)); break; default: throw NotImplementedException("ArrowArrayPhysicalType not recognized"); @@ -209,7 +213,8 @@ static void ArrowToDuckDBArray(Vector &vector, ArrowArray &array, ArrowArrayScan SetValidityMask(vector, array, scan_state, size, parent_offset, nested_offset); auto &child_vector = ArrayVector::GetEntry(vector); - SetValidityMask(child_vector, *array.children[0], scan_state, child_count, array.offset, child_offset); + SetValidityMask(child_vector, *array.children[0], scan_state, child_count, array.offset, + NumericCast(child_offset)); auto &array_mask = FlatVector::Validity(vector); if (parent_mask) { @@ -244,9 +249,10 @@ static void ArrowToDuckDBArray(Vector &vector, ArrowArray &array, ArrowArrayScan } else { if (child_array.dictionary) { ColumnArrowToDuckDBDictionary(child_vector, child_array, child_state, child_count, child_type, - child_offset); + NumericCast(child_offset)); } else { - ColumnArrowToDuckDB(child_vector, child_array, child_state, child_count, child_type, child_offset); + ColumnArrowToDuckDB(child_vector, child_array, child_state, child_count, child_type, + NumericCast(child_offset)); } } } @@ -339,8 +345,9 @@ static void SetVectorString(Vector &vector, idx_t size, char *cdata, T *offsets) static void DirectConversion(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, int64_t nested_offset, uint64_t parent_offset) { auto internal_type = GetTypeIdSize(vector.GetType().InternalType()); - auto data_ptr = ArrowBufferData(array, 1) + - internal_type * GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + auto data_ptr = + ArrowBufferData(array, 1) + + internal_type * GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); FlatVector::SetData(vector, data_ptr); } @@ -595,7 +602,7 @@ static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, ArrowArray &array, auto &scan_state = array_state.state; D_ASSERT(run_ends_array.length == values_array.length); - auto compressed_size = run_ends_array.length; + auto compressed_size = NumericCast(run_ends_array.length); // Create a vector for the run ends and the values auto &run_end_encoding = array_state.RunEndEncoding(); if (!run_end_encoding.run_ends) { @@ -606,11 +613,12 @@ static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, ArrowArray &array, ColumnArrowToDuckDB(*run_end_encoding.run_ends, run_ends_array, array_state, compressed_size, run_ends_type); auto &values = *run_end_encoding.values; - SetValidityMask(values, values_array, scan_state, compressed_size, parent_offset, nested_offset); + SetValidityMask(values, values_array, scan_state, compressed_size, NumericCast(parent_offset), + nested_offset); ColumnArrowToDuckDB(values, values_array, array_state, compressed_size, values_type); } - idx_t scan_offset = GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + idx_t scan_offset = GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); auto physical_type = run_ends_type.GetDuckType().InternalType(); switch (physical_type) { case PhysicalType::INT16: @@ -641,12 +649,12 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca //! Arrow bit-packs boolean values //! Lets first figure out where we are in the source array auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, parent_offset, scan_state, nested_offset) / 8; + GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset) / 8; auto tgt_ptr = (uint8_t *)FlatVector::GetData(vector); int src_pos = 0; idx_t cur_bit = scan_state.chunk_offset % 8; if (nested_offset != -1) { - cur_bit = nested_offset % 8; + cur_bit = NumericCast(nested_offset % 8); } for (idx_t row = 0; row < size; row++) { if ((src_ptr[src_pos] & (1 << cur_bit)) == 0) { @@ -686,11 +694,11 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca auto cdata = ArrowBufferData(array, 2); if (size_type == ArrowVariableSizeType::SUPER_SIZE) { auto offsets = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); SetVectorString(vector, size, cdata, offsets); } else { auto offsets = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); SetVectorString(vector, size, cdata, offsets); } break; @@ -706,7 +714,7 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca case ArrowDateTimeType::MILLISECONDS: { //! convert date from nanoseconds to days auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); auto tgt_ptr = FlatVector::GetData(vector); for (idx_t row = 0; row < size; row++) { tgt_ptr[row] = date_t( @@ -723,21 +731,24 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca auto precision = arrow_type.GetDateTimeType(); switch (precision) { case ArrowDateTimeType::SECONDS: { - TimeConversion(vector, array, scan_state, nested_offset, parent_offset, size, 1000000); + TimeConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, + 1000000); break; } case ArrowDateTimeType::MILLISECONDS: { - TimeConversion(vector, array, scan_state, nested_offset, parent_offset, size, 1000); + TimeConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, + 1000); break; } case ArrowDateTimeType::MICROSECONDS: { - TimeConversion(vector, array, scan_state, nested_offset, parent_offset, size, 1); + TimeConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, + 1); break; } case ArrowDateTimeType::NANOSECONDS: { auto tgt_ptr = FlatVector::GetData(vector); auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); for (idx_t row = 0; row < size; row++) { tgt_ptr[row].micros = src_ptr[row] / 1000; } @@ -752,11 +763,13 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca auto precision = arrow_type.GetDateTimeType(); switch (precision) { case ArrowDateTimeType::SECONDS: { - TimestampTZConversion(vector, array, scan_state, nested_offset, parent_offset, size, 1000000); + TimestampTZConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, + 1000000); break; } case ArrowDateTimeType::MILLISECONDS: { - TimestampTZConversion(vector, array, scan_state, nested_offset, parent_offset, size, 1000); + TimestampTZConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, + 1000); break; } case ArrowDateTimeType::MICROSECONDS: { @@ -766,7 +779,7 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca case ArrowDateTimeType::NANOSECONDS: { auto tgt_ptr = FlatVector::GetData(vector); auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); for (idx_t row = 0; row < size; row++) { tgt_ptr[row].value = src_ptr[row] / 1000; } @@ -781,22 +794,25 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca auto precision = arrow_type.GetDateTimeType(); switch (precision) { case ArrowDateTimeType::SECONDS: { - IntervalConversionUs(vector, array, scan_state, nested_offset, parent_offset, size, 1000000); + IntervalConversionUs(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, + 1000000); break; } case ArrowDateTimeType::DAYS: case ArrowDateTimeType::MILLISECONDS: { - IntervalConversionUs(vector, array, scan_state, nested_offset, parent_offset, size, 1000); + IntervalConversionUs(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, + 1000); break; } case ArrowDateTimeType::MICROSECONDS: { - IntervalConversionUs(vector, array, scan_state, nested_offset, parent_offset, size, 1); + IntervalConversionUs(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, + 1); break; } case ArrowDateTimeType::NANOSECONDS: { auto tgt_ptr = FlatVector::GetData(vector); auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); for (idx_t row = 0; row < size; row++) { tgt_ptr[row].micros = src_ptr[row] / 1000; tgt_ptr[row].days = 0; @@ -805,11 +821,13 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca break; } case ArrowDateTimeType::MONTHS: { - IntervalConversionMonths(vector, array, scan_state, nested_offset, parent_offset, size); + IntervalConversionMonths(vector, array, scan_state, nested_offset, NumericCast(parent_offset), + size); break; } case ArrowDateTimeType::MONTH_DAY_NANO: { - IntervalConversionMonthDayNanos(vector, array, scan_state, nested_offset, parent_offset, size); + IntervalConversionMonthDayNanos(vector, array, scan_state, nested_offset, + NumericCast(parent_offset), size); break; } default: @@ -820,8 +838,8 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca case LogicalTypeId::DECIMAL: { auto val_mask = FlatVector::Validity(vector); //! We have to convert from INT128 - auto src_ptr = - ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + auto src_ptr = ArrowBufferData(array, 1) + + GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); switch (vector.GetType().InternalType()) { case PhysicalType::INT16: { auto tgt_ptr = FlatVector::GetData(vector); @@ -859,7 +877,8 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca case PhysicalType::INT128: { FlatVector::SetData(vector, ArrowBufferData(array, 1) + GetTypeIdSize(vector.GetType().InternalType()) * - GetEffectiveOffset(array, parent_offset, scan_state, nested_offset)); + GetEffectiveOffset(array, NumericCast(parent_offset), + scan_state, nested_offset)); break; } default: @@ -869,19 +888,23 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca break; } case LogicalTypeId::BLOB: { - ArrowToDuckDBBlob(vector, array, scan_state, size, arrow_type, nested_offset, parent_offset); + ArrowToDuckDBBlob(vector, array, scan_state, size, arrow_type, nested_offset, + NumericCast(parent_offset)); break; } case LogicalTypeId::LIST: { - ArrowToDuckDBList(vector, array, array_state, size, arrow_type, nested_offset, parent_mask, parent_offset); + ArrowToDuckDBList(vector, array, array_state, size, arrow_type, nested_offset, parent_mask, + NumericCast(parent_offset)); break; } case LogicalTypeId::ARRAY: { - ArrowToDuckDBArray(vector, array, array_state, size, arrow_type, nested_offset, parent_mask, parent_offset); + ArrowToDuckDBArray(vector, array, array_state, size, arrow_type, nested_offset, parent_mask, + NumericCast(parent_offset)); break; } case LogicalTypeId::MAP: { - ArrowToDuckDBList(vector, array, array_state, size, arrow_type, nested_offset, parent_mask, parent_offset); + ArrowToDuckDBList(vector, array, array_state, size, arrow_type, nested_offset, parent_mask, + NumericCast(parent_offset)); ArrowToDuckDBMapVerify(vector, size); break; } @@ -889,7 +912,7 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca //! Fill the children auto &child_entries = StructVector::GetEntries(vector); auto &struct_validity_mask = FlatVector::Validity(vector); - for (int64_t child_idx = 0; child_idx < array.n_children; child_idx++) { + for (idx_t child_idx = 0; child_idx < NumericCast(array.n_children); child_idx++) { auto &child_entry = *child_entries[child_idx]; auto &child_array = *array.children[child_idx]; auto &child_type = arrow_type[child_idx]; @@ -909,15 +932,15 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca switch (array_physical_type) { case ArrowArrayPhysicalType::DICTIONARY_ENCODED: ColumnArrowToDuckDBDictionary(child_entry, child_array, child_state, size, child_type, nested_offset, - &struct_validity_mask, array.offset); + &struct_validity_mask, NumericCast(array.offset)); break; case ArrowArrayPhysicalType::RUN_END_ENCODED: ColumnArrowToDuckDBRunEndEncoded(child_entry, child_array, child_state, size, child_type, nested_offset, - &struct_validity_mask, array.offset); + &struct_validity_mask, NumericCast(array.offset)); break; case ArrowArrayPhysicalType::DEFAULT: ColumnArrowToDuckDB(child_entry, child_array, child_state, size, child_type, nested_offset, - &struct_validity_mask, array.offset); + &struct_validity_mask, NumericCast(array.offset)); break; default: throw NotImplementedException("ArrowArrayPhysicalType not recognized"); @@ -933,13 +956,13 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca auto &validity_mask = FlatVector::Validity(vector); duckdb::vector children; - for (int64_t child_idx = 0; child_idx < array.n_children; child_idx++) { + for (idx_t child_idx = 0; child_idx < NumericCast(array.n_children); child_idx++) { Vector child(members[child_idx].second, size); auto &child_array = *array.children[child_idx]; auto &child_state = array_state.GetChild(child_idx); auto &child_type = arrow_type[child_idx]; - SetValidityMask(child, child_array, scan_state, size, parent_offset, nested_offset); + SetValidityMask(child, child_array, scan_state, size, NumericCast(parent_offset), nested_offset); auto array_physical_type = GetArrowArrayPhysicalType(child_type); switch (array_physical_type) { @@ -960,7 +983,7 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca } for (idx_t row_idx = 0; row_idx < size; row_idx++) { - auto tag = type_ids[row_idx]; + auto tag = NumericCast(type_ids[row_idx]); auto out_of_range = tag < 0 || tag >= array.n_children; if (out_of_range) { @@ -982,7 +1005,7 @@ template static void SetSelectionVectorLoop(SelectionVector &sel, data_ptr_t indices_p, idx_t size) { auto indices = reinterpret_cast(indices_p); for (idx_t row = 0; row < size; row++) { - sel.set_index(row, indices[row]); + sel.set_index(row, UnsafeNumericCast(indices[row])); } } @@ -994,7 +1017,7 @@ static void SetSelectionVectorLoopWithChecks(SelectionVector &sel, data_ptr_t in if (indices[row] > NumericLimits::Maximum()) { throw ConversionException("DuckDB only supports indices that fit on an uint32"); } - sel.set_index(row, indices[row]); + sel.set_index(row, NumericCast(indices[row])); } } @@ -1004,7 +1027,7 @@ static void SetMaskedSelectionVectorLoop(SelectionVector &sel, data_ptr_t indice auto indices = reinterpret_cast(indices_p); for (idx_t row = 0; row < size; row++) { if (mask.RowIsValid(row)) { - sel.set_index(row, indices[row]); + sel.set_index(row, UnsafeNumericCast(indices[row])); } else { //! Need to point out to last element sel.set_index(row, last_element_pos); @@ -1119,22 +1142,23 @@ static void ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, Arr const bool has_nulls = CanContainNull(array, parent_mask); if (array_state.CacheOutdated(array.dictionary)) { //! We need to set the dictionary data for this column - auto base_vector = make_uniq(vector.GetType(), array.dictionary->length); - SetValidityMask(*base_vector, *array.dictionary, scan_state, array.dictionary->length, 0, 0, has_nulls); + auto base_vector = make_uniq(vector.GetType(), NumericCast(array.dictionary->length)); + SetValidityMask(*base_vector, *array.dictionary, scan_state, NumericCast(array.dictionary->length), 0, 0, + has_nulls); auto &dictionary_type = arrow_type.GetDictionary(); auto arrow_physical_type = GetArrowArrayPhysicalType(dictionary_type); switch (arrow_physical_type) { case ArrowArrayPhysicalType::DICTIONARY_ENCODED: - ColumnArrowToDuckDBDictionary(*base_vector, *array.dictionary, array_state, array.dictionary->length, - dictionary_type); + ColumnArrowToDuckDBDictionary(*base_vector, *array.dictionary, array_state, + NumericCast(array.dictionary->length), dictionary_type); break; case ArrowArrayPhysicalType::RUN_END_ENCODED: - ColumnArrowToDuckDBRunEndEncoded(*base_vector, *array.dictionary, array_state, array.dictionary->length, - dictionary_type); + ColumnArrowToDuckDBRunEndEncoded(*base_vector, *array.dictionary, array_state, + NumericCast(array.dictionary->length), dictionary_type); break; case ArrowArrayPhysicalType::DEFAULT: - ColumnArrowToDuckDB(*base_vector, *array.dictionary, array_state, array.dictionary->length, - dictionary_type); + ColumnArrowToDuckDB(*base_vector, *array.dictionary, array_state, + NumericCast(array.dictionary->length), dictionary_type); break; default: throw NotImplementedException("ArrowArrayPhysicalType not recognized"); @@ -1143,14 +1167,14 @@ static void ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, Arr } auto offset_type = arrow_type.GetDuckType(); //! Get Pointer to Indices of Dictionary - auto indices = - ArrowBufferData(array, 1) + - GetTypeIdSize(offset_type.InternalType()) * GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + auto indices = ArrowBufferData(array, 1) + + GetTypeIdSize(offset_type.InternalType()) * + GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); SelectionVector sel; if (has_nulls) { ValidityMask indices_validity; - GetValidityMask(indices_validity, array, scan_state, size, parent_offset); + GetValidityMask(indices_validity, array, scan_state, size, NumericCast(parent_offset)); if (parent_mask && !parent_mask->AllValid()) { auto &struct_validity_mask = *parent_mask; for (idx_t i = 0; i < size; i++) { @@ -1159,7 +1183,8 @@ static void ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, Arr } } } - SetSelectionVector(sel, indices, offset_type, size, &indices_validity, array.dictionary->length); + SetSelectionVector(sel, indices, offset_type, size, &indices_validity, + NumericCast(array.dictionary->length)); } else { SetSelectionVector(sel, indices, offset_type, size); } From 483033e0b0b4a73843f6ec98df6af4e4c5452572 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Thu, 4 Apr 2024 10:09:17 +0200 Subject: [PATCH 042/201] more casts, schema reflection should really use unsiged types everywhere --- src/function/table/copy_csv.cpp | 6 ++++-- src/function/table/read_file.cpp | 3 ++- src/function/table/repeat.cpp | 2 +- src/function/table/repeat_row.cpp | 2 +- src/function/table/system/duckdb_columns.cpp | 6 +++--- .../table/system/duckdb_constraints.cpp | 15 ++++++++------- src/function/table/system/duckdb_databases.cpp | 2 +- .../table/system/duckdb_dependencies.cpp | 4 ++-- src/function/table/system/duckdb_functions.cpp | 4 ++-- src/function/table/system/duckdb_indexes.cpp | 8 ++++---- src/function/table/system/duckdb_memory.cpp | 4 ++-- src/function/table/system/duckdb_schemas.cpp | 4 ++-- src/function/table/system/duckdb_sequences.cpp | 6 +++--- src/function/table/system/duckdb_tables.cpp | 17 +++++++++-------- .../table/system/duckdb_temporary_files.cpp | 2 +- src/function/table/system/duckdb_types.cpp | 12 +++++++----- src/function/table/system/duckdb_views.cpp | 8 ++++---- .../table/system/pragma_database_size.cpp | 8 ++++---- .../table/system/pragma_metadata_info.cpp | 6 +++--- .../table/system/pragma_storage_info.cpp | 12 ++++++------ src/function/table/unnest.cpp | 2 +- src/include/duckdb/execution/index/art/node.hpp | 2 +- 22 files changed, 71 insertions(+), 64 deletions(-) diff --git a/src/function/table/copy_csv.cpp b/src/function/table/copy_csv.cpp index 1fa2e46e7694..12ea326f1a31 100644 --- a/src/function/table/copy_csv.cpp +++ b/src/function/table/copy_csv.cpp @@ -111,8 +111,10 @@ static unique_ptr WriteCSVBind(ClientContext &context, CopyFunctio memset(bind_data->requires_quotes.get(), 0, sizeof(bool) * 256); bind_data->requires_quotes['\n'] = true; bind_data->requires_quotes['\r'] = true; - bind_data->requires_quotes[bind_data->options.dialect_options.state_machine_options.delimiter.GetValue()] = true; - bind_data->requires_quotes[bind_data->options.dialect_options.state_machine_options.quote.GetValue()] = true; + bind_data->requires_quotes[NumericCast( + bind_data->options.dialect_options.state_machine_options.delimiter.GetValue())] = true; + bind_data->requires_quotes[NumericCast( + bind_data->options.dialect_options.state_machine_options.quote.GetValue())] = true; if (!bind_data->options.write_newline.empty()) { bind_data->newline = TransformNewLine(bind_data->options.write_newline); diff --git a/src/function/table/read_file.cpp b/src/function/table/read_file.cpp index 7655d34c3ea8..c6aa707d026e 100644 --- a/src/function/table/read_file.cpp +++ b/src/function/table/read_file.cpp @@ -160,7 +160,8 @@ static void ReadFileExecute(ClientContext &context, TableFunctionInput &input, D } break; case ReadFileBindData::FILE_SIZE_COLUMN: { auto &file_size_vector = output.data[col_idx]; - FlatVector::GetData(file_size_vector)[out_idx] = file_handle->GetFileSize(); + FlatVector::GetData(file_size_vector)[out_idx] = + NumericCast(file_handle->GetFileSize()); } break; case ReadFileBindData::FILE_LAST_MODIFIED_COLUMN: { auto &last_modified_vector = output.data[col_idx]; diff --git a/src/function/table/repeat.cpp b/src/function/table/repeat.cpp index b62cbe15c87e..25fa43dc0675 100644 --- a/src/function/table/repeat.cpp +++ b/src/function/table/repeat.cpp @@ -26,7 +26,7 @@ static unique_ptr RepeatBind(ClientContext &context, TableFunction if (inputs[1].IsNull()) { throw BinderException("Repeat second parameter cannot be NULL"); } - return make_uniq(inputs[0], inputs[1].GetValue()); + return make_uniq(inputs[0], NumericCast(inputs[1].GetValue())); } static unique_ptr RepeatInit(ClientContext &context, TableFunctionInitInput &input) { diff --git a/src/function/table/repeat_row.cpp b/src/function/table/repeat_row.cpp index a81d8720691f..f68bf7cae439 100644 --- a/src/function/table/repeat_row.cpp +++ b/src/function/table/repeat_row.cpp @@ -32,7 +32,7 @@ static unique_ptr RepeatRowBind(ClientContext &context, TableFunct if (inputs.empty()) { throw BinderException("repeat_rows requires at least one column to be specified"); } - return make_uniq(inputs, entry->second.GetValue()); + return make_uniq(inputs, NumericCast(entry->second.GetValue())); } static unique_ptr RepeatRowInit(ClientContext &context, TableFunctionInitInput &input) { diff --git a/src/function/table/system/duckdb_columns.cpp b/src/function/table/system/duckdb_columns.cpp index 2d7790e1109c..1632a9e5c6c8 100644 --- a/src/function/table/system/duckdb_columns.cpp +++ b/src/function/table/system/duckdb_columns.cpp @@ -209,15 +209,15 @@ void ColumnHelper::WriteColumns(idx_t start_index, idx_t start_col, idx_t end_co // database_name, VARCHAR output.SetValue(col++, index, entry.catalog.GetName()); // database_oid, BIGINT - output.SetValue(col++, index, Value::BIGINT(entry.catalog.GetOid())); + output.SetValue(col++, index, Value::BIGINT(NumericCast(entry.catalog.GetOid()))); // schema_name, VARCHAR output.SetValue(col++, index, entry.schema.name); // schema_oid, BIGINT - output.SetValue(col++, index, Value::BIGINT(entry.schema.oid)); + output.SetValue(col++, index, Value::BIGINT(NumericCast(entry.schema.oid))); // table_name, VARCHAR output.SetValue(col++, index, entry.name); // table_oid, BIGINT - output.SetValue(col++, index, Value::BIGINT(entry.oid)); + output.SetValue(col++, index, Value::BIGINT(NumericCast(entry.oid))); // column_name, VARCHAR output.SetValue(col++, index, Value(ColumnName(i))); // column_index, INTEGER diff --git a/src/function/table/system/duckdb_constraints.cpp b/src/function/table/system/duckdb_constraints.cpp index 64aff39469ed..467aa2ae7605 100644 --- a/src/function/table/system/duckdb_constraints.cpp +++ b/src/function/table/system/duckdb_constraints.cpp @@ -180,15 +180,15 @@ void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_ // database_name, LogicalType::VARCHAR output.SetValue(col++, count, Value(table.schema.catalog.GetName())); // database_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(table.schema.catalog.GetOid())); + output.SetValue(col++, count, Value::BIGINT(NumericCast(table.schema.catalog.GetOid()))); // schema_name, LogicalType::VARCHAR output.SetValue(col++, count, Value(table.schema.name)); // schema_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(table.schema.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(table.schema.oid))); // table_name, LogicalType::VARCHAR output.SetValue(col++, count, Value(table.name)); // table_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(table.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(table.oid))); // constraint_index, BIGINT UniqueKeyInfo uk_info; @@ -224,15 +224,16 @@ void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_ } if (uk_info.columns.empty()) { - output.SetValue(col++, count, Value::BIGINT(data.unique_constraint_offset++)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(data.unique_constraint_offset++))); } else { auto known_unique_constraint_offset = data.known_fk_unique_constraint_offsets.find(uk_info); if (known_unique_constraint_offset == data.known_fk_unique_constraint_offsets.end()) { data.known_fk_unique_constraint_offsets.insert(make_pair(uk_info, data.unique_constraint_offset)); - output.SetValue(col++, count, Value::BIGINT(data.unique_constraint_offset)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(data.unique_constraint_offset))); data.unique_constraint_offset++; } else { - output.SetValue(col++, count, Value::BIGINT(known_unique_constraint_offset->second)); + output.SetValue(col++, count, + Value::BIGINT(NumericCast(known_unique_constraint_offset->second))); } } output.SetValue(col++, count, Value(constraint_type)); @@ -286,7 +287,7 @@ void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_ vector index_list; vector column_name_list; for (auto column_index : column_index_list) { - index_list.push_back(Value::BIGINT(column_index.index)); + index_list.push_back(Value::BIGINT(NumericCast(column_index.index))); column_name_list.emplace_back(table.GetColumn(column_index).Name()); } diff --git a/src/function/table/system/duckdb_databases.cpp b/src/function/table/system/duckdb_databases.cpp index f0a1fc99b5b8..0aee93897a08 100644 --- a/src/function/table/system/duckdb_databases.cpp +++ b/src/function/table/system/duckdb_databases.cpp @@ -66,7 +66,7 @@ void DuckDBDatabasesFunction(ClientContext &context, TableFunctionInput &data_p, // database_name, VARCHAR output.SetValue(col++, count, attached.GetName()); // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(attached.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(attached.oid))); bool is_internal = attached.IsSystem() || attached.IsTemporary(); bool is_readonly = attached.IsReadOnly(); // path, VARCHAR diff --git a/src/function/table/system/duckdb_dependencies.cpp b/src/function/table/system/duckdb_dependencies.cpp index c262d7d47ae2..6f7370194ed5 100644 --- a/src/function/table/system/duckdb_dependencies.cpp +++ b/src/function/table/system/duckdb_dependencies.cpp @@ -85,13 +85,13 @@ void DuckDBDependenciesFunction(ClientContext &context, TableFunctionInput &data // classid, LogicalType::BIGINT output.SetValue(0, count, Value::BIGINT(0)); // objid, LogicalType::BIGINT - output.SetValue(1, count, Value::BIGINT(entry.object.oid)); + output.SetValue(1, count, Value::BIGINT(NumericCast(entry.object.oid))); // objsubid, LogicalType::INTEGER output.SetValue(2, count, Value::INTEGER(0)); // refclassid, LogicalType::BIGINT output.SetValue(3, count, Value::BIGINT(0)); // refobjid, LogicalType::BIGINT - output.SetValue(4, count, Value::BIGINT(entry.dependent.oid)); + output.SetValue(4, count, Value::BIGINT(NumericCast(entry.dependent.oid))); // refobjsubid, LogicalType::INTEGER output.SetValue(5, count, Value::INTEGER(0)); // deptype, LogicalType::VARCHAR diff --git a/src/function/table/system/duckdb_functions.cpp b/src/function/table/system/duckdb_functions.cpp index 107da17e2fde..4dd287fd3ff7 100644 --- a/src/function/table/system/duckdb_functions.cpp +++ b/src/function/table/system/duckdb_functions.cpp @@ -453,7 +453,7 @@ bool ExtractFunctionData(FunctionEntry &entry, idx_t function_idx, DataChunk &ou output.SetValue(col++, output_offset, Value(function.schema.catalog.GetName())); // database_oid, BIGINT - output.SetValue(col++, output_offset, Value::BIGINT(function.schema.catalog.GetOid())); + output.SetValue(col++, output_offset, Value::BIGINT(NumericCast(function.schema.catalog.GetOid()))); // schema_name, LogicalType::VARCHAR output.SetValue(col++, output_offset, Value(function.schema.name)); @@ -497,7 +497,7 @@ bool ExtractFunctionData(FunctionEntry &entry, idx_t function_idx, DataChunk &ou output.SetValue(col++, output_offset, Value::BOOLEAN(function.internal)); // function_oid, LogicalType::BIGINT - output.SetValue(col++, output_offset, Value::BIGINT(function.oid)); + output.SetValue(col++, output_offset, Value::BIGINT(NumericCast(function.oid))); // example, LogicalType::VARCHAR output.SetValue(col++, output_offset, entry.example.empty() ? Value() : entry.example); diff --git a/src/function/table/system/duckdb_indexes.cpp b/src/function/table/system/duckdb_indexes.cpp index 7cbb6e904544..0b8e78853b4c 100644 --- a/src/function/table/system/duckdb_indexes.cpp +++ b/src/function/table/system/duckdb_indexes.cpp @@ -91,22 +91,22 @@ void DuckDBIndexesFunction(ClientContext &context, TableFunctionInput &data_p, D // database_name, VARCHAR output.SetValue(col++, count, index.catalog.GetName()); // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(index.catalog.GetOid())); + output.SetValue(col++, count, Value::BIGINT(NumericCast(index.catalog.GetOid()))); // schema_name, VARCHAR output.SetValue(col++, count, Value(index.schema.name)); // schema_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(index.schema.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(index.schema.oid))); // index_name, VARCHAR output.SetValue(col++, count, Value(index.name)); // index_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(index.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(index.oid))); // find the table in the catalog auto &table_entry = index.schema.catalog.GetEntry(context, index.GetSchemaName(), index.GetTableName()); // table_name, VARCHAR output.SetValue(col++, count, Value(table_entry.name)); // table_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(table_entry.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(table_entry.oid))); // comment, VARCHAR output.SetValue(col++, count, Value(index.comment)); // is_unique, BOOLEAN diff --git a/src/function/table/system/duckdb_memory.cpp b/src/function/table/system/duckdb_memory.cpp index 9117c91a9a2d..e67fa51062b7 100644 --- a/src/function/table/system/duckdb_memory.cpp +++ b/src/function/table/system/duckdb_memory.cpp @@ -48,9 +48,9 @@ void DuckDBMemoryFunction(ClientContext &context, TableFunctionInput &data_p, Da // tag, VARCHAR output.SetValue(col++, count, EnumUtil::ToString(entry.tag)); // memory_usage_bytes, BIGINT - output.SetValue(col++, count, Value::BIGINT(entry.size)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(entry.size))); // temporary_storage_bytes, BIGINT - output.SetValue(col++, count, Value::BIGINT(entry.evicted_data)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(entry.evicted_data))); count++; } output.SetCardinality(count); diff --git a/src/function/table/system/duckdb_schemas.cpp b/src/function/table/system/duckdb_schemas.cpp index 65d4a48ca1d4..c8d9828e3603 100644 --- a/src/function/table/system/duckdb_schemas.cpp +++ b/src/function/table/system/duckdb_schemas.cpp @@ -66,11 +66,11 @@ void DuckDBSchemasFunction(ClientContext &context, TableFunctionInput &data_p, D // return values: idx_t col = 0; // "oid", PhysicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(entry.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(entry.oid))); // database_name, VARCHAR output.SetValue(col++, count, entry.catalog.GetName()); // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(entry.catalog.GetOid())); + output.SetValue(col++, count, Value::BIGINT(NumericCast(entry.catalog.GetOid()))); // "schema_name", PhysicalType::VARCHAR output.SetValue(col++, count, Value(entry.name)); // "comment", PhysicalType::VARCHAR diff --git a/src/function/table/system/duckdb_sequences.cpp b/src/function/table/system/duckdb_sequences.cpp index 7a38e1c65e96..44128380b37f 100644 --- a/src/function/table/system/duckdb_sequences.cpp +++ b/src/function/table/system/duckdb_sequences.cpp @@ -98,15 +98,15 @@ void DuckDBSequencesFunction(ClientContext &context, TableFunctionInput &data_p, // database_name, VARCHAR output.SetValue(col++, count, seq.catalog.GetName()); // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq.catalog.GetOid())); + output.SetValue(col++, count, Value::BIGINT(NumericCast(seq.catalog.GetOid()))); // schema_name, VARCHAR output.SetValue(col++, count, Value(seq.schema.name)); // schema_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq.schema.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(seq.schema.oid))); // sequence_name, VARCHAR output.SetValue(col++, count, Value(seq.name)); // sequence_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(seq.oid))); // comment, VARCHAR output.SetValue(col++, count, Value(seq.comment)); // temporary, BOOLEAN diff --git a/src/function/table/system/duckdb_tables.cpp b/src/function/table/system/duckdb_tables.cpp index 2503cfbf80c9..a6f7d523aa3f 100644 --- a/src/function/table/system/duckdb_tables.cpp +++ b/src/function/table/system/duckdb_tables.cpp @@ -127,15 +127,15 @@ void DuckDBTablesFunction(ClientContext &context, TableFunctionInput &data_p, Da // database_name, VARCHAR output.SetValue(col++, count, table.catalog.GetName()); // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(table.catalog.GetOid())); + output.SetValue(col++, count, Value::BIGINT(NumericCast(table.catalog.GetOid()))); // schema_name, LogicalType::VARCHAR output.SetValue(col++, count, Value(table.schema.name)); // schema_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(table.schema.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(table.schema.oid))); // table_name, LogicalType::VARCHAR output.SetValue(col++, count, Value(table.name)); // table_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(table.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(table.oid))); // comment, LogicalType::VARCHAR output.SetValue(col++, count, Value(table.comment)); // internal, LogicalType::BOOLEAN @@ -145,15 +145,16 @@ void DuckDBTablesFunction(ClientContext &context, TableFunctionInput &data_p, Da // has_primary_key, LogicalType::BOOLEAN output.SetValue(col++, count, Value::BOOLEAN(TableHasPrimaryKey(table))); // estimated_size, LogicalType::BIGINT - Value card_val = - storage_info.cardinality == DConstants::INVALID_INDEX ? Value() : Value::BIGINT(storage_info.cardinality); + Value card_val = storage_info.cardinality == DConstants::INVALID_INDEX + ? Value() + : Value::BIGINT(NumericCast(storage_info.cardinality)); output.SetValue(col++, count, card_val); // column_count, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(table.GetColumns().LogicalColumnCount())); + output.SetValue(col++, count, Value::BIGINT(NumericCast(table.GetColumns().LogicalColumnCount()))); // index_count, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(storage_info.index_info.size())); + output.SetValue(col++, count, Value::BIGINT(NumericCast(storage_info.index_info.size()))); // check_constraint_count, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(CheckConstraintCount(table))); + output.SetValue(col++, count, Value::BIGINT(NumericCast(CheckConstraintCount(table)))); // sql, LogicalType::VARCHAR output.SetValue(col++, count, Value(table.ToSQL())); diff --git a/src/function/table/system/duckdb_temporary_files.cpp b/src/function/table/system/duckdb_temporary_files.cpp index d4f043aff746..f1886639441c 100644 --- a/src/function/table/system/duckdb_temporary_files.cpp +++ b/src/function/table/system/duckdb_temporary_files.cpp @@ -45,7 +45,7 @@ void DuckDBTemporaryFilesFunction(ClientContext &context, TableFunctionInput &da // database_name, VARCHAR output.SetValue(col++, count, entry.path); // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(entry.size)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(entry.size))); count++; } output.SetCardinality(count); diff --git a/src/function/table/system/duckdb_types.cpp b/src/function/table/system/duckdb_types.cpp index 2a7fc1fd1d61..1b7ca0c97488 100644 --- a/src/function/table/system/duckdb_types.cpp +++ b/src/function/table/system/duckdb_types.cpp @@ -89,17 +89,17 @@ void DuckDBTypesFunction(ClientContext &context, TableFunctionInput &data_p, Dat // database_name, VARCHAR output.SetValue(col++, count, type_entry.catalog.GetName()); // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(type_entry.catalog.GetOid())); + output.SetValue(col++, count, Value::BIGINT(NumericCast(type_entry.catalog.GetOid()))); // schema_name, LogicalType::VARCHAR output.SetValue(col++, count, Value(type_entry.schema.name)); // schema_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(type_entry.schema.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(type_entry.schema.oid))); // type_oid, BIGINT int64_t oid; if (type_entry.internal) { - oid = int64_t(type.id()); + oid = NumericCast(type.id()); } else { - oid = type_entry.oid; + oid = NumericCast(type_entry.oid); } Value oid_val; if (data.oids.find(oid) == data.oids.end()) { @@ -114,7 +114,9 @@ void DuckDBTypesFunction(ClientContext &context, TableFunctionInput &data_p, Dat // type_size, BIGINT auto internal_type = type.InternalType(); output.SetValue(col++, count, - internal_type == PhysicalType::INVALID ? Value() : Value::BIGINT(GetTypeIdSize(internal_type))); + internal_type == PhysicalType::INVALID + ? Value() + : Value::BIGINT(NumericCast(GetTypeIdSize(internal_type)))); // logical_type, VARCHAR output.SetValue(col++, count, Value(EnumUtil::ToString(type.id()))); // type_category, VARCHAR diff --git a/src/function/table/system/duckdb_views.cpp b/src/function/table/system/duckdb_views.cpp index bd17a7ef6a50..e2eb26603468 100644 --- a/src/function/table/system/duckdb_views.cpp +++ b/src/function/table/system/duckdb_views.cpp @@ -89,15 +89,15 @@ void DuckDBViewsFunction(ClientContext &context, TableFunctionInput &data_p, Dat // database_name, VARCHAR output.SetValue(col++, count, view.catalog.GetName()); // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(view.catalog.GetOid())); + output.SetValue(col++, count, Value::BIGINT(NumericCast(view.catalog.GetOid()))); // schema_name, LogicalType::VARCHAR output.SetValue(col++, count, Value(view.schema.name)); // schema_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(view.schema.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(view.schema.oid))); // view_name, LogicalType::VARCHAR output.SetValue(col++, count, Value(view.name)); // view_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(view.oid)); + output.SetValue(col++, count, Value::BIGINT(NumericCast(view.oid))); // comment, LogicalType::VARCHARs output.SetValue(col++, count, Value(view.comment)); // internal, LogicalType::BOOLEAN @@ -105,7 +105,7 @@ void DuckDBViewsFunction(ClientContext &context, TableFunctionInput &data_p, Dat // temporary, LogicalType::BOOLEAN output.SetValue(col++, count, Value::BOOLEAN(view.temporary)); // column_count, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(view.types.size())); + output.SetValue(col++, count, Value::BIGINT(NumericCast(view.types.size()))); // sql, LogicalType::VARCHAR output.SetValue(col++, count, Value(view.ToSQL())); diff --git a/src/function/table/system/pragma_database_size.cpp b/src/function/table/system/pragma_database_size.cpp index c10eae0c9b60..ba8f2020e3d5 100644 --- a/src/function/table/system/pragma_database_size.cpp +++ b/src/function/table/system/pragma_database_size.cpp @@ -76,10 +76,10 @@ void PragmaDatabaseSizeFunction(ClientContext &context, TableFunctionInput &data idx_t col = 0; output.data[col++].SetValue(row, Value(db.GetName())); output.data[col++].SetValue(row, Value(StringUtil::BytesToHumanReadableString(ds.bytes))); - output.data[col++].SetValue(row, Value::BIGINT(ds.block_size)); - output.data[col++].SetValue(row, Value::BIGINT(ds.total_blocks)); - output.data[col++].SetValue(row, Value::BIGINT(ds.used_blocks)); - output.data[col++].SetValue(row, Value::BIGINT(ds.free_blocks)); + output.data[col++].SetValue(row, Value::BIGINT(NumericCast(ds.block_size))); + output.data[col++].SetValue(row, Value::BIGINT(NumericCast(ds.total_blocks))); + output.data[col++].SetValue(row, Value::BIGINT(NumericCast(ds.used_blocks))); + output.data[col++].SetValue(row, Value::BIGINT(NumericCast(ds.free_blocks))); output.data[col++].SetValue( row, ds.wal_size == idx_t(-1) ? Value() : Value(StringUtil::BytesToHumanReadableString(ds.wal_size))); output.data[col++].SetValue(row, data.memory_usage); diff --git a/src/function/table/system/pragma_metadata_info.cpp b/src/function/table/system/pragma_metadata_info.cpp index 92c180308db6..741a1d6c93c7 100644 --- a/src/function/table/system/pragma_metadata_info.cpp +++ b/src/function/table/system/pragma_metadata_info.cpp @@ -57,13 +57,13 @@ static void PragmaMetadataInfoFunction(ClientContext &context, TableFunctionInpu // block_id output.SetValue(col_idx++, count, Value::BIGINT(entry.block_id)); // total_blocks - output.SetValue(col_idx++, count, Value::BIGINT(entry.total_blocks)); + output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.total_blocks))); // free_blocks - output.SetValue(col_idx++, count, Value::BIGINT(entry.free_list.size())); + output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.free_list.size()))); // free_list vector list_values; for (auto &free_id : entry.free_list) { - list_values.push_back(Value::BIGINT(free_id)); + list_values.push_back(Value::BIGINT(NumericCast(free_id))); } output.SetValue(col_idx++, count, Value::LIST(LogicalType::BIGINT, std::move(list_values))); count++; diff --git a/src/function/table/system/pragma_storage_info.cpp b/src/function/table/system/pragma_storage_info.cpp index 90c60d15c040..0c33ebf01949 100644 --- a/src/function/table/system/pragma_storage_info.cpp +++ b/src/function/table/system/pragma_storage_info.cpp @@ -103,22 +103,22 @@ static void PragmaStorageInfoFunction(ClientContext &context, TableFunctionInput idx_t col_idx = 0; // row_group_id - output.SetValue(col_idx++, count, Value::BIGINT(entry.row_group_index)); + output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.row_group_index))); // column_name auto &col = columns.GetColumn(PhysicalIndex(entry.column_id)); output.SetValue(col_idx++, count, Value(col.Name())); // column_id - output.SetValue(col_idx++, count, Value::BIGINT(entry.column_id)); + output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.column_id))); // column_path output.SetValue(col_idx++, count, Value(entry.column_path)); // segment_id - output.SetValue(col_idx++, count, Value::BIGINT(entry.segment_idx)); + output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.segment_idx))); // segment_type output.SetValue(col_idx++, count, Value(entry.segment_type)); // start - output.SetValue(col_idx++, count, Value::BIGINT(entry.segment_start)); + output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.segment_start))); // count - output.SetValue(col_idx++, count, Value::BIGINT(entry.segment_count)); + output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.segment_count))); // compression output.SetValue(col_idx++, count, Value(entry.compression_type)); // stats @@ -131,7 +131,7 @@ static void PragmaStorageInfoFunction(ClientContext &context, TableFunctionInput // block_offset if (entry.persistent) { output.SetValue(col_idx++, count, Value::BIGINT(entry.block_id)); - output.SetValue(col_idx++, count, Value::BIGINT(entry.block_offset)); + output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.block_offset))); } else { output.SetValue(col_idx++, count, Value()); output.SetValue(col_idx++, count, Value()); diff --git a/src/function/table/unnest.cpp b/src/function/table/unnest.cpp index 15f3950893bf..b6485bc3d2aa 100644 --- a/src/function/table/unnest.cpp +++ b/src/function/table/unnest.cpp @@ -63,7 +63,7 @@ static unique_ptr UnnestLocalInit(ExecutionContext &con static unique_ptr UnnestInit(ClientContext &context, TableFunctionInitInput &input) { auto &bind_data = input.bind_data->Cast(); auto result = make_uniq(); - auto ref = make_uniq(bind_data.input_type, 0); + auto ref = make_uniq(bind_data.input_type, 0U); auto bound_unnest = make_uniq(ListType::GetChildType(bind_data.input_type)); bound_unnest->child = std::move(ref); result->select_list.push_back(std::move(bound_unnest)); diff --git a/src/include/duckdb/execution/index/art/node.hpp b/src/include/duckdb/execution/index/art/node.hpp index 2e170d410d91..34e2498dd01d 100644 --- a/src/include/duckdb/execution/index/art/node.hpp +++ b/src/include/duckdb/execution/index/art/node.hpp @@ -117,7 +117,7 @@ class Node : public IndexPointer { } //! Set the row ID (8th to 63rd bit) inline void SetRowId(const row_t row_id) { - Set((Get() & AND_METADATA) | row_id); + Set((Get() & AND_METADATA) | UnsafeNumericCast(row_id)); } //! Returns the type of the node, which is held in the metadata From f529ef99d5370478a3456018c9e6408185e87f37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Thu, 4 Apr 2024 11:02:10 +0200 Subject: [PATCH 043/201] moaaar --- src/common/adbc/nanoarrow/allocator.cpp | 4 +-- src/common/adbc/nanoarrow/metadata.cpp | 6 ++--- src/common/adbc/nanoarrow/schema.cpp | 13 +++++----- src/common/arrow/appender/bool_data.cpp | 1 + src/common/arrow/appender/struct_data.cpp | 2 +- src/common/arrow/appender/union_data.cpp | 3 ++- src/common/arrow/arrow_appender.cpp | 8 +++--- src/common/arrow/arrow_converter.cpp | 6 ++--- src/common/box_renderer.cpp | 2 +- src/common/compressed_file_system.cpp | 18 +++++++------ src/common/exception_format_value.cpp | 1 + src/common/file_system.cpp | 12 ++++----- src/common/gzip_file_system.cpp | 25 +++++++++++-------- src/common/hive_partitioning.cpp | 4 +-- src/common/local_file_system.cpp | 16 ++++++------ src/common/multi_file_reader.cpp | 2 +- src/common/pipe_file_system.cpp | 5 ++-- src/common/random_engine.cpp | 3 ++- src/common/re2_regex.cpp | 8 +++--- src/common/string_util.cpp | 5 ++-- .../common/arrow/appender/append_data.hpp | 1 + .../common/arrow/appender/enum_data.hpp | 6 ++--- .../duckdb/common/arrow/appender/map_data.hpp | 2 +- .../common/arrow/appender/varchar_data.hpp | 5 ++-- .../duckdb/storage/string_uncompressed.hpp | 4 +-- third_party/fsst/fsst.h | 2 +- 26 files changed, 92 insertions(+), 72 deletions(-) diff --git a/src/common/adbc/nanoarrow/allocator.cpp b/src/common/adbc/nanoarrow/allocator.cpp index 692cb58a6230..cea53b1880a7 100644 --- a/src/common/adbc/nanoarrow/allocator.cpp +++ b/src/common/adbc/nanoarrow/allocator.cpp @@ -23,11 +23,11 @@ namespace duckdb_nanoarrow { void *ArrowMalloc(int64_t size) { - return malloc(size); + return malloc(size_t(size)); } void *ArrowRealloc(void *ptr, int64_t size) { - return realloc(ptr, size); + return realloc(ptr, size_t(size)); } void ArrowFree(void *ptr) { diff --git a/src/common/adbc/nanoarrow/metadata.cpp b/src/common/adbc/nanoarrow/metadata.cpp index 742bbe4186c6..cb3009c01f3c 100644 --- a/src/common/adbc/nanoarrow/metadata.cpp +++ b/src/common/adbc/nanoarrow/metadata.cpp @@ -78,7 +78,7 @@ int64_t ArrowMetadataSizeOf(const char *metadata) { int64_t size = sizeof(int32_t); while (ArrowMetadataReaderRead(&reader, &key, &value) == NANOARROW_OK) { - size += sizeof(int32_t) + key.n_bytes + sizeof(int32_t) + value.n_bytes; + size += sizeof(int32_t) + uint64_t(key.n_bytes) + sizeof(int32_t) + uint64_t(value.n_bytes); } return size; @@ -89,7 +89,7 @@ ArrowErrorCode ArrowMetadataGetValue(const char *metadata, const char *key, cons struct ArrowStringView target_key_view = {key, static_cast(strlen(key))}; value_out->data = default_value; if (default_value != NULL) { - value_out->n_bytes = strlen(default_value); + value_out->n_bytes = int64_t(strlen(default_value)); } else { value_out->n_bytes = 0; } @@ -101,7 +101,7 @@ ArrowErrorCode ArrowMetadataGetValue(const char *metadata, const char *key, cons while (ArrowMetadataReaderRead(&reader, &key_view, &value) == NANOARROW_OK) { int key_equal = target_key_view.n_bytes == key_view.n_bytes && - strncmp(target_key_view.data, key_view.data, key_view.n_bytes) == 0; + strncmp(target_key_view.data, key_view.data, size_t(key_view.n_bytes)) == 0; if (key_equal) { value_out->data = value.data; value_out->n_bytes = value.n_bytes; diff --git a/src/common/adbc/nanoarrow/schema.cpp b/src/common/adbc/nanoarrow/schema.cpp index 1ed36f1f8f5f..38d1b314ff90 100644 --- a/src/common/adbc/nanoarrow/schema.cpp +++ b/src/common/adbc/nanoarrow/schema.cpp @@ -318,7 +318,7 @@ ArrowErrorCode ArrowSchemaSetFormat(struct ArrowSchema *schema, const char *form if (format != NULL) { size_t format_size = strlen(format) + 1; - schema->format = (const char *)ArrowMalloc(format_size); + schema->format = (const char *)ArrowMalloc(int64_t(format_size)); if (schema->format == NULL) { return ENOMEM; } @@ -338,7 +338,7 @@ ArrowErrorCode ArrowSchemaSetName(struct ArrowSchema *schema, const char *name) if (name != NULL) { size_t name_size = strlen(name) + 1; - schema->name = (const char *)ArrowMalloc(name_size); + schema->name = (const char *)ArrowMalloc(int64_t(name_size)); if (schema->name == NULL) { return ENOMEM; } @@ -357,13 +357,13 @@ ArrowErrorCode ArrowSchemaSetMetadata(struct ArrowSchema *schema, const char *me } if (metadata != NULL) { - size_t metadata_size = ArrowMetadataSizeOf(metadata); + auto metadata_size = ArrowMetadataSizeOf(metadata); schema->metadata = (const char *)ArrowMalloc(metadata_size); if (schema->metadata == NULL) { return ENOMEM; } - memcpy((void *)schema->metadata, metadata, metadata_size); + memcpy((void *)schema->metadata, metadata, size_t(metadata_size)); } else { schema->metadata = NULL; } @@ -377,7 +377,8 @@ ArrowErrorCode ArrowSchemaAllocateChildren(struct ArrowSchema *schema, int64_t n } if (n_children > 0) { - schema->children = (struct ArrowSchema **)ArrowMalloc(n_children * sizeof(struct ArrowSchema *)); + schema->children = + (struct ArrowSchema **)ArrowMalloc(int64_t(uint64_t(n_children) * sizeof(struct ArrowSchema *))); if (schema->children == NULL) { return ENOMEM; @@ -385,7 +386,7 @@ ArrowErrorCode ArrowSchemaAllocateChildren(struct ArrowSchema *schema, int64_t n schema->n_children = n_children; - memset(schema->children, 0, n_children * sizeof(struct ArrowSchema *)); + memset(schema->children, 0, uint64_t(n_children) * sizeof(struct ArrowSchema *)); for (int64_t i = 0; i < n_children; i++) { schema->children[i] = (struct ArrowSchema *)ArrowMalloc(sizeof(struct ArrowSchema)); diff --git a/src/common/arrow/appender/bool_data.cpp b/src/common/arrow/appender/bool_data.cpp index 6a5f67287387..d30b39339076 100644 --- a/src/common/arrow/appender/bool_data.cpp +++ b/src/common/arrow/appender/bool_data.cpp @@ -6,6 +6,7 @@ namespace duckdb { void ArrowBoolData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { auto byte_count = (capacity + 7) / 8; result.main_buffer.reserve(byte_count); + (void)AppendValidity; // silence a compiler warning about unused static function } void ArrowBoolData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { diff --git a/src/common/arrow/appender/struct_data.cpp b/src/common/arrow/appender/struct_data.cpp index ce74a92a7b51..b2afa62d145e 100644 --- a/src/common/arrow/appender/struct_data.cpp +++ b/src/common/arrow/appender/struct_data.cpp @@ -35,7 +35,7 @@ void ArrowStructData::Finalize(ArrowAppendData &append_data, const LogicalType & auto &child_types = StructType::GetChildTypes(type); ArrowAppender::AddChildren(append_data, child_types.size()); result->children = append_data.child_pointers.data(); - result->n_children = child_types.size(); + result->n_children = NumericCast(child_types.size()); for (idx_t i = 0; i < child_types.size(); i++) { auto &child_type = child_types[i].second; append_data.child_arrays[i] = *ArrowAppender::FinalizeChild(child_type, std::move(append_data.child_data[i])); diff --git a/src/common/arrow/appender/union_data.cpp b/src/common/arrow/appender/union_data.cpp index 9797aa9f4772..270129c14dd1 100644 --- a/src/common/arrow/appender/union_data.cpp +++ b/src/common/arrow/appender/union_data.cpp @@ -14,6 +14,7 @@ void ArrowUnionData::Initialize(ArrowAppendData &result, const LogicalType &type auto child_buffer = ArrowAppender::InitializeChild(child.second, capacity, result.options); result.child_data.push_back(std::move(child_buffer)); } + (void)AppendValidity; // silence a compiler warning about unused static function } void ArrowUnionData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { @@ -61,7 +62,7 @@ void ArrowUnionData::Finalize(ArrowAppendData &append_data, const LogicalType &t auto &child_types = UnionType::CopyMemberTypes(type); ArrowAppender::AddChildren(append_data, child_types.size()); result->children = append_data.child_pointers.data(); - result->n_children = child_types.size(); + result->n_children = NumericCast(child_types.size()); for (idx_t i = 0; i < child_types.size(); i++) { auto &child_type = child_types[i].second; append_data.child_arrays[i] = *ArrowAppender::FinalizeChild(child_type, std::move(append_data.child_data[i])); diff --git a/src/common/arrow/arrow_appender.cpp b/src/common/arrow/arrow_appender.cpp index 6dc0c14b3d24..e5ed4343c322 100644 --- a/src/common/arrow/arrow_appender.cpp +++ b/src/common/arrow/arrow_appender.cpp @@ -70,8 +70,8 @@ ArrowArray *ArrowAppender::FinalizeChild(const LogicalType &type, unique_ptroffset = 0; result->dictionary = nullptr; result->buffers = append_data.buffers.data(); - result->null_count = append_data.null_count; - result->length = append_data.row_count; + result->null_count = NumericCast(append_data.null_count); + result->length = NumericCast(append_data.row_count); result->buffers[0] = append_data.validity.data(); if (append_data.finalize) { @@ -90,10 +90,10 @@ ArrowArray ArrowAppender::Finalize() { ArrowArray result; AddChildren(*root_holder, types.size()); result.children = root_holder->child_pointers.data(); - result.n_children = types.size(); + result.n_children = NumericCast(types.size()); // Configure root array - result.length = row_count; + result.length = NumericCast(row_count); result.n_buffers = 1; result.buffers = root_holder->buffers.data(); // there is no actual buffer there since we don't have NULLs result.offset = 0; diff --git a/src/common/arrow/arrow_converter.cpp b/src/common/arrow/arrow_converter.cpp index 62be691b484f..b9f9ef836ed5 100644 --- a/src/common/arrow/arrow_converter.cpp +++ b/src/common/arrow/arrow_converter.cpp @@ -206,7 +206,7 @@ void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, co case LogicalTypeId::STRUCT: { child.format = "+s"; auto &child_types = StructType::GetChildTypes(type); - child.n_children = child_types.size(); + child.n_children = NumericCast(child_types.size()); root_holder.nested_children.emplace_back(); root_holder.nested_children.back().resize(child_types.size()); root_holder.nested_children_ptr.emplace_back(); @@ -251,7 +251,7 @@ void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, co std::string format = "+us:"; auto &child_types = UnionType::CopyMemberTypes(type); - child.n_children = child_types.size(); + child.n_children = NumericCast(child_types.size()); root_holder.nested_children.emplace_back(); root_holder.nested_children.back().resize(child_types.size()); root_holder.nested_children_ptr.emplace_back(); @@ -323,7 +323,7 @@ void ArrowConverter::ToArrowSchema(ArrowSchema *out_schema, const vectorchildren_ptrs[i] = &root_holder->children[i]; } out_schema->children = root_holder->children_ptrs.data(); - out_schema->n_children = column_count; + out_schema->n_children = NumericCast(column_count); // Store the schema out_schema->format = "+s"; // struct apparently diff --git a/src/common/box_renderer.cpp b/src/common/box_renderer.cpp index fc309150e1b7..629f6349fe69 100644 --- a/src/common/box_renderer.cpp +++ b/src/common/box_renderer.cpp @@ -399,7 +399,7 @@ vector BoxRenderer::ComputeRenderWidths(const vector &names, cons // e.g. if we have 10 columns, we remove #5, then #4, then #6, then #3, then #7, etc int64_t offset = 0; while (total_length > max_width) { - idx_t c = column_count / 2 + offset; + idx_t c = column_count / 2 + NumericCast(offset); total_length -= widths[c] + 3; pruned_columns.insert(c); if (offset >= 0) { diff --git a/src/common/compressed_file_system.cpp b/src/common/compressed_file_system.cpp index 725393904127..5727d4d7db93 100644 --- a/src/common/compressed_file_system.cpp +++ b/src/common/compressed_file_system.cpp @@ -1,4 +1,5 @@ #include "duckdb/common/compressed_file_system.hpp" +#include "duckdb/common/numeric_utils.hpp" namespace duckdb { @@ -37,7 +38,9 @@ int64_t CompressedFile::ReadData(void *buffer, int64_t remaining) { // first check if there are input bytes available in the output buffers if (stream_data.out_buff_start != stream_data.out_buff_end) { // there is! copy it into the output buffer - idx_t available = MinValue(remaining, stream_data.out_buff_end - stream_data.out_buff_start); + auto available = + MinValue(UnsafeNumericCast(remaining), + UnsafeNumericCast(stream_data.out_buff_end - stream_data.out_buff_start)); memcpy(data_ptr_t(buffer) + total_read, stream_data.out_buff_start, available); // increment the total read variables as required @@ -46,11 +49,11 @@ int64_t CompressedFile::ReadData(void *buffer, int64_t remaining) { remaining -= available; if (remaining == 0) { // done! read enough - return total_read; + return UnsafeNumericCast(total_read); } } if (!stream_wrapper) { - return total_read; + return UnsafeNumericCast(total_read); } // ran out of buffer: read more data from the child stream @@ -63,10 +66,11 @@ int64_t CompressedFile::ReadData(void *buffer, int64_t remaining) { if (stream_data.refresh && (stream_data.in_buff_end == stream_data.in_buff.get() + stream_data.in_buf_size)) { auto bufrem = stream_data.in_buff_end - stream_data.in_buff_start; // buffer not empty, move remaining bytes to the beginning - memmove(stream_data.in_buff.get(), stream_data.in_buff_start, bufrem); + memmove(stream_data.in_buff.get(), stream_data.in_buff_start, UnsafeNumericCast(bufrem)); stream_data.in_buff_start = stream_data.in_buff.get(); // refill the rest of input buffer - auto sz = child_handle->Read(stream_data.in_buff_start + bufrem, stream_data.in_buf_size - bufrem); + auto sz = child_handle->Read(stream_data.in_buff_start + bufrem, + stream_data.in_buf_size - UnsafeNumericCast(bufrem)); stream_data.in_buff_end = stream_data.in_buff_start + bufrem + sz; if (sz <= 0) { stream_wrapper.reset(); @@ -92,7 +96,7 @@ int64_t CompressedFile::ReadData(void *buffer, int64_t remaining) { stream_wrapper.reset(); } } - return total_read; + return UnsafeNumericCast(total_read); } int64_t CompressedFile::WriteData(data_ptr_t buffer, int64_t nr_bytes) { @@ -134,7 +138,7 @@ void CompressedFileSystem::Reset(FileHandle &handle) { int64_t CompressedFileSystem::GetFileSize(FileHandle &handle) { auto &compressed_file = handle.Cast(); - return compressed_file.child_handle->GetFileSize(); + return NumericCast(compressed_file.child_handle->GetFileSize()); } bool CompressedFileSystem::OnDiskFile(FileHandle &handle) { diff --git a/src/common/exception_format_value.cpp b/src/common/exception_format_value.cpp index 1eb9d45939f5..ddef4e10c872 100644 --- a/src/common/exception_format_value.cpp +++ b/src/common/exception_format_value.cpp @@ -1,5 +1,6 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/types.hpp" +#include "duckdb/common/helper.hpp" // defines DUCKDB_EXPLICIT_FALLTHROUGH which fmt will use to annotate #include "fmt/format.h" #include "fmt/printf.h" #include "duckdb/common/types/hugeint.hpp" diff --git a/src/common/file_system.cpp b/src/common/file_system.cpp index 7450c1fbec9a..ed912d590bc7 100644 --- a/src/common/file_system.cpp +++ b/src/common/file_system.cpp @@ -517,7 +517,7 @@ FileHandle::~FileHandle() { } int64_t FileHandle::Read(void *buffer, idx_t nr_bytes) { - return file_system.Read(*this, buffer, nr_bytes); + return file_system.Read(*this, buffer, UnsafeNumericCast(nr_bytes)); } bool FileHandle::Trim(idx_t offset_bytes, idx_t length_bytes) { @@ -525,15 +525,15 @@ bool FileHandle::Trim(idx_t offset_bytes, idx_t length_bytes) { } int64_t FileHandle::Write(void *buffer, idx_t nr_bytes) { - return file_system.Write(*this, buffer, nr_bytes); + return file_system.Write(*this, buffer, UnsafeNumericCast(nr_bytes)); } void FileHandle::Read(void *buffer, idx_t nr_bytes, idx_t location) { - file_system.Read(*this, buffer, nr_bytes, location); + file_system.Read(*this, buffer, UnsafeNumericCast(nr_bytes), location); } void FileHandle::Write(void *buffer, idx_t nr_bytes, idx_t location) { - file_system.Write(*this, buffer, nr_bytes, location); + file_system.Write(*this, buffer, UnsafeNumericCast(nr_bytes), location); } void FileHandle::Seek(idx_t location) { @@ -560,7 +560,7 @@ string FileHandle::ReadLine() { string result; char buffer[1]; while (true) { - idx_t tuples_read = Read(buffer, 1); + auto tuples_read = UnsafeNumericCast(Read(buffer, 1)); if (tuples_read == 0 || buffer[0] == '\n') { return result; } @@ -575,7 +575,7 @@ bool FileHandle::OnDiskFile() { } idx_t FileHandle::GetFileSize() { - return file_system.GetFileSize(*this); + return NumericCast(file_system.GetFileSize(*this)); } void FileHandle::Sync() { diff --git a/src/common/gzip_file_system.cpp b/src/common/gzip_file_system.cpp index d24d5e57149f..a4e39b90d39b 100644 --- a/src/common/gzip_file_system.cpp +++ b/src/common/gzip_file_system.cpp @@ -120,13 +120,13 @@ void MiniZStreamWrapper::Initialize(CompressedFile &file, bool write) { } else { idx_t data_start = GZIP_HEADER_MINSIZE; auto read_count = file.child_handle->Read(gzip_hdr, GZIP_HEADER_MINSIZE); - GZipFileSystem::VerifyGZIPHeader(gzip_hdr, read_count); + GZipFileSystem::VerifyGZIPHeader(gzip_hdr, NumericCast(read_count)); // Skip over the extra field if necessary if (gzip_hdr[3] & GZIP_FLAG_EXTRA) { uint8_t gzip_xlen[2]; file.child_handle->Seek(data_start); file.child_handle->Read(gzip_xlen, 2); - idx_t xlen = (uint8_t)gzip_xlen[0] | (uint8_t)gzip_xlen[1] << 8; + auto xlen = NumericCast((uint8_t)gzip_xlen[0] | (uint8_t)gzip_xlen[1] << 8); data_start += xlen + 2; } // Skip over the file name if necessary @@ -160,7 +160,7 @@ bool MiniZStreamWrapper::Read(StreamData &sd) { GZipFileSystem::VerifyGZIPHeader(gzip_hdr, GZIP_HEADER_MINSIZE); body_ptr += GZIP_HEADER_MINSIZE; if (gzip_hdr[3] & GZIP_FLAG_EXTRA) { - idx_t xlen = (uint8_t)*body_ptr | (uint8_t) * (body_ptr + 1) << 8; + auto xlen = NumericCast((uint8_t)*body_ptr | (uint8_t) * (body_ptr + 1) << 8); body_ptr += xlen + 2; if (GZIP_FOOTER_SIZE + GZIP_HEADER_MINSIZE + 2 + xlen >= GZIP_HEADER_MAXSIZE) { throw InternalException("Extra field resulting in GZIP header larger than defined maximum (%d)", @@ -170,7 +170,7 @@ bool MiniZStreamWrapper::Read(StreamData &sd) { if (gzip_hdr[3] & GZIP_FLAG_NAME) { char c; do { - c = *body_ptr; + c = UnsafeNumericCast(*body_ptr); body_ptr++; } while (c != '\0' && body_ptr < sd.in_buff_end); if ((idx_t)(body_ptr - sd.in_buff_start) >= GZIP_HEADER_MAXSIZE) { @@ -217,12 +217,13 @@ bool MiniZStreamWrapper::Read(StreamData &sd) { void MiniZStreamWrapper::Write(CompressedFile &file, StreamData &sd, data_ptr_t uncompressed_data, int64_t uncompressed_size) { // update the src and the total size - crc = duckdb_miniz::mz_crc32(crc, reinterpret_cast(uncompressed_data), uncompressed_size); - total_size += uncompressed_size; + crc = duckdb_miniz::mz_crc32(crc, reinterpret_cast(uncompressed_data), + UnsafeNumericCast(uncompressed_size)); + total_size += UnsafeNumericCast(uncompressed_size); auto remaining = uncompressed_size; while (remaining > 0) { - idx_t output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start; + auto output_remaining = UnsafeNumericCast((sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start); mz_stream_ptr->next_in = reinterpret_cast(uncompressed_data); mz_stream_ptr->avail_in = NumericCast(remaining); @@ -237,10 +238,11 @@ void MiniZStreamWrapper::Write(CompressedFile &file, StreamData &sd, data_ptr_t sd.out_buff_start += output_remaining - mz_stream_ptr->avail_out; if (mz_stream_ptr->avail_out == 0) { // no more output buffer available: flush - file.child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get()); + file.child_handle->Write(sd.out_buff.get(), + UnsafeNumericCast(sd.out_buff_start - sd.out_buff.get())); sd.out_buff_start = sd.out_buff.get(); } - idx_t written = remaining - mz_stream_ptr->avail_in; + auto written = UnsafeNumericCast(remaining - mz_stream_ptr->avail_in); uncompressed_data += written; remaining = mz_stream_ptr->avail_in; } @@ -258,7 +260,8 @@ void MiniZStreamWrapper::FlushStream() { auto res = mz_deflate(mz_stream_ptr.get(), duckdb_miniz::MZ_FINISH); sd.out_buff_start += (output_remaining - mz_stream_ptr->avail_out); if (sd.out_buff_start > sd.out_buff.get()) { - file->child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get()); + file->child_handle->Write(sd.out_buff.get(), + UnsafeNumericCast(sd.out_buff_start - sd.out_buff.get())); sd.out_buff_start = sd.out_buff.get(); } if (res == duckdb_miniz::MZ_STREAM_END) { @@ -354,7 +357,7 @@ string GZipFileSystem::UncompressGZIPString(const string &in) { throw InternalException("Failed to initialize miniz"); } - auto bytes_remaining = in.size() - (body_ptr - in.data()); + auto bytes_remaining = in.size() - NumericCast(body_ptr - in.data()); mz_stream_ptr->next_in = const_uchar_ptr_cast(body_ptr); mz_stream_ptr->avail_in = NumericCast(bytes_remaining); diff --git a/src/common/hive_partitioning.cpp b/src/common/hive_partitioning.cpp index 0bb3306b7805..5dc8248cc3be 100644 --- a/src/common/hive_partitioning.cpp +++ b/src/common/hive_partitioning.cpp @@ -359,8 +359,8 @@ void HivePartitionedColumnData::GrowPartitions(PartitionedColumnDataAppendState void HivePartitionedColumnData::SynchronizeLocalMap() { // Synchronise global map into local, may contain changes from other threads too - for (auto it = global_state->partitions.begin() + local_partition_map.size(); it < global_state->partitions.end(); - it++) { + for (auto it = global_state->partitions.begin() + NumericCast(local_partition_map.size()); + it < global_state->partitions.end(); it++) { local_partition_map[(*it)->first] = (*it)->second; } } diff --git a/src/common/local_file_system.cpp b/src/common/local_file_system.cpp index b05189b4ea0f..43ef0a8badda 100644 --- a/src/common/local_file_system.cpp +++ b/src/common/local_file_system.cpp @@ -389,7 +389,7 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF void LocalFileSystem::SetFilePointer(FileHandle &handle, idx_t location) { int fd = handle.Cast().fd; - off_t offset = lseek(fd, location, SEEK_SET); + off_t offset = lseek(fd, UnsafeNumericCast(location), SEEK_SET); if (offset == (off_t)-1) { throw IOException("Could not seek to location %lld for file \"%s\": %s", {{"errno", std::to_string(errno)}}, location, handle.path, strerror(errno)); @@ -403,14 +403,15 @@ idx_t LocalFileSystem::GetFilePointer(FileHandle &handle) { throw IOException("Could not get file position file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, strerror(errno)); } - return position; + return UnsafeNumericCast(position); } void LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { int fd = handle.Cast().fd; auto read_buffer = char_ptr_cast(buffer); while (nr_bytes > 0) { - int64_t bytes_read = pread(fd, read_buffer, nr_bytes, location); + int64_t bytes_read = + pread(fd, read_buffer, UnsafeNumericCast(nr_bytes), UnsafeNumericCast(location)); if (bytes_read == -1) { throw IOException("Could not read from file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, strerror(errno)); @@ -422,13 +423,13 @@ void LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, i } read_buffer += bytes_read; nr_bytes -= bytes_read; - location += bytes_read; + location += UnsafeNumericCast(bytes_read); } } int64_t LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { int fd = handle.Cast().fd; - int64_t bytes_read = read(fd, buffer, nr_bytes); + int64_t bytes_read = read(fd, buffer, UnsafeNumericCast(nr_bytes)); if (bytes_read == -1) { throw IOException("Could not read from file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, strerror(errno)); @@ -440,7 +441,8 @@ void LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, int fd = handle.Cast().fd; auto write_buffer = char_ptr_cast(buffer); while (nr_bytes > 0) { - int64_t bytes_written = pwrite(fd, write_buffer, nr_bytes, location); + int64_t bytes_written = + pwrite(fd, write_buffer, UnsafeNumericCast(nr_bytes), UnsafeNumericCast(location)); if (bytes_written < 0) { throw IOException("Could not write file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, strerror(errno)); @@ -451,7 +453,7 @@ void LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, } write_buffer += bytes_written; nr_bytes -= bytes_written; - location += bytes_written; + location += UnsafeNumericCast(bytes_written); } } diff --git a/src/common/multi_file_reader.cpp b/src/common/multi_file_reader.cpp index bbb59705e29f..05ff0cb8ffa3 100644 --- a/src/common/multi_file_reader.cpp +++ b/src/common/multi_file_reader.cpp @@ -176,7 +176,7 @@ MultiFileReaderBindData MultiFileReader::BindOptions(MultiFileReaderOptions &opt auto lookup = std::find(names.begin(), names.end(), part.first); if (lookup != names.end()) { // hive partitioning column also exists in file - override - auto idx = lookup - names.begin(); + auto idx = NumericCast(lookup - names.begin()); hive_partitioning_index = idx; return_types[idx] = options.GetHiveLogicalType(part.first); } else { diff --git a/src/common/pipe_file_system.cpp b/src/common/pipe_file_system.cpp index 39a1877dd9bf..d6eb2c6aff48 100644 --- a/src/common/pipe_file_system.cpp +++ b/src/common/pipe_file_system.cpp @@ -2,6 +2,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/file_system.hpp" #include "duckdb/common/helper.hpp" +#include "duckdb/common/numeric_utils.hpp" namespace duckdb { class PipeFile : public FileHandle { @@ -22,10 +23,10 @@ class PipeFile : public FileHandle { }; int64_t PipeFile::ReadChunk(void *buffer, int64_t nr_bytes) { - return child_handle->Read(buffer, nr_bytes); + return child_handle->Read(buffer, UnsafeNumericCast(nr_bytes)); } int64_t PipeFile::WriteChunk(void *buffer, int64_t nr_bytes) { - return child_handle->Write(buffer, nr_bytes); + return child_handle->Write(buffer, UnsafeNumericCast(nr_bytes)); } void PipeFileSystem::Reset(FileHandle &handle) { diff --git a/src/common/random_engine.cpp b/src/common/random_engine.cpp index 0c9aec4e5119..acbda8b3e194 100644 --- a/src/common/random_engine.cpp +++ b/src/common/random_engine.cpp @@ -1,4 +1,5 @@ #include "duckdb/common/random_engine.hpp" +#include "duckdb/common/numeric_utils.hpp" #include "pcg_random.hpp" #include @@ -15,7 +16,7 @@ RandomEngine::RandomEngine(int64_t seed) : random_state(make_uniq() if (seed < 0) { random_state->pcg.seed(pcg_extras::seed_seq_from()); } else { - random_state->pcg.seed(seed); + random_state->pcg.seed(NumericCast(seed)); } } diff --git a/src/common/re2_regex.cpp b/src/common/re2_regex.cpp index f22b681acfe6..b67b3c6524c8 100644 --- a/src/common/re2_regex.cpp +++ b/src/common/re2_regex.cpp @@ -17,10 +17,11 @@ bool RegexSearchInternal(const char *input, Match &match, const Regex &r, RE2::A size_t end) { auto ®ex = r.GetRegex(); duckdb::vector target_groups; - auto group_count = regex.NumberOfCapturingGroups() + 1; + auto group_count = duckdb::UnsafeNumericCast(regex.NumberOfCapturingGroups() + 1); target_groups.resize(group_count); match.groups.clear(); - if (!regex.Match(StringPiece(input), start, end, anchor, target_groups.data(), group_count)) { + if (!regex.Match(StringPiece(input), start, end, anchor, target_groups.data(), + duckdb::UnsafeNumericCast(group_count))) { return false; } for (auto &group : target_groups) { @@ -41,7 +42,8 @@ bool RegexMatch(const std::string &input, Match &match, const Regex ®ex) { } bool RegexMatch(const char *start, const char *end, Match &match, const Regex ®ex) { - return RegexSearchInternal(start, match, regex, RE2::ANCHOR_BOTH, 0, end - start); + return RegexSearchInternal(start, match, regex, RE2::ANCHOR_BOTH, 0, + duckdb::UnsafeNumericCast(end - start)); } bool RegexMatch(const std::string &input, const Regex ®ex) { diff --git a/src/common/string_util.cpp b/src/common/string_util.cpp index 7898a343a927..7d33b1cd80f7 100644 --- a/src/common/string_util.cpp +++ b/src/common/string_util.cpp @@ -203,7 +203,8 @@ string StringUtil::Upper(const string &str) { string StringUtil::Lower(const string &str) { string copy(str); - transform(copy.begin(), copy.end(), copy.begin(), [](unsigned char c) { return StringUtil::CharacterToLower(c); }); + transform(copy.begin(), copy.end(), copy.begin(), + [](unsigned char c) { return StringUtil::CharacterToLower(UnsafeNumericCast(c)); }); return (copy); } @@ -215,7 +216,7 @@ bool StringUtil::IsLower(const string &str) { uint64_t StringUtil::CIHash(const string &str) { uint32_t hash = 0; for (auto c : str) { - hash += StringUtil::CharacterToLower(c); + hash += UnsafeNumericCast(StringUtil::CharacterToLower(UnsafeNumericCast(c))); hash += hash << 10; hash ^= hash >> 6; } diff --git a/src/include/duckdb/common/arrow/appender/append_data.hpp b/src/include/duckdb/common/arrow/appender/append_data.hpp index 8cbbf02c8ee7..52845cbd2dc6 100644 --- a/src/include/duckdb/common/arrow/appender/append_data.hpp +++ b/src/include/duckdb/common/arrow/appender/append_data.hpp @@ -59,6 +59,7 @@ struct ArrowAppendData { //===--------------------------------------------------------------------===// // Append Helper Functions //===--------------------------------------------------------------------===// + static void GetBitPosition(idx_t row_idx, idx_t ¤t_byte, uint8_t ¤t_bit) { current_byte = row_idx / 8; current_bit = row_idx % 8; diff --git a/src/include/duckdb/common/arrow/appender/enum_data.hpp b/src/include/duckdb/common/arrow/appender/enum_data.hpp index e6ed5c107b0a..f7f73e055141 100644 --- a/src/include/duckdb/common/arrow/appender/enum_data.hpp +++ b/src/include/duckdb/common/arrow/appender/enum_data.hpp @@ -44,14 +44,14 @@ struct ArrowEnumData : public ArrowScalarBaseData { auto string_length = GetLength(data[i]); // append the offset data - auto current_offset = last_offset + string_length; - offset_data[offset_idx] = UnsafeNumericCast(current_offset); + auto current_offset = UnsafeNumericCast(last_offset) + string_length; + offset_data[offset_idx] = UnsafeNumericCast(current_offset); // resize the string buffer if required, and write the string data append_data.aux_buffer.resize(current_offset); WriteData(append_data.aux_buffer.data() + last_offset, data[i]); - last_offset = UnsafeNumericCast(current_offset); + last_offset = UnsafeNumericCast(current_offset); } append_data.row_count += size; } diff --git a/src/include/duckdb/common/arrow/appender/map_data.hpp b/src/include/duckdb/common/arrow/appender/map_data.hpp index e881c532a15a..630ff1664f9c 100644 --- a/src/include/duckdb/common/arrow/appender/map_data.hpp +++ b/src/include/duckdb/common/arrow/appender/map_data.hpp @@ -75,7 +75,7 @@ struct ArrowMapData { struct_result->children = struct_data.child_pointers.data(); struct_result->n_buffers = 1; struct_result->n_children = struct_child_count; - struct_result->length = struct_data.child_data[0]->row_count; + struct_result->length = NumericCast(struct_data.child_data[0]->row_count); append_data.child_arrays[0] = *struct_result; diff --git a/src/include/duckdb/common/arrow/appender/varchar_data.hpp b/src/include/duckdb/common/arrow/appender/varchar_data.hpp index d6ffcd9db868..753756b9ea0b 100644 --- a/src/include/duckdb/common/arrow/appender/varchar_data.hpp +++ b/src/include/duckdb/common/arrow/appender/varchar_data.hpp @@ -77,8 +77,9 @@ struct ArrowVarcharData { auto string_length = OP::GetLength(data[source_idx]); // append the offset data - auto current_offset = last_offset + string_length; - if (!LARGE_STRING && (int64_t)last_offset + string_length > NumericLimits::Maximum()) { + auto current_offset = UnsafeNumericCast(last_offset) + string_length; + if (!LARGE_STRING && + UnsafeNumericCast(last_offset) + string_length > NumericLimits::Maximum()) { D_ASSERT(append_data.options.arrow_offset_size == ArrowOffsetSize::REGULAR); throw InvalidInputException( "Arrow Appender: The maximum total string size for regular string buffers is " diff --git a/src/include/duckdb/storage/string_uncompressed.hpp b/src/include/duckdb/storage/string_uncompressed.hpp index 6ad17066e58f..dacad03a0edb 100644 --- a/src/include/duckdb/storage/string_uncompressed.hpp +++ b/src/include/duckdb/storage/string_uncompressed.hpp @@ -148,7 +148,7 @@ struct UncompressedStringStorage { // place the dictionary offset into the set of vectors // note: for overflow strings we write negative value - result_data[target_idx] = -(*dictionary_size); + result_data[target_idx] = NumericCast(-(*dictionary_size)); } else { // string fits in block, append to dictionary and increment dictionary position D_ASSERT(string_length < NumericLimits::Maximum()); @@ -159,7 +159,7 @@ struct UncompressedStringStorage { memcpy(dict_pos, source_data[source_idx].GetData(), string_length); // place the dictionary offset into the set of vectors - result_data[target_idx] = *dictionary_size; + result_data[target_idx] = NumericCast(*dictionary_size); } D_ASSERT(RemainingSpace(segment, handle) <= Storage::BLOCK_SIZE); #ifdef DEBUG diff --git a/third_party/fsst/fsst.h b/third_party/fsst/fsst.h index 6970dedc053f..553c143cd021 100644 --- a/third_party/fsst/fsst.h +++ b/third_party/fsst/fsst.h @@ -184,7 +184,7 @@ duckdb_fsst_decompress( code = strIn[posIn++]; FSST_UNALIGNED_STORE(strOut+posOut, symbol[code]); posOut += len[code]; code = strIn[posIn++]; FSST_UNALIGNED_STORE(strOut+posOut, symbol[code]); posOut += len[code]; } else { - unsigned long firstEscapePos=__builtin_ctzl((unsigned long long) escapeMask)>>3; + unsigned long firstEscapePos=static_cast(__builtin_ctzl((unsigned long long) escapeMask)>>3); switch(firstEscapePos) { /* Duff's device */ case 3: code = strIn[posIn++]; FSST_UNALIGNED_STORE(strOut+posOut, symbol[code]); posOut += len[code]; DUCKDB_FSST_EXPLICIT_FALLTHROUGH; From ae853ab2467374f1ac2508733f7f2bd5b1fbd41f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Thu, 4 Apr 2024 12:49:27 +0200 Subject: [PATCH 044/201] moaaaar --- CMakeLists.txt | 1 - src/catalog/duck_catalog.cpp | 13 ++++++++++++ src/common/operator/cast_operators.cpp | 13 ++++++++---- src/common/operator/string_cast.cpp | 2 +- src/common/row_operations/row_aggregate.cpp | 20 +++++++++---------- src/common/row_operations/row_external.cpp | 7 ++++--- .../serializer/buffered_file_reader.cpp | 8 ++++---- .../serializer/buffered_file_writer.cpp | 17 ++++++++-------- src/common/sort/comparators.cpp | 2 +- src/common/sort/merge_sorter.cpp | 4 ++-- src/common/sort/partition_state.cpp | 12 +++++------ src/common/sort/sort_state.cpp | 6 +++--- src/common/sort/sorted_block.cpp | 10 +++++----- src/common/types/bit.cpp | 2 +- src/common/types/blob.cpp | 6 +++--- src/common/types/cast_helpers.cpp | 2 +- src/common/types/conflict_manager.cpp | 2 +- src/common/types/date.cpp | 4 ++-- src/common/types/decimal.cpp | 12 +++++------ src/common/types/hash.cpp | 2 +- src/common/types/hugeint.cpp | 4 ++-- src/function/table/system/duckdb_tables.cpp | 5 +++-- .../common/operator/integer_cast_operator.hpp | 9 +++++---- .../duckdb/common/sort/duckdb_pdqsort.hpp | 6 +++--- .../duckdb/common/types/cast_helpers.hpp | 16 ++++++++------- .../duckdb/common/types/validity_mask.hpp | 2 +- .../rule/ordered_aggregate_optimizer.cpp | 1 + 27 files changed, 106 insertions(+), 82 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index aebc060c3bd7..2d50d9c21828 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -553,7 +553,6 @@ include_directories(third_party/fast_float) include_directories(third_party/re2) include_directories(third_party/miniz) include_directories(third_party/utf8proc/include) -include_directories(third_party/miniparquet) include_directories(third_party/concurrentqueue) include_directories(third_party/pcg) include_directories(third_party/tdigest) diff --git a/src/catalog/duck_catalog.cpp b/src/catalog/duck_catalog.cpp index 56717e90f09f..ebf1bb53b72c 100644 --- a/src/catalog/duck_catalog.cpp +++ b/src/catalog/duck_catalog.cpp @@ -1,3 +1,16 @@ +#include "duckdb/catalog/duck_catalog.hpp" +#include "duckdb/catalog/dependency_manager.hpp" +#include "duckdb/catalog/catalog_entry/duck_schema_entry.hpp" +#include "duckdb/storage/storage_manager.hpp" +#include "duckdb/parser/parsed_data/drop_info.hpp" +#include "duckdb/parser/parsed_data/create_schema_info.hpp" +#include "duckdb/catalog/default/default_schemas.hpp" +#include "duckdb/function/built_in_functions.hpp" +#include "duckdb/main/attached_database.hpp" +#ifndef DISABLE_CORE_FUNCTIONS_EXTENSION +#include "duckdb/core_functions/core_functions.hpp" +#endif + namespace duckdb { DuckCatalog::DuckCatalog(AttachedDatabase &db) diff --git a/src/common/operator/cast_operators.cpp b/src/common/operator/cast_operators.cpp index c63adefcab7f..be59aabbe296 100644 --- a/src/common/operator/cast_operators.cpp +++ b/src/common/operator/cast_operators.cpp @@ -1593,12 +1593,12 @@ struct HugeIntCastData { using ResultType = T; using Operation = OP; ResultType result; - int64_t intermediate; + ResultType intermediate; uint8_t digits; ResultType decimal; uint16_t decimal_total_digits; - int64_t decimal_intermediate; + ResultType decimal_intermediate; uint16_t decimal_intermediate_digits; bool Flush() { @@ -1647,7 +1647,8 @@ struct HugeIntegerCastOperation { template static bool HandleDigit(T &state, uint8_t digit) { if (NEGATIVE) { - if (DUCKDB_UNLIKELY(state.intermediate < (NumericLimits::Minimum() + digit) / 10)) { + if (DUCKDB_UNLIKELY(UnsafeNumericCast(state.intermediate) < + (NumericLimits::Minimum() + digit) / 10)) { // intermediate is full: need to flush it if (!state.Flush()) { return false; @@ -1694,7 +1695,11 @@ struct HugeIntegerCastOperation { if (e < 0) { state.result = T::Operation::DivMod(state.result, T::Operation::POWERS_OF_TEN[-e], remainder); if (remainder < 0) { - remainder *= -1; + result_t negate_result; + if (!T::Operation::TryNegate(remainder, negate_result)) { + return false; + } + remainder = negate_result; } state.decimal = remainder; state.decimal_total_digits = UnsafeNumericCast(-e); diff --git a/src/common/operator/string_cast.cpp b/src/common/operator/string_cast.cpp index 2da6949abe50..0fd1f5f0a0ca 100644 --- a/src/common/operator/string_cast.cpp +++ b/src/common/operator/string_cast.cpp @@ -181,7 +181,7 @@ string_t StringCastTZ::Operation(dtime_tz_t input, Vector &vector) { auto ss = std::abs(offset); const auto hh = ss / Interval::SECS_PER_HOUR; - const auto hh_length = (hh < 100) ? 2 : NumericHelper::UnsignedLength(uint32_t(hh)); + const auto hh_length = UnsafeNumericCast((hh < 100) ? 2 : NumericHelper::UnsignedLength(uint32_t(hh))); length += hh_length; ss %= Interval::SECS_PER_HOUR; diff --git a/src/common/row_operations/row_aggregate.cpp b/src/common/row_operations/row_aggregate.cpp index f6e9e6cbb374..8be0ada4fa20 100644 --- a/src/common/row_operations/row_aggregate.cpp +++ b/src/common/row_operations/row_aggregate.cpp @@ -36,14 +36,14 @@ void RowOperations::DestroyStates(RowOperationsState &state, TupleDataLayout &la return; } // Move to the first aggregate state - VectorOperations::AddInPlace(addresses, layout.GetAggrOffset(), count); + VectorOperations::AddInPlace(addresses, UnsafeNumericCast(layout.GetAggrOffset()), count); for (const auto &aggr : layout.GetAggregates()) { if (aggr.function.destructor) { AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); aggr.function.destructor(addresses, aggr_input_data, count); } // Move to the next aggregate state - VectorOperations::AddInPlace(addresses, aggr.payload_size, count); + VectorOperations::AddInPlace(addresses, UnsafeNumericCast(aggr.payload_size), count); } } @@ -74,8 +74,8 @@ void RowOperations::CombineStates(RowOperationsState &state, TupleDataLayout &la } // Move to the first aggregate states - VectorOperations::AddInPlace(sources, layout.GetAggrOffset(), count); - VectorOperations::AddInPlace(targets, layout.GetAggrOffset(), count); + VectorOperations::AddInPlace(sources, UnsafeNumericCast(layout.GetAggrOffset()), count); + VectorOperations::AddInPlace(targets, UnsafeNumericCast(layout.GetAggrOffset()), count); // Keep track of the offset idx_t offset = layout.GetAggrOffset(); @@ -87,16 +87,16 @@ void RowOperations::CombineStates(RowOperationsState &state, TupleDataLayout &la aggr.function.combine(sources, targets, aggr_input_data, count); // Move to the next aggregate states - VectorOperations::AddInPlace(sources, aggr.payload_size, count); - VectorOperations::AddInPlace(targets, aggr.payload_size, count); + VectorOperations::AddInPlace(sources, UnsafeNumericCast(aggr.payload_size), count); + VectorOperations::AddInPlace(targets, UnsafeNumericCast(aggr.payload_size), count); // Increment the offset offset += aggr.payload_size; } // Now subtract the offset to get back to the original position - VectorOperations::AddInPlace(sources, -offset, count); - VectorOperations::AddInPlace(targets, -offset, count); + VectorOperations::AddInPlace(sources, UnsafeNumericCast(-offset), count); + VectorOperations::AddInPlace(targets, UnsafeNumericCast(-offset), count); } void RowOperations::FinalizeStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, @@ -106,7 +106,7 @@ void RowOperations::FinalizeStates(RowOperationsState &state, TupleDataLayout &l VectorOperations::Copy(addresses, addresses_copy, result.size(), 0, 0); // Move to the first aggregate state - VectorOperations::AddInPlace(addresses_copy, layout.GetAggrOffset(), result.size()); + VectorOperations::AddInPlace(addresses_copy, UnsafeNumericCast(layout.GetAggrOffset()), result.size()); auto &aggregates = layout.GetAggregates(); for (idx_t i = 0; i < aggregates.size(); i++) { @@ -116,7 +116,7 @@ void RowOperations::FinalizeStates(RowOperationsState &state, TupleDataLayout &l aggr.function.finalize(addresses_copy, aggr_input_data, target, result.size(), 0); // Move to the next aggregate state - VectorOperations::AddInPlace(addresses_copy, aggr.payload_size, result.size()); + VectorOperations::AddInPlace(addresses_copy, UnsafeNumericCast(aggr.payload_size), result.size()); } } diff --git a/src/common/row_operations/row_external.cpp b/src/common/row_operations/row_external.cpp index 9e2fa071f86d..5aef76a00dd7 100644 --- a/src/common/row_operations/row_external.cpp +++ b/src/common/row_operations/row_external.cpp @@ -37,7 +37,8 @@ void RowOperations::SwizzleColumns(const RowLayout &layout, const data_ptr_t bas for (idx_t i = 0; i < next; i++) { if (Load(col_ptr) > string_t::INLINE_LENGTH) { // Overwrite the string pointer with the within-row offset (if not inlined) - Store(Load(string_ptr) - heap_row_ptrs[i], string_ptr); + Store(UnsafeNumericCast(Load(string_ptr) - heap_row_ptrs[i]), + string_ptr); } col_ptr += row_width; string_ptr += row_width; @@ -46,7 +47,7 @@ void RowOperations::SwizzleColumns(const RowLayout &layout, const data_ptr_t bas // Non-varchar blob columns for (idx_t i = 0; i < next; i++) { // Overwrite the column data pointer with the within-row offset - Store(Load(col_ptr) - heap_row_ptrs[i], col_ptr); + Store(UnsafeNumericCast(Load(col_ptr) - heap_row_ptrs[i]), col_ptr); col_ptr += row_width; } } @@ -79,7 +80,7 @@ void RowOperations::CopyHeapAndSwizzle(const RowLayout &layout, data_ptr_t row_p // Copy and swizzle memcpy(heap_ptr, source_heap_ptr, size); - Store(heap_ptr - heap_base_ptr, row_ptr + heap_offset); + Store(UnsafeNumericCast(heap_ptr - heap_base_ptr), row_ptr + heap_offset); // Increment for next iteration row_ptr += row_width; diff --git a/src/common/serializer/buffered_file_reader.cpp b/src/common/serializer/buffered_file_reader.cpp index 762e1ca889a0..a6ed87e9f4bc 100644 --- a/src/common/serializer/buffered_file_reader.cpp +++ b/src/common/serializer/buffered_file_reader.cpp @@ -11,20 +11,20 @@ BufferedFileReader::BufferedFileReader(FileSystem &fs, const char *path, FileLoc optional_ptr opener) : fs(fs), data(make_unsafe_uniq_array(FILE_BUFFER_SIZE)), offset(0), read_data(0), total_read(0) { handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ | lock_type, opener.get()); - file_size = fs.GetFileSize(*handle); + file_size = NumericCast(fs.GetFileSize(*handle)); } BufferedFileReader::BufferedFileReader(FileSystem &fs, unique_ptr handle_p) : fs(fs), data(make_unsafe_uniq_array(FILE_BUFFER_SIZE)), offset(0), read_data(0), handle(std::move(handle_p)), total_read(0) { - file_size = fs.GetFileSize(*handle); + file_size = NumericCast(fs.GetFileSize(*handle)); } void BufferedFileReader::ReadData(data_ptr_t target_buffer, uint64_t read_size) { // first copy anything we can from the buffer data_ptr_t end_ptr = target_buffer + read_size; while (true) { - idx_t to_read = MinValue(end_ptr - target_buffer, read_data - offset); + idx_t to_read = MinValue(UnsafeNumericCast(end_ptr - target_buffer), read_data - offset); if (to_read > 0) { memcpy(target_buffer, data.get() + offset, to_read); offset += to_read; @@ -36,7 +36,7 @@ void BufferedFileReader::ReadData(data_ptr_t target_buffer, uint64_t read_size) // did not finish reading yet but exhausted buffer // read data into buffer offset = 0; - read_data = fs.Read(*handle, data.get(), FILE_BUFFER_SIZE); + read_data = UnsafeNumericCast(fs.Read(*handle, data.get(), FILE_BUFFER_SIZE)); if (read_data == 0) { throw SerializationException("not enough data in file to deserialize result"); } diff --git a/src/common/serializer/buffered_file_writer.cpp b/src/common/serializer/buffered_file_writer.cpp index dcbe7d4f0a45..62d237e63ea3 100644 --- a/src/common/serializer/buffered_file_writer.cpp +++ b/src/common/serializer/buffered_file_writer.cpp @@ -15,7 +15,7 @@ BufferedFileWriter::BufferedFileWriter(FileSystem &fs, const string &path_p, Fil } int64_t BufferedFileWriter::GetFileSize() { - return fs.GetFileSize(*handle) + offset; + return fs.GetFileSize(*handle) + NumericCast(offset); } idx_t BufferedFileWriter::GetTotalWritten() { @@ -37,13 +37,14 @@ void BufferedFileWriter::WriteData(const_data_ptr_t buffer, idx_t write_size) { Flush(); // Flush buffer before writing every things else } idx_t remaining_to_write = write_size - to_copy; - fs.Write(*handle, const_cast(buffer + to_copy), remaining_to_write); // NOLINT: wrong API in Write + fs.Write(*handle, const_cast(buffer + to_copy), + UnsafeNumericCast(remaining_to_write)); // NOLINT: wrong API in Write total_written += remaining_to_write; } else { // first copy anything we can from the buffer const_data_ptr_t end_ptr = buffer + write_size; while (buffer < end_ptr) { - idx_t to_write = MinValue((end_ptr - buffer), FILE_BUFFER_SIZE - offset); + idx_t to_write = MinValue(UnsafeNumericCast((end_ptr - buffer)), FILE_BUFFER_SIZE - offset); D_ASSERT(to_write > 0); memcpy(data.get() + offset, buffer, to_write); offset += to_write; @@ -59,7 +60,7 @@ void BufferedFileWriter::Flush() { if (offset == 0) { return; } - fs.Write(*handle, data.get(), offset); + fs.Write(*handle, data.get(), UnsafeNumericCast(offset)); total_written += offset; offset = 0; } @@ -70,11 +71,11 @@ void BufferedFileWriter::Sync() { } void BufferedFileWriter::Truncate(int64_t size) { - uint64_t persistent = fs.GetFileSize(*handle); - D_ASSERT((uint64_t)size <= persistent + offset); - if (persistent <= (uint64_t)size) { + auto persistent = fs.GetFileSize(*handle); + D_ASSERT(size <= persistent + NumericCast(offset)); + if (persistent <= size) { // truncating into the pending write buffer. - offset = size - persistent; + offset = NumericCast(size - persistent); } else { // truncate the physical file on disk handle->Truncate(size); diff --git a/src/common/sort/comparators.cpp b/src/common/sort/comparators.cpp index 7084a3f993a3..82e8069d9211 100644 --- a/src/common/sort/comparators.cpp +++ b/src/common/sort/comparators.cpp @@ -501,7 +501,7 @@ void Comparators::SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap if (type.InternalType() == PhysicalType::VARCHAR) { data_ptr += string_t::HEADER_SIZE; } - Store(Load(data_ptr) - heap_ptr, data_ptr); + Store(UnsafeNumericCast(Load(data_ptr) - heap_ptr), data_ptr); } } // namespace duckdb diff --git a/src/common/sort/merge_sorter.cpp b/src/common/sort/merge_sorter.cpp index 7d2f6a1dc4e1..b36887e66a06 100644 --- a/src/common/sort/merge_sorter.cpp +++ b/src/common/sort/merge_sorter.cpp @@ -516,9 +516,9 @@ void MergeSorter::MergeData(SortedData &result_data, SortedData &l_data, SortedD entry_size = l_smaller * Load(l_heap_ptr_copy) + r_smaller * Load(r_heap_ptr_copy); D_ASSERT(entry_size >= sizeof(uint32_t)); - D_ASSERT(l_heap_ptr_copy - l.BaseHeapPtr(l_data) + l_smaller * entry_size <= + D_ASSERT(NumericCast(l_heap_ptr_copy - l.BaseHeapPtr(l_data)) + l_smaller * entry_size <= l_data.heap_blocks[l.block_idx]->byte_offset); - D_ASSERT(r_heap_ptr_copy - r.BaseHeapPtr(r_data) + r_smaller * entry_size <= + D_ASSERT(NumericCast(r_heap_ptr_copy - r.BaseHeapPtr(r_data)) + r_smaller * entry_size <= r_data.heap_blocks[r.block_idx]->byte_offset); l_heap_ptr_copy += l_smaller * entry_size; r_heap_ptr_copy += r_smaller * entry_size; diff --git a/src/common/sort/partition_state.cpp b/src/common/sort/partition_state.cpp index c123f275556f..5346d057a381 100644 --- a/src/common/sort/partition_state.cpp +++ b/src/common/sort/partition_state.cpp @@ -318,7 +318,7 @@ void PartitionLocalSinkState::Sink(DataChunk &input_chunk) { const auto entry_size = payload_layout.GetRowWidth(); const auto capacity = MaxValue(STANDARD_VECTOR_SIZE, (Storage::BLOCK_SIZE / entry_size) + 1); rows = make_uniq(gstate.buffer_manager, capacity, entry_size); - strings = make_uniq(gstate.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1, true); + strings = make_uniq(gstate.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U, true); } const auto row_count = input_chunk.size(); const auto row_sel = FlatVector::IncrementalSelectionVector(); @@ -402,8 +402,8 @@ void PartitionLocalSinkState::Combine() { PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data_p, hash_t hash_bin) : sink(sink), group_data(std::move(group_data_p)), memory_per_thread(sink.memory_per_thread), - num_threads(TaskScheduler::GetScheduler(sink.context).NumberOfThreads()), stage(PartitionSortStage::INIT), - total_tasks(0), tasks_assigned(0), tasks_completed(0) { + num_threads(NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())), + stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), tasks_completed(0) { const auto group_idx = sink.hash_groups.size(); auto new_group = make_uniq(sink.buffer_manager, sink.partitions, sink.orders, @@ -424,8 +424,8 @@ PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &s PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink) : sink(sink), memory_per_thread(sink.memory_per_thread), - num_threads(TaskScheduler::GetScheduler(sink.context).NumberOfThreads()), stage(PartitionSortStage::INIT), - total_tasks(0), tasks_assigned(0), tasks_completed(0) { + num_threads(NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())), + stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), tasks_completed(0) { const hash_t hash_bin = 0; const size_t group_idx = 0; @@ -661,7 +661,7 @@ void PartitionMergeEvent::Schedule() { // Schedule tasks equal to the number of threads, which will each merge multiple partitions auto &ts = TaskScheduler::GetScheduler(context); - idx_t num_threads = ts.NumberOfThreads(); + auto num_threads = NumericCast(ts.NumberOfThreads()); vector> merge_tasks; for (idx_t tnum = 0; tnum < num_threads; tnum++) { diff --git a/src/common/sort/sort_state.cpp b/src/common/sort/sort_state.cpp index 9c8594866e8a..27650b46e76b 100644 --- a/src/common/sort/sort_state.cpp +++ b/src/common/sort/sort_state.cpp @@ -171,13 +171,13 @@ void LocalSortState::Initialize(GlobalSortState &global_sort_state, BufferManage auto blob_row_width = sort_layout->blob_layout.GetRowWidth(); blob_sorting_data = make_uniq( *buffer_manager, RowDataCollection::EntriesPerBlock(blob_row_width), blob_row_width); - blob_sorting_heap = make_uniq(*buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1, true); + blob_sorting_heap = make_uniq(*buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U, true); } // Payload data auto payload_row_width = payload_layout->GetRowWidth(); payload_data = make_uniq(*buffer_manager, RowDataCollection::EntriesPerBlock(payload_row_width), payload_row_width); - payload_heap = make_uniq(*buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1, true); + payload_heap = make_uniq(*buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U, true); // Init done initialized = true; } @@ -323,7 +323,7 @@ void LocalSortState::ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataColl std::accumulate(heap.blocks.begin(), heap.blocks.end(), (idx_t)0, [](idx_t a, const unique_ptr &b) { return a + b->byte_offset; }); idx_t heap_block_size = MaxValue(total_byte_offset, (idx_t)Storage::BLOCK_SIZE); - auto ordered_heap_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, heap_block_size, 1); + auto ordered_heap_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, heap_block_size, 1U); ordered_heap_block->count = count; ordered_heap_block->byte_offset = total_byte_offset; auto ordered_heap_handle = buffer_manager->Pin(ordered_heap_block->block); diff --git a/src/common/sort/sorted_block.cpp b/src/common/sort/sorted_block.cpp index 22127abe1032..9539302c3b4a 100644 --- a/src/common/sort/sorted_block.cpp +++ b/src/common/sort/sorted_block.cpp @@ -30,7 +30,7 @@ void SortedData::CreateBlock() { data_blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, layout.GetRowWidth())); if (!layout.AllConstant() && state.external) { heap_blocks.push_back( - make_uniq(MemoryTag::ORDER_BY, buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1)); + make_uniq(MemoryTag::ORDER_BY, buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U)); D_ASSERT(data_blocks.size() == heap_blocks.size()); } } @@ -291,10 +291,10 @@ PayloadScanner::PayloadScanner(SortedData &sorted_data, GlobalSortState &global_ auto &layout = sorted_data.layout; // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1); + rows = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U); rows->count = count; - heap = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1); + heap = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U); if (!sorted_data.layout.AllConstant()) { heap->count = count; } @@ -330,7 +330,7 @@ PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, idx_t block_i auto &layout = sorted_data.layout; // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1); + rows = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U); if (flush_p) { rows->blocks.emplace_back(std::move(sorted_data.data_blocks[block_idx])); } else { @@ -338,7 +338,7 @@ PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, idx_t block_i } rows->count = count; - heap = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1); + heap = make_uniq(global_sort_state.buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U); if (!sorted_data.layout.AllConstant() && sorted_data.swizzled) { if (flush_p) { heap->blocks.emplace_back(std::move(sorted_data.heap_blocks[block_idx])); diff --git a/src/common/types/bit.cpp b/src/common/types/bit.cpp index 6f2bd8a0d066..83f0ce8365df 100644 --- a/src/common/types/bit.cpp +++ b/src/common/types/bit.cpp @@ -179,7 +179,7 @@ void Bit::BitToBlob(string_t bit, string_t &output_blob) { auto output = output_blob.GetDataWriteable(); idx_t size = output_blob.GetSize(); - output[0] = GetFirstByte(bit); + output[0] = UnsafeNumericCast(GetFirstByte(bit)); if (size > 2) { ++output; // First byte in bitstring contains amount of padded bits, diff --git a/src/common/types/blob.cpp b/src/common/types/blob.cpp index 5cb8e5df606b..6f472c0e3b92 100644 --- a/src/common/types/blob.cpp +++ b/src/common/types/blob.cpp @@ -47,7 +47,7 @@ void Blob::ToString(string_t blob, char *output) { for (idx_t i = 0; i < len; i++) { if (IsRegularCharacter(data[i])) { // ascii characters are rendered as-is - output[str_idx++] = data[i]; + output[str_idx++] = UnsafeNumericCast(data[i]); } else { auto byte_a = data[i] >> 4; auto byte_b = data[i] & 0x0F; @@ -244,8 +244,8 @@ uint32_t DecodeBase64Bytes(const string_t &str, const_data_ptr_t input_data, idx input_data[base_idx + decode_idx], base_idx + decode_idx); } } - return (decoded_bytes[0] << 3 * 6) + (decoded_bytes[1] << 2 * 6) + (decoded_bytes[2] << 1 * 6) + - (decoded_bytes[3] << 0 * 6); + return UnsafeNumericCast((decoded_bytes[0] << 3 * 6) + (decoded_bytes[1] << 2 * 6) + + (decoded_bytes[2] << 1 * 6) + (decoded_bytes[3] << 0 * 6)); } void Blob::FromBase64(string_t str, data_ptr_t output, idx_t output_size) { diff --git a/src/common/types/cast_helpers.cpp b/src/common/types/cast_helpers.cpp index 1e37fbc79ea6..f37fbaa971e9 100644 --- a/src/common/types/cast_helpers.cpp +++ b/src/common/types/cast_helpers.cpp @@ -67,7 +67,7 @@ int NumericHelper::UnsignedLength(uint32_t value) { } template <> -idx_t NumericHelper::UnsignedLength(uint64_t value) { +int NumericHelper::UnsignedLength(uint64_t value) { if (value >= 10000000000ULL) { if (value >= 1000000000000000ULL) { int length = 16; diff --git a/src/common/types/conflict_manager.cpp b/src/common/types/conflict_manager.cpp index 38d61240dcca..171d45115fb4 100644 --- a/src/common/types/conflict_manager.cpp +++ b/src/common/types/conflict_manager.cpp @@ -159,7 +159,7 @@ bool ConflictManager::AddNull(idx_t chunk_index) { if (!IsConflict(LookupResultType::LOOKUP_NULL)) { return false; } - return AddHit(chunk_index, DConstants::INVALID_INDEX); + return AddHit(chunk_index, UnsafeNumericCast(DConstants::INVALID_INDEX)); } bool ConflictManager::SingleIndexTarget() const { diff --git a/src/common/types/date.cpp b/src/common/types/date.cpp index 1f06b884dd27..be3349035199 100644 --- a/src/common/types/date.cpp +++ b/src/common/types/date.cpp @@ -307,7 +307,7 @@ bool Date::TryConvertDate(const char *buf, idx_t len, idx_t &pos, date_t &result // in strict mode, check remaining string for non-space characters if (strict) { // skip trailing spaces - while (pos < len && StringUtil::CharacterIsSpace((unsigned char)buf[pos])) { + while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { pos++; } // check position. if end was not reached, non-space chars remaining @@ -316,7 +316,7 @@ bool Date::TryConvertDate(const char *buf, idx_t len, idx_t &pos, date_t &result } } else { // in non-strict mode, check for any direct trailing digits - if (pos < len && StringUtil::CharacterIsDigit((unsigned char)buf[pos])) { + if (pos < len && StringUtil::CharacterIsDigit(buf[pos])) { return false; } } diff --git a/src/common/types/decimal.cpp b/src/common/types/decimal.cpp index 323cec4e381d..2f38d76ea786 100644 --- a/src/common/types/decimal.cpp +++ b/src/common/types/decimal.cpp @@ -6,9 +6,9 @@ namespace duckdb { template string TemplatedDecimalToString(SIGNED value, uint8_t width, uint8_t scale) { auto len = DecimalToString::DecimalLength(value, width, scale); - auto data = make_unsafe_uniq_array(len + 1); - DecimalToString::FormatDecimal(value, width, scale, data.get(), len); - return string(data.get(), len); + auto data = make_unsafe_uniq_array(UnsafeNumericCast(len + 1)); + DecimalToString::FormatDecimal(value, width, scale, data.get(), UnsafeNumericCast(len)); + return string(data.get(), UnsafeNumericCast(len)); } string Decimal::ToString(int16_t value, uint8_t width, uint8_t scale) { @@ -25,9 +25,9 @@ string Decimal::ToString(int64_t value, uint8_t width, uint8_t scale) { string Decimal::ToString(hugeint_t value, uint8_t width, uint8_t scale) { auto len = HugeintToStringCast::DecimalLength(value, width, scale); - auto data = make_unsafe_uniq_array(len + 1); - HugeintToStringCast::FormatDecimal(value, width, scale, data.get(), len); - return string(data.get(), len); + auto data = make_unsafe_uniq_array(UnsafeNumericCast(len + 1)); + HugeintToStringCast::FormatDecimal(value, width, scale, data.get(), UnsafeNumericCast(len)); + return string(data.get(), UnsafeNumericCast(len)); } } // namespace duckdb diff --git a/src/common/types/hash.cpp b/src/common/types/hash.cpp index cdc9ba3020ec..83a1ef22310e 100644 --- a/src/common/types/hash.cpp +++ b/src/common/types/hash.cpp @@ -22,7 +22,7 @@ hash_t Hash(int64_t val) { template <> hash_t Hash(hugeint_t val) { - return MurmurHash64(val.lower) ^ MurmurHash64(val.upper); + return MurmurHash64(val.lower) ^ MurmurHash64(static_cast(val.upper)); } template <> diff --git a/src/common/types/hugeint.cpp b/src/common/types/hugeint.cpp index d5343fc6860d..bc4a13e0b98d 100644 --- a/src/common/types/hugeint.cpp +++ b/src/common/types/hugeint.cpp @@ -85,7 +85,7 @@ static uint8_t PositiveHugeintHighestBit(hugeint_t bits) { uint8_t out = 0; if (bits.upper) { out = 64; - uint64_t up = bits.upper; + uint64_t up = static_cast(bits.upper); while (up) { up >>= 1; out++; @@ -104,7 +104,7 @@ static bool PositiveHugeintIsBitSet(hugeint_t lhs, uint8_t bit_position) { if (bit_position < 64) { return lhs.lower & (uint64_t(1) << uint64_t(bit_position)); } else { - return lhs.upper & (uint64_t(1) << uint64_t(bit_position - 64)); + return static_cast(lhs.upper) & (uint64_t(1) << uint64_t(bit_position - 64)); } } diff --git a/src/function/table/system/duckdb_tables.cpp b/src/function/table/system/duckdb_tables.cpp index 95a0e79ad0f5..bcc0668cde01 100644 --- a/src/function/table/system/duckdb_tables.cpp +++ b/src/function/table/system/duckdb_tables.cpp @@ -146,8 +146,9 @@ void DuckDBTablesFunction(ClientContext &context, TableFunctionInput &data_p, Da output.SetValue(col++, count, Value::BOOLEAN(TableHasPrimaryKey(table))); // estimated_size, LogicalType::BIGINT - Value card_val = - !storage_info.cardinality.IsValid() ? Value() : Value::BIGINT(NumericCast(storage_info.cardinality.GetIndex())); + Value card_val = !storage_info.cardinality.IsValid() + ? Value() + : Value::BIGINT(NumericCast(storage_info.cardinality.GetIndex())); output.SetValue(col++, count, card_val); // column_count, LogicalType::BIGINT output.SetValue(col++, count, Value::BIGINT(NumericCast(table.GetColumns().LogicalColumnCount()))); diff --git a/src/include/duckdb/common/operator/integer_cast_operator.hpp b/src/include/duckdb/common/operator/integer_cast_operator.hpp index 0fdca6b15b16..decec991f7e6 100644 --- a/src/include/duckdb/common/operator/integer_cast_operator.hpp +++ b/src/include/duckdb/common/operator/integer_cast_operator.hpp @@ -234,7 +234,8 @@ static bool IntegerCastLoop(const char *buf, idx_t len, T &result, bool strict) if (!StringUtil::CharacterIsDigit(buf[pos])) { break; } - if (!OP::template HandleDecimal(result, buf[pos] - '0')) { + if (!OP::template HandleDecimal( + result, UnsafeNumericCast(buf[pos] - '0'))) { return false; } pos++; @@ -296,7 +297,7 @@ static bool IntegerCastLoop(const char *buf, idx_t len, T &result, bool strict) } return false; } - uint8_t digit = buf[pos++] - '0'; + auto digit = UnsafeNumericCast(buf[pos++] - '0'); if (!OP::template HandleDigit(result, digit)) { return false; } @@ -330,9 +331,9 @@ static bool IntegerHexCastLoop(const char *buf, idx_t len, T &result, bool stric } uint8_t digit; if (current_char >= 'a') { - digit = current_char - 'a' + 10; + digit = UnsafeNumericCast(current_char - 'a' + 10); } else { - digit = current_char - '0'; + digit = UnsafeNumericCast(current_char - '0'); } pos++; diff --git a/src/include/duckdb/common/sort/duckdb_pdqsort.hpp b/src/include/duckdb/common/sort/duckdb_pdqsort.hpp index 7b239a15ab9f..cae339b20f2e 100644 --- a/src/include/duckdb/common/sort/duckdb_pdqsort.hpp +++ b/src/include/duckdb/common/sort/duckdb_pdqsort.hpp @@ -154,9 +154,9 @@ struct PDQIterator { } inline friend idx_t operator-(const PDQIterator &lhs, const PDQIterator &rhs) { - D_ASSERT((*lhs - *rhs) % lhs.entry_size == 0); + D_ASSERT(duckdb::NumericCast(*lhs - *rhs) % lhs.entry_size == 0); D_ASSERT(*lhs - *rhs >= 0); - return (*lhs - *rhs) / lhs.entry_size; + return duckdb::NumericCast(*lhs - *rhs) / lhs.entry_size; } inline friend bool operator<(const PDQIterator &lhs, const PDQIterator &rhs) { @@ -320,7 +320,7 @@ inline T *align_cacheline(T *p) { #else std::size_t ip = reinterpret_cast(p); #endif - ip = (ip + cacheline_size - 1) & -cacheline_size; + ip = (ip + cacheline_size - 1) & duckdb::UnsafeNumericCast(-cacheline_size); return reinterpret_cast(ip); } diff --git a/src/include/duckdb/common/types/cast_helpers.hpp b/src/include/duckdb/common/types/cast_helpers.hpp index 2e071f3798d4..03c45b5a09df 100644 --- a/src/include/duckdb/common/types/cast_helpers.hpp +++ b/src/include/duckdb/common/types/cast_helpers.hpp @@ -62,7 +62,7 @@ class NumericHelper { template static string_t FormatSigned(SIGNED value, Vector &vector) { int sign = -(value < 0); - UNSIGNED unsigned_value = UnsafeNumericCast(UNSIGNED(value ^ sign) - sign); + UNSIGNED unsigned_value = UNSIGNED(value) ^ UNSIGNED(sign) - UNSIGNED(sign); int length = UnsignedLength(unsigned_value) - sign; string_t result = StringVector::EmptyString(vector, NumericCast(length)); auto dataptr = result.GetDataWriteable(); @@ -123,16 +123,18 @@ struct DecimalToString { *dst = '-'; } if (scale == 0) { - NumericHelper::FormatUnsigned(value, end); + NumericHelper::FormatUnsigned(UnsafeNumericCast(value), end); return; } // we write two numbers: // the numbers BEFORE the decimal (major) // and the numbers AFTER the decimal (minor) - UNSIGNED minor = value % (UNSIGNED)NumericHelper::POWERS_OF_TEN[scale]; - UNSIGNED major = value / (UNSIGNED)NumericHelper::POWERS_OF_TEN[scale]; + auto minor = + UnsafeNumericCast(value) % UnsafeNumericCast(NumericHelper::POWERS_OF_TEN[scale]); + auto major = + UnsafeNumericCast(value) / UnsafeNumericCast(NumericHelper::POWERS_OF_TEN[scale]); // write the number after the decimal - dst = NumericHelper::FormatUnsigned(minor, end); + dst = NumericHelper::FormatUnsigned(UnsafeNumericCast(minor), end); // (optionally) pad with zeros and add the decimal point while (dst > (end - scale)) { *--dst = '0'; @@ -142,7 +144,7 @@ struct DecimalToString { D_ASSERT(width > scale || major == 0); if (width > scale) { // there are numbers after the comma - dst = NumericHelper::FormatUnsigned(major, dst); + dst = NumericHelper::FormatUnsigned(UnsafeNumericCast(major), dst); } } @@ -150,7 +152,7 @@ struct DecimalToString { static string_t Format(SIGNED value, uint8_t width, uint8_t scale, Vector &vector) { int len = DecimalLength(value, width, scale); string_t result = StringVector::EmptyString(vector, NumericCast(len)); - FormatDecimal(value, width, scale, result.GetDataWriteable(), len); + FormatDecimal(value, width, scale, result.GetDataWriteable(), UnsafeNumericCast(len)); result.Finalize(); return result; } diff --git a/src/include/duckdb/common/types/validity_mask.hpp b/src/include/duckdb/common/types/validity_mask.hpp index b7dd548cf68a..fce5f0d1a09d 100644 --- a/src/include/duckdb/common/types/validity_mask.hpp +++ b/src/include/duckdb/common/types/validity_mask.hpp @@ -19,7 +19,7 @@ struct ValidityMask; template struct TemplatedValidityData { static constexpr const int BITS_PER_VALUE = sizeof(V) * 8; - static constexpr const V MAX_ENTRY = ~V(0); + static constexpr const V MAX_ENTRY = V(~V(0)); public: inline explicit TemplatedValidityData(idx_t count) { diff --git a/src/optimizer/rule/ordered_aggregate_optimizer.cpp b/src/optimizer/rule/ordered_aggregate_optimizer.cpp index b8f52e4916d7..93c9c80166a8 100644 --- a/src/optimizer/rule/ordered_aggregate_optimizer.cpp +++ b/src/optimizer/rule/ordered_aggregate_optimizer.cpp @@ -6,6 +6,7 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp" namespace duckdb { From 1cdfbd85d2d70fd0d79920252bf52f1508ac98f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Thu, 4 Apr 2024 12:59:15 +0200 Subject: [PATCH 045/201] hugeint --- src/common/types/hugeint.cpp | 15 ++++++++------- src/common/types/uhugeint.cpp | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/common/types/hugeint.cpp b/src/common/types/hugeint.cpp index bc4a13e0b98d..c99abaa9ee1e 100644 --- a/src/common/types/hugeint.cpp +++ b/src/common/types/hugeint.cpp @@ -112,7 +112,8 @@ static hugeint_t PositiveHugeintLeftShift(hugeint_t lhs, uint32_t amount) { D_ASSERT(amount > 0 && amount < 64); hugeint_t result; result.lower = lhs.lower << amount; - result.upper = (lhs.upper << amount) + (lhs.lower >> (64 - amount)); + result.upper = + UnsafeNumericCast((UnsafeNumericCast(lhs.upper) << amount) + (lhs.lower >> (64 - amount))); return result; } @@ -625,7 +626,7 @@ bool Hugeint::TryCast(hugeint_t input, uhugeint_t &result) { } result.lower = input.lower; - result.upper = input.upper; + result.upper = UnsafeNumericCast(input.upper); return true; } @@ -744,7 +745,7 @@ bool ConvertFloatingToBigint(REAL_T value, hugeint_t &result) { value = -value; } result.lower = (uint64_t)fmod(value, REAL_T(NumericLimits::Maximum())); - result.upper = (uint64_t)(value / REAL_T(NumericLimits::Maximum())); + result.upper = (int64_t)(value / REAL_T(NumericLimits::Maximum())); if (negative) { Hugeint::NegateInPlace(result); } @@ -829,14 +830,14 @@ hugeint_t hugeint_t::operator>>(const hugeint_t &rhs) const { return *this; } else if (shift == 64) { result.upper = (upper < 0) ? -1 : 0; - result.lower = upper; + result.lower = UnsafeNumericCast(upper); } else if (shift < 64) { // perform lower shift in unsigned integer, and mask away the most significant bit result.lower = (uint64_t(upper) << (64 - shift)) | (lower >> shift); result.upper = upper >> shift; } else { D_ASSERT(shift < 128); - result.lower = upper >> (shift - 64); + result.lower = UnsafeNumericCast(upper >> (shift - 64)); result.upper = (upper < 0) ? -1 : 0; } return result; @@ -851,7 +852,7 @@ hugeint_t hugeint_t::operator<<(const hugeint_t &rhs) const { if (rhs.upper != 0 || shift >= 128) { return hugeint_t(0); } else if (shift == 64) { - result.upper = lower; + result.upper = UnsafeNumericCast(lower); result.lower = 0; } else if (shift == 0) { return *this; @@ -859,7 +860,7 @@ hugeint_t hugeint_t::operator<<(const hugeint_t &rhs) const { // perform upper shift in unsigned integer, and mask away the most significant bit uint64_t upper_shift = ((uint64_t(upper) << shift) + (lower >> (64 - shift))) & 0x7FFFFFFFFFFFFFFF; result.lower = lower << shift; - result.upper = upper_shift; + result.upper = UnsafeNumericCast(upper_shift); } else { D_ASSERT(shift < 128); result.lower = 0; diff --git a/src/common/types/uhugeint.cpp b/src/common/types/uhugeint.cpp index 479f756d4300..7469dd809642 100644 --- a/src/common/types/uhugeint.cpp +++ b/src/common/types/uhugeint.cpp @@ -399,7 +399,7 @@ bool Uhugeint::TryCast(uhugeint_t input, hugeint_t &result) { } result.lower = input.lower; - result.upper = input.upper; + result.upper = UnsafeNumericCast(input.upper); return true; } From 7eef6a705c78e2c2e8cdc9ea849be81f200bda02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Thu, 4 Apr 2024 15:05:47 +0200 Subject: [PATCH 046/201] moaar42 --- .../types/row/partitioned_tuple_data.cpp | 4 +- .../types/row/row_data_collection_scanner.cpp | 11 +-- src/common/types/uuid.cpp | 13 ++-- src/common/types/value.cpp | 8 +- src/common/types/vector.cpp | 6 +- .../vector_operations/null_operations.cpp | 2 +- .../numeric_inplace_operators.cpp | 4 +- src/execution/adaptive_filter.cpp | 2 +- src/execution/aggregate_hashtable.cpp | 6 +- .../aggregate/distinct_aggregate_data.cpp | 2 +- .../aggregate/physical_hash_aggregate.cpp | 2 +- .../aggregate/physical_streaming_window.cpp | 6 +- .../physical_ungrouped_aggregate.cpp | 2 +- .../operator/aggregate/physical_window.cpp | 2 +- .../buffer_manager/csv_file_handle.cpp | 3 +- .../scanner/column_count_scanner.cpp | 2 +- .../scanner/string_value_scanner.cpp | 2 +- .../csv_scanner/sniffer/header_detection.cpp | 17 ++-- .../csv_scanner/sniffer/type_detection.cpp | 2 +- src/execution/perfect_aggregate_hashtable.cpp | 8 +- src/execution/physical_operator.cpp | 4 +- src/execution/radix_partitioned_hashtable.cpp | 20 ++--- src/execution/window_executor.cpp | 78 +++++++++---------- src/include/duckdb/common/vector_size.hpp | 2 +- .../duckdb/execution/merge_sort_tree.hpp | 22 +++--- src/include/duckdb/storage/storage_info.hpp | 4 +- 26 files changed, 120 insertions(+), 114 deletions(-) diff --git a/src/common/types/row/partitioned_tuple_data.cpp b/src/common/types/row/partitioned_tuple_data.cpp index 979b292294e7..f3b6469812d0 100644 --- a/src/common/types/row/partitioned_tuple_data.cpp +++ b/src/common/types/row/partitioned_tuple_data.cpp @@ -329,8 +329,8 @@ void PartitionedTupleData::Repartition(PartitionedTupleData &new_partitioned_dat const int64_t update = reverse ? -1 : 1; const int64_t adjustment = reverse ? -1 : 0; - for (idx_t partition_idx = start_idx; partition_idx != end_idx; partition_idx += update) { - auto actual_partition_idx = partition_idx + adjustment; + for (idx_t partition_idx = start_idx; partition_idx != end_idx; partition_idx += idx_t(update)) { + auto actual_partition_idx = partition_idx + idx_t(adjustment); auto &partition = *partitions[actual_partition_idx]; if (partition.Count() > 0) { diff --git a/src/common/types/row/row_data_collection_scanner.cpp b/src/common/types/row/row_data_collection_scanner.cpp index 50135d887707..efbd072e5406 100644 --- a/src/common/types/row/row_data_collection_scanner.cpp +++ b/src/common/types/row/row_data_collection_scanner.cpp @@ -56,7 +56,8 @@ void RowDataCollectionScanner::AlignHeapBlocks(RowDataCollection &swizzled_block if (!swizzled_string_heap.keep_pinned) { auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block->count, heap_offset); + RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block->count, + NumericCast(heap_offset)); } else { swizzled_string_heap.pinned_blocks.emplace_back(std::move(heap_handle)); } @@ -84,7 +85,7 @@ void RowDataCollectionScanner::AlignHeapBlocks(RowDataCollection &swizzled_block auto heap_start_ptr = Load(data_ptr + layout.GetHeapOffset()); auto heap_end_ptr = Load(data_ptr + layout.GetHeapOffset() + (next - 1) * layout.GetRowWidth()); - idx_t size = heap_end_ptr - heap_start_ptr + Load(heap_end_ptr); + auto size = NumericCast(heap_end_ptr - heap_start_ptr + Load(heap_end_ptr)); ptrs_and_sizes.emplace_back(heap_start_ptr, size); D_ASSERT(size <= heap_blocks[heap_block_idx]->byte_offset); @@ -100,7 +101,7 @@ void RowDataCollectionScanner::AlignHeapBlocks(RowDataCollection &swizzled_block // Finally, we allocate a new heap block and copy data to it swizzled_string_heap.blocks.emplace_back(make_uniq( - MemoryTag::ORDER_BY, buffer_manager, MaxValue(total_size, (idx_t)Storage::BLOCK_SIZE), 1)); + MemoryTag::ORDER_BY, buffer_manager, MaxValue(total_size, (idx_t)Storage::BLOCK_SIZE), 1U)); auto new_heap_handle = buffer_manager.Pin(swizzled_string_heap.blocks.back()->block); auto new_heap_ptr = new_heap_handle.Ptr(); for (auto &ptr_and_size : ptrs_and_sizes) { @@ -174,7 +175,7 @@ RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, Ro // Pretend that we have scanned up to the start block // and will stop at the end auto begin = rows.blocks.begin(); - auto end = begin + block_idx; + auto end = begin + NumericCast(block_idx); total_scanned = std::accumulate(begin, end, idx_t(0), [&](idx_t c, const unique_ptr &b) { return c + b->count; }); total_count = total_scanned + (*end)->count; @@ -194,7 +195,7 @@ void RowDataCollectionScanner::SwizzleBlock(RowDataBlock &data_block, RowDataBlo auto heap_handle = heap.buffer_manager.Pin(heap_block.block); auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block.count, heap_offset); + RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block.count, NumericCast(heap_offset)); } void RowDataCollectionScanner::ReSwizzle() { diff --git a/src/common/types/uuid.cpp b/src/common/types/uuid.cpp index 818d9e8ddeaa..82583fe91b36 100644 --- a/src/common/types/uuid.cpp +++ b/src/common/types/uuid.cpp @@ -6,13 +6,13 @@ namespace duckdb { bool UUID::FromString(const string &str, hugeint_t &result) { auto hex2char = [](char ch) -> unsigned char { if (ch >= '0' && ch <= '9') { - return ch - '0'; + return UnsafeNumericCast(ch - '0'); } if (ch >= 'a' && ch <= 'f') { - return 10 + ch - 'a'; + return UnsafeNumericCast(10 + ch - 'a'); } if (ch >= 'A' && ch <= 'F') { - return 10 + ch - 'A'; + return UnsafeNumericCast(10 + ch - 'A'); } return 0; }; @@ -23,7 +23,7 @@ bool UUID::FromString(const string &str, hugeint_t &result) { if (str.empty()) { return false; } - int has_braces = 0; + idx_t has_braces = 0; if (str.front() == '{') { has_braces = 1; } @@ -54,14 +54,15 @@ bool UUID::FromString(const string &str, hugeint_t &result) { } void UUID::ToString(hugeint_t input, char *buf) { - auto byte_to_hex = [](char byte_val, char *buf, idx_t &pos) { + auto byte_to_hex = [](uint64_t byte_val, char *buf, idx_t &pos) { + D_ASSERT(byte_val <= 0xFF); static char const HEX_DIGITS[] = "0123456789abcdef"; buf[pos++] = HEX_DIGITS[(byte_val >> 4) & 0xf]; buf[pos++] = HEX_DIGITS[byte_val & 0xf]; }; // Flip back before convert to string - int64_t upper = input.upper ^ (uint64_t(1) << 63); + int64_t upper = int64_t(uint64_t(input.upper) ^ (uint64_t(1) << 63)); idx_t pos = 0; byte_to_hex(upper >> 56 & 0xFF, buf, pos); byte_to_hex(upper >> 48 & 0xFF, buf, pos); diff --git a/src/common/types/value.cpp b/src/common/types/value.cpp index fb3ee10cea4d..94c7c3415134 100644 --- a/src/common/types/value.cpp +++ b/src/common/types/value.cpp @@ -1209,11 +1209,11 @@ Value Value::Numeric(const LogicalType &type, int64_t value) { return Value::UINTEGER((uint32_t)value); case LogicalTypeId::UBIGINT: D_ASSERT(value >= 0); - return Value::UBIGINT(value); + return Value::UBIGINT(NumericCast(value)); case LogicalTypeId::HUGEINT: return Value::HUGEINT(value); case LogicalTypeId::UHUGEINT: - return Value::UHUGEINT(value); + return Value::UHUGEINT(NumericCast(value)); case LogicalTypeId::DECIMAL: return Value::DECIMAL(value, DecimalType::GetWidth(type), DecimalType::GetScale(type)); case LogicalTypeId::FLOAT: @@ -1221,7 +1221,7 @@ Value Value::Numeric(const LogicalType &type, int64_t value) { case LogicalTypeId::DOUBLE: return Value((double)value); case LogicalTypeId::POINTER: - return Value::POINTER(value); + return Value::POINTER(NumericCast(value)); case LogicalTypeId::DATE: D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); return Value::DATE(date_t(NumericCast(value))); @@ -1662,7 +1662,7 @@ hugeint_t IntegralValue::Get(const Value &value) { case PhysicalType::UINT32: return UIntegerValue::Get(value); case PhysicalType::UINT64: - return UBigIntValue::Get(value); + return NumericCast(UBigIntValue::Get(value)); case PhysicalType::UINT128: return static_cast(UhugeIntValue::Get(value)); default: diff --git a/src/common/types/vector.cpp b/src/common/types/vector.cpp index b177911278b0..e61ad46dd259 100644 --- a/src/common/types/vector.cpp +++ b/src/common/types/vector.cpp @@ -555,7 +555,7 @@ Value Vector::GetValueInternal(const Vector &v_p, idx_t index_p) { case VectorType::SEQUENCE_VECTOR: { int64_t start, increment; SequenceVector::GetSequence(*vector, start, increment); - return Value::Numeric(vector->GetType(), start + increment * index); + return Value::Numeric(vector->GetType(), start + increment * NumericCast(index)); } default: throw InternalException("Unimplemented vector type for Vector::GetValue"); @@ -788,7 +788,7 @@ string Vector::ToString(idx_t count) const { int64_t start, increment; SequenceVector::GetSequence(*this, start, increment); for (idx_t i = 0; i < count; i++) { - retval += to_string(start + increment * i) + (i == count - 1 ? "" : ", "); + retval += to_string(start + increment * UnsafeNumericCast(i)) + (i == count - 1 ? "" : ", "); } break; } @@ -1006,7 +1006,7 @@ void Vector::Flatten(idx_t count) { buffer = VectorBuffer::CreateStandardVector(GetType()); data = buffer->GetData(); - VectorOperations::GenerateSequence(*this, sequence_count, start, increment); + VectorOperations::GenerateSequence(*this, NumericCast(sequence_count), start, increment); break; } default: diff --git a/src/common/vector_operations/null_operations.cpp b/src/common/vector_operations/null_operations.cpp index 48bc904da474..dd34ac8edbc1 100644 --- a/src/common/vector_operations/null_operations.cpp +++ b/src/common/vector_operations/null_operations.cpp @@ -102,7 +102,7 @@ idx_t VectorOperations::CountNotNull(Vector &input, const idx_t count) { default: for (idx_t i = 0; i < count; ++i) { const auto row_idx = vdata.sel->get_index(i); - valid += int(vdata.validity.RowIsValid(row_idx)); + valid += idx_t(vdata.validity.RowIsValid(row_idx)); } break; } diff --git a/src/common/vector_operations/numeric_inplace_operators.cpp b/src/common/vector_operations/numeric_inplace_operators.cpp index d2bd0f313db2..86b507a3ce9a 100644 --- a/src/common/vector_operations/numeric_inplace_operators.cpp +++ b/src/common/vector_operations/numeric_inplace_operators.cpp @@ -23,14 +23,14 @@ void VectorOperations::AddInPlace(Vector &input, int64_t right, idx_t count) { case VectorType::CONSTANT_VECTOR: { D_ASSERT(!ConstantVector::IsNull(input)); auto data = ConstantVector::GetData(input); - *data += right; + *data += UnsafeNumericCast(right); break; } default: { D_ASSERT(input.GetVectorType() == VectorType::FLAT_VECTOR); auto data = FlatVector::GetData(input); for (idx_t i = 0; i < count; i++) { - data[i] += right; + data[i] += UnsafeNumericCast(right); } break; } diff --git a/src/execution/adaptive_filter.cpp b/src/execution/adaptive_filter.cpp index 37ffad99b240..166174d00b89 100644 --- a/src/execution/adaptive_filter.cpp +++ b/src/execution/adaptive_filter.cpp @@ -58,7 +58,7 @@ void AdaptiveFilter::AdaptRuntimeStatistics(double duration) { // get swap index and swap likeliness std::uniform_int_distribution distribution(1, NumericCast(right_random_border)); // a <= i <= b - idx_t random_number = distribution(generator) - 1; + auto random_number = UnsafeNumericCast(distribution(generator) - 1); swap_idx = random_number / 100; // index to be swapped idx_t likeliness = random_number - 100 * swap_idx; // random number between [0, 100) diff --git a/src/execution/aggregate_hashtable.cpp b/src/execution/aggregate_hashtable.cpp index 90029b0a4b6c..c8ab1ce6b965 100644 --- a/src/execution/aggregate_hashtable.cpp +++ b/src/execution/aggregate_hashtable.cpp @@ -246,7 +246,7 @@ idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, Vector &group_hashe #endif const auto new_group_count = FindOrCreateGroups(groups, group_hashes, state.addresses, state.new_groups); - VectorOperations::AddInPlace(state.addresses, layout.GetAggrOffset(), payload.size()); + VectorOperations::AddInPlace(state.addresses, NumericCast(layout.GetAggrOffset()), payload.size()); // Now every cell has an entry, update the aggregates auto &aggregates = layout.GetAggregates(); @@ -258,7 +258,7 @@ idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, Vector &group_hashe if (filter_idx >= filter.size() || i < filter[filter_idx]) { // Skip all the aggregates that are not in the filter payload_idx += aggr.child_count; - VectorOperations::AddInPlace(state.addresses, aggr.payload_size, payload.size()); + VectorOperations::AddInPlace(state.addresses, NumericCast(aggr.payload_size), payload.size()); continue; } D_ASSERT(i == filter[filter_idx]); @@ -272,7 +272,7 @@ idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, Vector &group_hashe // Move to the next aggregate payload_idx += aggr.child_count; - VectorOperations::AddInPlace(state.addresses, aggr.payload_size, payload.size()); + VectorOperations::AddInPlace(state.addresses, NumericCast(aggr.payload_size), payload.size()); filter_idx++; } diff --git a/src/execution/operator/aggregate/distinct_aggregate_data.cpp b/src/execution/operator/aggregate/distinct_aggregate_data.cpp index cd76deee6482..d1d0d20d6eb5 100644 --- a/src/execution/operator/aggregate/distinct_aggregate_data.cpp +++ b/src/execution/operator/aggregate/distinct_aggregate_data.cpp @@ -151,7 +151,7 @@ idx_t DistinctAggregateCollectionInfo::CreateTableIndexMap() { std::find_if(table_inputs.begin(), table_inputs.end(), FindMatchingAggregate(std::ref(aggregate))); if (matching_inputs != table_inputs.end()) { //! Assign the existing table to the aggregate - idx_t found_idx = std::distance(table_inputs.begin(), matching_inputs); + auto found_idx = NumericCast(std::distance(table_inputs.begin(), matching_inputs)); table_map[agg_idx] = found_idx; continue; } diff --git a/src/execution/operator/aggregate/physical_hash_aggregate.cpp b/src/execution/operator/aggregate/physical_hash_aggregate.cpp index 5217c110bd26..bfed85859f4d 100644 --- a/src/execution/operator/aggregate/physical_hash_aggregate.cpp +++ b/src/execution/operator/aggregate/physical_hash_aggregate.cpp @@ -565,7 +565,7 @@ class HashAggregateDistinctFinalizeTask : public ExecutorTask { void HashAggregateDistinctFinalizeEvent::Schedule() { auto n_tasks = CreateGlobalSources(); - n_tasks = MinValue(n_tasks, TaskScheduler::GetScheduler(context).NumberOfThreads()); + n_tasks = MinValue(n_tasks, NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads())); vector> tasks; for (idx_t i = 0; i < n_tasks; i++) { tasks.push_back(make_uniq(*pipeline, shared_from_this(), op, gstate)); diff --git a/src/execution/operator/aggregate/physical_streaming_window.cpp b/src/execution/operator/aggregate/physical_streaming_window.cpp index 7f4974681d67..f1310350e623 100644 --- a/src/execution/operator/aggregate/physical_streaming_window.cpp +++ b/src/execution/operator/aggregate/physical_streaming_window.cpp @@ -142,7 +142,7 @@ OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, D auto data = FlatVector::GetData(result); int64_t start_row = gstate.row_number; for (idx_t i = 0; i < input.size(); ++i) { - data[i] = start_row + i; + data[i] = NumericCast(start_row + NumericCast(i)); } break; } @@ -192,7 +192,7 @@ OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, D int64_t start_row = gstate.row_number; auto rdata = FlatVector::GetData(chunk.data[col_idx]); for (idx_t i = 0; i < count; i++) { - rdata[i] = start_row + i; + rdata[i] = NumericCast(start_row + NumericCast(i)); } break; } @@ -200,7 +200,7 @@ OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, D throw NotImplementedException("%s for StreamingWindow", ExpressionTypeToString(expr.GetExpressionType())); } } - gstate.row_number += count; + gstate.row_number += NumericCast(count); chunk.SetCardinality(count); return OperatorResultType::NEED_MORE_INPUT; } diff --git a/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp b/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp index f1ceeab4bc7c..b542d4b6362f 100644 --- a/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp +++ b/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp @@ -454,7 +454,7 @@ void UngroupedDistinctAggregateFinalizeEvent::Schedule() { global_source_states.push_back(radix_table_p.GetGlobalSourceState(context)); } n_tasks = MaxValue(n_tasks, 1); - n_tasks = MinValue(n_tasks, TaskScheduler::GetScheduler(context).NumberOfThreads()); + n_tasks = MinValue(n_tasks, NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads())); vector> tasks; for (idx_t i = 0; i < n_tasks; i++) { diff --git a/src/execution/operator/aggregate/physical_window.cpp b/src/execution/operator/aggregate/physical_window.cpp index bcfe0a56bd3b..01e6752a2cd0 100644 --- a/src/execution/operator/aggregate/physical_window.cpp +++ b/src/execution/operator/aggregate/physical_window.cpp @@ -327,7 +327,7 @@ void WindowPartitionSourceState::MaterializeSortedData() { heap->blocks = std::move(sd.heap_blocks); hash_group.reset(); } else { - heap = make_uniq(buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1, true); + heap = make_uniq(buffer_manager, Storage::BLOCK_SIZE, 1U, true); } heap->count = std::accumulate(heap->blocks.begin(), heap->blocks.end(), idx_t(0), [&](idx_t c, const unique_ptr &b) { return c + b->count; }); diff --git a/src/execution/operator/csv_scanner/buffer_manager/csv_file_handle.cpp b/src/execution/operator/csv_scanner/buffer_manager/csv_file_handle.cpp index 5a73815e2763..90d32eebaab8 100644 --- a/src/execution/operator/csv_scanner/buffer_manager/csv_file_handle.cpp +++ b/src/execution/operator/csv_scanner/buffer_manager/csv_file_handle.cpp @@ -1,5 +1,6 @@ #include "duckdb/execution/operator/csv_scanner/csv_file_handle.hpp" #include "duckdb/common/exception/binder_exception.hpp" +#include "duckdb/common/numeric_utils.hpp" namespace duckdb { @@ -57,7 +58,7 @@ idx_t CSVFileHandle::Read(void *buffer, idx_t nr_bytes) { if (!finished) { finished = bytes_read == 0; } - return bytes_read; + return UnsafeNumericCast(bytes_read); } string CSVFileHandle::ReadLine() { diff --git a/src/execution/operator/csv_scanner/scanner/column_count_scanner.cpp b/src/execution/operator/csv_scanner/scanner/column_count_scanner.cpp index 8894cbcf4e76..70542ef4b472 100644 --- a/src/execution/operator/csv_scanner/scanner/column_count_scanner.cpp +++ b/src/execution/operator/csv_scanner/scanner/column_count_scanner.cpp @@ -50,7 +50,7 @@ ColumnCountScanner::ColumnCountScanner(shared_ptr buffer_manag } unique_ptr ColumnCountScanner::UpgradeToStringValueScanner() { - auto scanner = make_uniq(0, buffer_manager, state_machine, error_handler, nullptr, true); + auto scanner = make_uniq(0U, buffer_manager, state_machine, error_handler, nullptr, true); return scanner; } diff --git a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp index 9582e1c1af2f..3cff693b1bee 100644 --- a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp +++ b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp @@ -1075,7 +1075,7 @@ void StringValueScanner::SetStart() { } scan_finder = make_uniq( - 0, buffer_manager, state_machine, make_shared(true), csv_file_scan, false, iterator, 1); + 0U, buffer_manager, state_machine, make_shared(true), csv_file_scan, false, iterator, 1U); auto &tuples = scan_finder->ParseChunk(); line_found = true; if (tuples.number_of_rows != 1) { diff --git a/src/execution/operator/csv_scanner/sniffer/header_detection.cpp b/src/execution/operator/csv_scanner/sniffer/header_detection.cpp index a52e2aeae884..6617421e5726 100644 --- a/src/execution/operator/csv_scanner/sniffer/header_detection.cpp +++ b/src/execution/operator/csv_scanner/sniffer/header_detection.cpp @@ -7,9 +7,9 @@ namespace duckdb { // Helper function to generate column names static string GenerateColumnName(const idx_t total_cols, const idx_t col_number, const string &prefix = "column") { - int max_digits = NumericHelper::UnsignedLength(total_cols - 1); - int digits = NumericHelper::UnsignedLength(col_number); - string leading_zeros = string(max_digits - digits, '0'); + auto max_digits = NumericHelper::UnsignedLength(total_cols - 1); + auto digits = NumericHelper::UnsignedLength(col_number); + string leading_zeros = string(NumericCast(max_digits - digits), '0'); string value = to_string(col_number); return string(prefix + leading_zeros + value); } @@ -22,21 +22,21 @@ static string TrimWhitespace(const string &col_name) { // Find the first character that is not left trimmed idx_t begin = 0; while (begin < size) { - auto bytes = utf8proc_iterate(str + begin, size - begin, &codepoint); + auto bytes = utf8proc_iterate(str + begin, NumericCast(size - begin), &codepoint); D_ASSERT(bytes > 0); if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { break; } - begin += bytes; + begin += NumericCast(bytes); } // Find the last character that is not right trimmed idx_t end; end = begin; for (auto next = begin; next < col_name.size();) { - auto bytes = utf8proc_iterate(str + next, size - next, &codepoint); + auto bytes = utf8proc_iterate(str + next, NumericCast(size - next), &codepoint); D_ASSERT(bytes > 0); - next += bytes; + next += NumericCast(bytes); if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { end = next; } @@ -48,7 +48,8 @@ static string TrimWhitespace(const string &col_name) { static string NormalizeColumnName(const string &col_name) { // normalize UTF8 characters to NFKD - auto nfkd = utf8proc_NFKD(reinterpret_cast(col_name.c_str()), col_name.size()); + auto nfkd = utf8proc_NFKD(reinterpret_cast(col_name.c_str()), + NumericCast(col_name.size())); const string col_name_nfkd = string(const_char_ptr_cast(nfkd), strlen(const_char_ptr_cast(nfkd))); free(nfkd); diff --git a/src/execution/operator/csv_scanner/sniffer/type_detection.cpp b/src/execution/operator/csv_scanner/sniffer/type_detection.cpp index da9f7ed6983f..077cbdd2f8a0 100644 --- a/src/execution/operator/csv_scanner/sniffer/type_detection.cpp +++ b/src/execution/operator/csv_scanner/sniffer/type_detection.cpp @@ -50,7 +50,7 @@ static bool StartsWithNumericDate(string &separator, const string &value) { } // second literal must match first - if (((field3 - literal2) != (field2 - literal1)) || strncmp(literal1, literal2, (field2 - literal1)) != 0) { + if (((field3 - literal2) != (field2 - literal1)) || strncmp(literal1, literal2, NumericCast((field2 - literal1))) != 0) { return false; } diff --git a/src/execution/perfect_aggregate_hashtable.cpp b/src/execution/perfect_aggregate_hashtable.cpp index da7c20192452..a46e9499b5a3 100644 --- a/src/execution/perfect_aggregate_hashtable.cpp +++ b/src/execution/perfect_aggregate_hashtable.cpp @@ -64,7 +64,7 @@ static void ComputeGroupLocationTemplated(UnifiedVectorFormat &group_data, Value // we only need to handle non-null values here if (group_data.validity.RowIsValid(index)) { D_ASSERT(data[index] >= min_val); - uintptr_t adjusted_value = (data[index] - min_val) + 1; + auto adjusted_value = UnsafeNumericCast((data[index] - min_val) + 1); address_data[i] += adjusted_value << current_shift; } } @@ -72,7 +72,7 @@ static void ComputeGroupLocationTemplated(UnifiedVectorFormat &group_data, Value // no null values: we can directly compute the addresses for (idx_t i = 0; i < count; i++) { auto index = group_data.sel->get_index(i); - uintptr_t adjusted_value = (data[index] - min_val) + 1; + auto adjusted_value = UnsafeNumericCast((data[index] - min_val) + 1); address_data[i] += adjusted_value << current_shift; } } @@ -149,7 +149,7 @@ void PerfectAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload) } // move to the next aggregate payload_idx += input_count; - VectorOperations::AddInPlace(addresses, aggregate.payload_size, payload.size()); + VectorOperations::AddInPlace(addresses, NumericCast(aggregate.payload_size), payload.size()); } } @@ -199,7 +199,7 @@ static void ReconstructGroupVectorTemplated(uint32_t group_values[], Value &min, auto min_data = min.GetValueUnsafe(); for (idx_t i = 0; i < entry_count; i++) { // extract the value of this group from the total group index - auto group_index = UnsafeNumericCast((group_values[i] >> shift) & mask); + auto group_index = UnsafeNumericCast((group_values[i] >> shift) & mask); if (group_index == 0) { // if it is 0, the value is NULL validity_mask.SetInvalid(i); diff --git a/src/execution/physical_operator.cpp b/src/execution/physical_operator.cpp index ba5ba3e22158..4934789ecfe7 100644 --- a/src/execution/physical_operator.cpp +++ b/src/execution/physical_operator.cpp @@ -122,8 +122,8 @@ unique_ptr PhysicalOperator::GetGlobalSinkState(ClientContext & idx_t PhysicalOperator::GetMaxThreadMemory(ClientContext &context) { // Memory usage per thread should scale with max mem / num threads // We take 1/4th of this, to be conservative - idx_t max_memory = BufferManager::GetBufferManager(context).GetQueryMaxMemory(); - idx_t num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + auto max_memory = BufferManager::GetBufferManager(context).GetQueryMaxMemory(); + auto num_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); return (max_memory / num_threads) / 4; } diff --git a/src/execution/radix_partitioned_hashtable.cpp b/src/execution/radix_partitioned_hashtable.cpp index fe8575cb11ae..f148fefd1709 100644 --- a/src/execution/radix_partitioned_hashtable.cpp +++ b/src/execution/radix_partitioned_hashtable.cpp @@ -195,7 +195,7 @@ RadixHTGlobalSinkState::RadixHTGlobalSinkState(ClientContext &context_p, const R auto ht_size = blocks_per_partition * Storage::BLOCK_ALLOC_SIZE + config.sink_capacity * sizeof(aggr_ht_entry_t); // This really is the minimum reservation that we can do - idx_t num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + auto num_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); auto minimum_reservation = num_threads * ht_size; temporary_memory_state->SetMinimumReservation(minimum_reservation); @@ -273,12 +273,12 @@ void RadixHTConfig::SetRadixBitsInternal(const idx_t radix_bits_p, bool external } idx_t RadixHTConfig::InitialSinkRadixBits(ClientContext &context) { - const idx_t active_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + const auto active_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); return MinValue(RadixPartitioning::RadixBits(NextPowerOfTwo(active_threads)), MAXIMUM_INITIAL_SINK_RADIX_BITS); } idx_t RadixHTConfig::MaximumSinkRadixBits(ClientContext &context) { - const idx_t active_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + const auto active_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); return MinValue(RadixPartitioning::RadixBits(NextPowerOfTwo(active_threads)), MAXIMUM_FINAL_SINK_RADIX_BITS); } @@ -288,7 +288,7 @@ idx_t RadixHTConfig::ExternalRadixBits(const idx_t &maximum_sink_radix_bits_p) { idx_t RadixHTConfig::SinkCapacity(ClientContext &context) { // Get active and maximum number of threads - const idx_t active_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + const auto active_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); // Compute cache size per active thread (assuming cache is shared) const auto total_shared_cache_size = active_threads * L3_CACHE_SIZE; @@ -527,8 +527,8 @@ void RadixPartitionedHashTable::Finalize(ClientContext &context, GlobalSinkState // Minimum of combining one partition at a time gstate.temporary_memory_state->SetMinimumReservation(gstate.max_partition_size); // Maximum of combining all partitions - auto max_threads = - MinValue(TaskScheduler::GetScheduler(context).NumberOfThreads(), gstate.partitions.size()); + auto max_threads = MinValue(NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()), + gstate.partitions.size()); gstate.temporary_memory_state->SetRemainingSize(context, max_threads * gstate.max_partition_size); gstate.finalized = true; } @@ -545,8 +545,8 @@ idx_t RadixPartitionedHashTable::MaxThreads(GlobalSinkState &sink_p) const { // This many partitions will fit given our reservation (at least 1)) auto partitions_fit = MaxValue(sink.temporary_memory_state->GetReservation() / sink.max_partition_size, 1); // Maximum is either the number of partitions, or the number of threads - auto max_possible = - MinValue(sink.partitions.size(), TaskScheduler::GetScheduler(sink.context).NumberOfThreads()); + auto max_possible = MinValue( + sink.partitions.size(), NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())); // Mininum of the two return MinValue(partitions_fit, max_possible); @@ -726,8 +726,8 @@ void RadixHTLocalSourceState::Finalize(RadixHTGlobalSinkState &sink, RadixHTGlob const auto capacity = GroupedAggregateHashTable::GetCapacityForCount(partition.data->Count()); // However, we will limit the initial capacity so we don't do a huge over-allocation - const idx_t n_threads = TaskScheduler::GetScheduler(gstate.context).NumberOfThreads(); - const idx_t memory_limit = BufferManager::GetBufferManager(gstate.context).GetMaxMemory(); + const auto n_threads = NumericCast(TaskScheduler::GetScheduler(gstate.context).NumberOfThreads()); + const auto memory_limit = BufferManager::GetBufferManager(gstate.context).GetMaxMemory(); const idx_t thread_limit = 0.6 * memory_limit / n_threads; const idx_t size_per_entry = partition.data->SizeInBytes() / partition.data->Count() + diff --git a/src/execution/window_executor.cpp b/src/execution/window_executor.cpp index 627e8df630da..4d6b9b099c37 100644 --- a/src/execution/window_executor.cpp +++ b/src/execution/window_executor.cpp @@ -139,11 +139,11 @@ struct WindowColumnIterator { // Random Access inline iterator &operator+=(difference_type n) { - pos += n; + pos += UnsafeNumericCast(n); return *this; } inline iterator &operator-=(difference_type n) { - pos -= n; + pos -= UnsafeNumericCast(n); return *this; } @@ -232,7 +232,7 @@ static idx_t FindTypedRangeBound(const WindowInputColumn &over, const idx_t orde const auto first = over.GetCell(prev.start); if (!comp(val, first)) { // prev.first <= val, so we can start further forward - begin += (prev.start - order_begin); + begin += UnsafeNumericCast(prev.start - order_begin); } } if (order_begin < prev.end && prev.end < order_end) { @@ -240,7 +240,7 @@ static idx_t FindTypedRangeBound(const WindowInputColumn &over, const idx_t orde if (!comp(second, val)) { // val <= prev.second, so we can end further back // (prev.second is the largest peer) - end -= (order_end - prev.end - 1); + end -= UnsafeNumericCast(order_end - prev.end - 1); } } } @@ -452,13 +452,13 @@ void WindowBoundariesState::Update(const idx_t row_idx, const WindowInputColumn switch (start_boundary) { case WindowBoundary::UNBOUNDED_PRECEDING: - window_start = partition_start; + window_start = NumericCast(partition_start); break; case WindowBoundary::CURRENT_ROW_ROWS: - window_start = row_idx; + window_start = NumericCast(row_idx); break; case WindowBoundary::CURRENT_ROW_RANGE: - window_start = peer_start; + window_start = NumericCast(peer_start); break; case WindowBoundary::EXPR_PRECEDING_ROWS: { if (!TrySubtractOperator::Operation(int64_t(row_idx), boundary_start.GetCell(chunk_idx), @@ -475,21 +475,21 @@ void WindowBoundariesState::Update(const idx_t row_idx, const WindowInputColumn } case WindowBoundary::EXPR_PRECEDING_RANGE: { if (boundary_start.CellIsNull(chunk_idx)) { - window_start = peer_start; + window_start = NumericCast(peer_start); } else { prev.start = FindOrderedRangeBound(range_collection, range_sense, valid_start, row_idx, start_boundary, boundary_start, chunk_idx, prev); - window_start = prev.start; + window_start = NumericCast(prev.start); } break; } case WindowBoundary::EXPR_FOLLOWING_RANGE: { if (boundary_start.CellIsNull(chunk_idx)) { - window_start = peer_start; + window_start = NumericCast(peer_start); } else { prev.start = FindOrderedRangeBound(range_collection, range_sense, row_idx, valid_end, start_boundary, boundary_start, chunk_idx, prev); - window_start = prev.start; + window_start = NumericCast(prev.start); } break; } @@ -499,13 +499,13 @@ void WindowBoundariesState::Update(const idx_t row_idx, const WindowInputColumn switch (end_boundary) { case WindowBoundary::CURRENT_ROW_ROWS: - window_end = row_idx + 1; + window_end = NumericCast(row_idx + 1); break; case WindowBoundary::CURRENT_ROW_RANGE: - window_end = peer_end; + window_end = NumericCast(peer_end); break; case WindowBoundary::UNBOUNDED_FOLLOWING: - window_end = partition_end; + window_end = NumericCast(partition_end); break; case WindowBoundary::EXPR_PRECEDING_ROWS: if (!TrySubtractOperator::Operation(int64_t(row_idx + 1), boundary_end.GetCell(chunk_idx), @@ -520,21 +520,21 @@ void WindowBoundariesState::Update(const idx_t row_idx, const WindowInputColumn break; case WindowBoundary::EXPR_PRECEDING_RANGE: { if (boundary_end.CellIsNull(chunk_idx)) { - window_end = peer_end; + window_end = NumericCast(peer_end); } else { prev.end = FindOrderedRangeBound(range_collection, range_sense, valid_start, row_idx, end_boundary, boundary_end, chunk_idx, prev); - window_end = prev.end; + window_end = NumericCast(prev.end); } break; } case WindowBoundary::EXPR_FOLLOWING_RANGE: { if (boundary_end.CellIsNull(chunk_idx)) { - window_end = peer_end; + window_end = NumericCast(peer_end); } else { prev.end = FindOrderedRangeBound(range_collection, range_sense, row_idx, valid_end, end_boundary, boundary_end, chunk_idx, prev); - window_end = prev.end; + window_end = NumericCast(prev.end); } break; } @@ -543,17 +543,17 @@ void WindowBoundariesState::Update(const idx_t row_idx, const WindowInputColumn } // clamp windows to partitions if they should exceed - if (window_start < (int64_t)partition_start) { - window_start = partition_start; + if (window_start < NumericCast(partition_start)) { + window_start = NumericCast(partition_start); } - if (window_start > (int64_t)partition_end) { - window_start = partition_end; + if (window_start > NumericCast(partition_end)) { + window_start = NumericCast(partition_end); } - if (window_end < (int64_t)partition_start) { - window_end = partition_start; + if (window_end < NumericCast(partition_start)) { + window_end = NumericCast(partition_start); } - if (window_end > (int64_t)partition_end) { - window_end = partition_end; + if (window_end > NumericCast(partition_end)) { + window_end = NumericCast(partition_end); } if (window_start < 0 || window_end < 0) { @@ -1029,7 +1029,7 @@ void WindowAggregateExecutor::Finalize() { // Estimate the frame statistics // Default to the entire partition if we don't know anything FrameStats stats; - const int64_t count = aggregator->GetInputs().size(); + const auto count = NumericCast(aggregator->GetInputs().size()); // First entry is the frame start stats[0] = FrameDelta(-count, count); @@ -1090,7 +1090,7 @@ void WindowRowNumberExecutor::EvaluateInternal(WindowExecutorState &lstate, Vect auto partition_begin = FlatVector::GetData(lbstate.bounds.data[PARTITION_BEGIN]); auto rdata = FlatVector::GetData(result); for (idx_t i = 0; i < count; ++i, ++row_idx) { - rdata[i] = row_idx - partition_begin[i] + 1; + rdata[i] = NumericCast(row_idx - partition_begin[i] + 1); } } @@ -1147,7 +1147,7 @@ void WindowRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector &r for (idx_t i = 0; i < count; ++i, ++row_idx) { lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); - rdata[i] = lpeer.rank; + rdata[i] = NumericCast(lpeer.rank); } } @@ -1209,7 +1209,7 @@ void WindowDenseRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Vect for (idx_t i = 0; i < count; ++i, ++row_idx) { lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); - rdata[i] = lpeer.dense_rank; + rdata[i] = NumericCast(lpeer.dense_rank); } } @@ -1237,7 +1237,7 @@ void WindowPercentRankExecutor::EvaluateInternal(WindowExecutorState &lstate, Ve for (idx_t i = 0; i < count; ++i, ++row_idx) { lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); - int64_t denom = partition_end[i] - partition_begin[i] - 1; + auto denom = NumericCast(partition_end[i] - partition_begin[i] - 1); double percent_rank = denom > 0 ? ((double)lpeer.rank - 1) / denom : 0; rdata[i] = percent_rank; } @@ -1260,7 +1260,7 @@ void WindowCumeDistExecutor::EvaluateInternal(WindowExecutorState &lstate, Vecto auto peer_end = FlatVector::GetData(lbstate.bounds.data[PEER_END]); auto rdata = FlatVector::GetData(result); for (idx_t i = 0; i < count; ++i, ++row_idx) { - int64_t denom = partition_end[i] - partition_begin[i]; + auto denom = NumericCast(partition_end[i] - partition_begin[i]); double cume_dist = denom > 0 ? ((double)(peer_end[i] - partition_begin[i])) / denom : 0; rdata[i] = cume_dist; } @@ -1365,7 +1365,7 @@ void WindowNtileExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector & throw InvalidInputException("Argument for ntile must be greater than zero"); } // With thanks from SQLite's ntileValueFunc() - int64_t n_total = partition_end[i] - partition_begin[i]; + auto n_total = NumericCast(partition_end[i] - partition_begin[i]); if (n_param > n_total) { // more groups allowed than we have values // map every entry to a unique group @@ -1374,7 +1374,7 @@ void WindowNtileExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector & int64_t n_size = (n_total / n_param); // find the row idx within the group D_ASSERT(row_idx >= partition_begin[i]); - int64_t adjusted_row_idx = row_idx - partition_begin[i]; + auto adjusted_row_idx = NumericCast(row_idx - partition_begin[i]); // now compute the ntile int64_t n_large = n_total - n_param * n_size; int64_t i_small = n_large * (n_size + 1); @@ -1452,16 +1452,16 @@ void WindowLeadLagExecutor::EvaluateInternal(WindowExecutorState &lstate, Vector idx_t delta = 0; if (val_idx < (int64_t)row_idx) { // Count backwards - delta = idx_t(row_idx - val_idx); - val_idx = FindPrevStart(ignore_nulls, partition_begin[i], row_idx, delta); + delta = idx_t(row_idx - idx_t(val_idx)); + val_idx = int64_t(FindPrevStart(ignore_nulls, partition_begin[i], row_idx, delta)); } else if (val_idx > (int64_t)row_idx) { - delta = idx_t(val_idx - row_idx); - val_idx = FindNextStart(ignore_nulls, row_idx + 1, partition_end[i], delta); + delta = idx_t(idx_t(val_idx) - row_idx); + val_idx = int64_t(FindNextStart(ignore_nulls, row_idx + 1, partition_end[i], delta)); } // else offset is zero, so don't move. if (!delta) { - CopyCell(payload_collection, 0, val_idx, result, i); + CopyCell(payload_collection, 0, NumericCast(val_idx), result, i); } else if (wexpr.default_expr) { llstate.leadlag_default.CopyCell(result, i); } else { diff --git a/src/include/duckdb/common/vector_size.hpp b/src/include/duckdb/common/vector_size.hpp index 0fa29086b220..3eb73b36b3bd 100644 --- a/src/include/duckdb/common/vector_size.hpp +++ b/src/include/duckdb/common/vector_size.hpp @@ -13,7 +13,7 @@ namespace duckdb { //! The default standard vector size -#define DEFAULT_STANDARD_VECTOR_SIZE 2048 +#define DEFAULT_STANDARD_VECTOR_SIZE 2048U //! The vector size used in the execution engine #ifndef STANDARD_VECTOR_SIZE diff --git a/src/include/duckdb/execution/merge_sort_tree.hpp b/src/include/duckdb/execution/merge_sort_tree.hpp index 4f752a75410b..825c05f872de 100644 --- a/src/include/duckdb/execution/merge_sort_tree.hpp +++ b/src/include/duckdb/execution/merge_sort_tree.hpp @@ -515,8 +515,8 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons entry.first = run_idx.first * level_width; entry.second = std::min(entry.first + level_width, static_cast(tree[0].first.size())); auto *level_data = tree[level].first.data(); - idx_t entry_idx = - std::lower_bound(level_data + entry.first, level_data + entry.second, needle) - level_data; + auto entry_idx = NumericCast( + std::lower_bound(level_data + entry.first, level_data + entry.second, needle) - level_data); cascading_idx.first = cascading_idx.second = (entry_idx / CASCADING + 2 * (entry.first / level_width)) * FANOUT; @@ -545,7 +545,7 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons const auto run_pos = std::lower_bound(search_begin, search_end, needle) - level_data; // Compute runBegin and pass it to our callback const auto run_begin = curr.first - level_width; - aggregate(level, run_begin, run_pos); + aggregate(level, run_begin, NumericCast(run_pos)); // Update state for next round curr.first -= level_width; --cascading_idx.first; @@ -554,7 +554,7 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons if (curr.first != lower) { const auto *search_begin = level_data + cascading_idcs[cascading_idx.first]; const auto *search_end = level_data + cascading_idcs[cascading_idx.first + FANOUT]; - auto idx = std::lower_bound(search_begin, search_end, needle) - level_data; + auto idx = NumericCast(std::lower_bound(search_begin, search_end, needle) - level_data); cascading_idx.first = (idx / CASCADING + 2 * (lower / level_width)) * FANOUT; } @@ -567,7 +567,7 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons const auto run_pos = std::lower_bound(search_begin, search_end, needle) - level_data; // Compute runBegin and pass it to our callback const auto run_begin = curr.second; - aggregate(level, run_begin, run_pos); + aggregate(level, run_begin, NumericCast(run_pos)); // Update state for next round curr.second += level_width; ++cascading_idx.second; @@ -576,7 +576,7 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons if (curr.second != upper) { const auto *search_begin = level_data + cascading_idcs[cascading_idx.second]; const auto *search_end = level_data + cascading_idcs[cascading_idx.second + FANOUT]; - auto idx = std::lower_bound(search_begin, search_end, needle) - level_data; + auto idx = NumericCast(std::lower_bound(search_begin, search_end, needle) - level_data); cascading_idx.second = (idx / CASCADING + 2 * (upper / level_width)) * FANOUT; } } while (level >= LowestCascadingLevel()); @@ -591,8 +591,9 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons while (curr.first - lower >= level_width) { const auto *search_end = level_data + curr.first; const auto *search_begin = search_end - level_width; - const auto run_pos = std::lower_bound(search_begin, search_end, needle) - level_data; - const auto run_begin = search_begin - level_data; + const auto run_pos = + NumericCast(std::lower_bound(search_begin, search_end, needle) - level_data); + const auto run_begin = NumericCast(search_begin - level_data); aggregate(level, run_begin, run_pos); curr.first -= level_width; } @@ -600,8 +601,9 @@ void MergeSortTree::AggregateLowerBound(const idx_t lower, cons while (upper - curr.second >= level_width) { const auto *search_begin = level_data + curr.second; const auto *search_end = search_begin + level_width; - const auto run_pos = std::lower_bound(search_begin, search_end, needle) - level_data; - const auto run_begin = search_begin - level_data; + const auto run_pos = + NumericCast(std::lower_bound(search_begin, search_end, needle) - level_data); + const auto run_begin = NumericCast(search_begin - level_data); aggregate(level, run_begin, run_pos); curr.second += level_width; } diff --git a/src/include/duckdb/storage/storage_info.hpp b/src/include/duckdb/storage/storage_info.hpp index 1acdbba67cd3..e3d36025e681 100644 --- a/src/include/duckdb/storage/storage_info.hpp +++ b/src/include/duckdb/storage/storage_info.hpp @@ -29,7 +29,7 @@ using block_id_t = int64_t; struct Storage { //! The size of a hard disk sector, only really needed for Direct IO - constexpr static idx_t SECTOR_SIZE = 4096; + constexpr static idx_t SECTOR_SIZE = 4096U; //! Block header size for blocks written to the storage constexpr static idx_t BLOCK_HEADER_SIZE = sizeof(uint64_t); //! Size of a memory slot managed by the StorageManager. This is the quantum of allocation for Blocks on DuckDB. We @@ -39,7 +39,7 @@ struct Storage { constexpr static idx_t BLOCK_SIZE = BLOCK_ALLOC_SIZE - BLOCK_HEADER_SIZE; //! The size of the headers. This should be small and written more or less atomically by the hard disk. We default //! to the page size, which is 4KB. (1 << 12) - constexpr static idx_t FILE_HEADER_SIZE = 4096; + constexpr static idx_t FILE_HEADER_SIZE = 4096U; //! The number of rows per row group (must be a multiple of the vector size) constexpr static const idx_t ROW_GROUP_SIZE = STANDARD_ROW_GROUPS_SIZE; //! The number of vectors per row group From 8591d72fdb0f3a1bd06cba858b74a946a8587a16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Thu, 4 Apr 2024 15:59:14 +0200 Subject: [PATCH 047/201] moo --- .../csv_scanner/sniffer/type_detection.cpp | 5 +++-- .../table_function/global_csv_state.cpp | 2 +- .../csv_scanner/util/csv_reader_options.cpp | 20 +++++++++---------- .../operator/helper/physical_execute.cpp | 2 +- .../helper/physical_reservoir_sample.cpp | 2 +- .../operator/join/physical_asof_join.cpp | 2 +- .../operator/join/physical_hash_join.cpp | 12 +++++------ .../operator/join/physical_iejoin.cpp | 8 ++++---- .../join/physical_piecewise_merge_join.cpp | 4 ++-- .../operator/join/physical_range_join.cpp | 2 +- .../operator/order/physical_order.cpp | 2 +- .../physical_batch_copy_to_file.cpp | 6 +++--- .../persistent/physical_batch_insert.cpp | 9 +++++---- .../persistent/physical_copy_to_file.cpp | 2 +- .../operator/persistent/physical_delete.cpp | 2 +- .../operator/persistent/physical_export.cpp | 2 +- .../operator/persistent/physical_insert.cpp | 4 ++-- .../operator/persistent/physical_update.cpp | 2 +- .../physical_plan/plan_comparison_join.cpp | 2 +- .../physical_plan/plan_create_table.cpp | 4 ++-- src/execution/physical_plan/plan_top_n.cpp | 4 ++-- src/main/capi/arrow-c.cpp | 8 ++++---- src/main/capi/result-c.cpp | 4 ++-- src/main/client_context.cpp | 3 ++- src/main/database.cpp | 2 +- src/main/query_profiler.cpp | 2 +- src/main/relation.cpp | 2 +- src/main/secret/secret.cpp | 2 +- src/main/settings/settings.cpp | 18 ++++++++--------- src/parallel/pipeline.cpp | 4 ++-- src/parallel/task_scheduler.cpp | 4 ++-- 31 files changed, 75 insertions(+), 72 deletions(-) diff --git a/src/execution/operator/csv_scanner/sniffer/type_detection.cpp b/src/execution/operator/csv_scanner/sniffer/type_detection.cpp index 077cbdd2f8a0..65c2b5b3534e 100644 --- a/src/execution/operator/csv_scanner/sniffer/type_detection.cpp +++ b/src/execution/operator/csv_scanner/sniffer/type_detection.cpp @@ -50,7 +50,8 @@ static bool StartsWithNumericDate(string &separator, const string &value) { } // second literal must match first - if (((field3 - literal2) != (field2 - literal1)) || strncmp(literal1, literal2, NumericCast((field2 - literal1))) != 0) { + if (((field3 - literal2) != (field2 - literal1)) || + strncmp(literal1, literal2, NumericCast((field2 - literal1))) != 0) { return false; } @@ -69,7 +70,7 @@ static bool StartsWithNumericDate(string &separator, const string &value) { string GenerateDateFormat(const string &separator, const char *format_template) { string format_specifier = format_template; - auto amount_of_dashes = std::count(format_specifier.begin(), format_specifier.end(), '-'); + auto amount_of_dashes = NumericCast(std::count(format_specifier.begin(), format_specifier.end(), '-')); // All our date formats must have at least one - D_ASSERT(amount_of_dashes); string result; diff --git a/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp b/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp index b5943a9f5b8b..f89bcfef5b18 100644 --- a/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp +++ b/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp @@ -22,7 +22,7 @@ CSVGlobalState::CSVGlobalState(ClientContext &context_p, const shared_ptr(context, files[0], options, 0, bind_data, column_ids, file_schema)); + make_uniq(context, files[0], options, 0U, bind_data, column_ids, file_schema)); }; //! There are situations where we only support single threaded scanning bool many_csv_files = files.size() > 1 && files.size() > system_threads * 2; diff --git a/src/execution/operator/csv_scanner/util/csv_reader_options.cpp b/src/execution/operator/csv_scanner/util/csv_reader_options.cpp index 836ddd731827..70d909941f66 100644 --- a/src/execution/operator/csv_scanner/util/csv_reader_options.cpp +++ b/src/execution/operator/csv_scanner/util/csv_reader_options.cpp @@ -90,11 +90,11 @@ void CSVReaderOptions::SetEscape(const string &input) { } int64_t CSVReaderOptions::GetSkipRows() const { - return this->dialect_options.skip_rows.GetValue(); + return NumericCast(this->dialect_options.skip_rows.GetValue()); } void CSVReaderOptions::SetSkipRows(int64_t skip_rows) { - dialect_options.skip_rows.Set(skip_rows); + dialect_options.skip_rows.Set(NumericCast(skip_rows)); } string CSVReaderOptions::GetDelimiter() const { @@ -162,7 +162,7 @@ void CSVReaderOptions::SetReadOption(const string &loption, const Value &value, if (loption == "auto_detect") { auto_detect = ParseBoolean(value, loption); } else if (loption == "sample_size") { - int64_t sample_size_option = ParseInteger(value, loption); + auto sample_size_option = ParseInteger(value, loption); if (sample_size_option < 1 && sample_size_option != -1) { throw BinderException("Unsupported parameter for SAMPLE_SIZE: cannot be smaller than 1"); } @@ -170,7 +170,7 @@ void CSVReaderOptions::SetReadOption(const string &loption, const Value &value, // If -1, we basically read the whole thing sample_size_chunks = NumericLimits().Maximum(); } else { - sample_size_chunks = sample_size_option / STANDARD_VECTOR_SIZE; + sample_size_chunks = NumericCast(sample_size_option / STANDARD_VECTOR_SIZE); if (sample_size_option % STANDARD_VECTOR_SIZE != 0) { sample_size_chunks++; } @@ -179,7 +179,7 @@ void CSVReaderOptions::SetReadOption(const string &loption, const Value &value, } else if (loption == "skip") { SetSkipRows(ParseInteger(value, loption)); } else if (loption == "max_line_size" || loption == "maximum_line_size") { - maximum_line_size = ParseInteger(value, loption); + maximum_line_size = NumericCast(ParseInteger(value, loption)); } else if (loption == "force_not_null") { force_not_null = ParseColumnList(value, expected_names, loption); } else if (loption == "date_format" || loption == "dateformat") { @@ -191,7 +191,7 @@ void CSVReaderOptions::SetReadOption(const string &loption, const Value &value, } else if (loption == "ignore_errors") { ignore_errors = ParseBoolean(value, loption); } else if (loption == "buffer_size") { - buffer_size = ParseInteger(value, loption); + buffer_size = NumericCast(ParseInteger(value, loption)); if (buffer_size == 0) { throw InvalidInputException("Buffer Size option must be higher than 0"); } @@ -221,11 +221,11 @@ void CSVReaderOptions::SetReadOption(const string &loption, const Value &value, rejects_recovery_columns.push_back(col_name); } } else if (loption == "rejects_limit") { - int64_t limit = ParseInteger(value, loption); + auto limit = ParseInteger(value, loption); if (limit < 0) { throw BinderException("Unsupported parameter for REJECTS_LIMIT: cannot be negative"); } - rejects_limit = limit; + rejects_limit = NumericCast(limit); } else { throw BinderException("Unrecognized option for CSV reader \"%s\"", loption); } @@ -523,7 +523,7 @@ void CSVReaderOptions::ToNamedParameters(named_parameter_map_t &named_params) { if (header.IsSetByUser()) { named_params["header"] = Value(GetHeader()); } - named_params["max_line_size"] = Value::BIGINT(maximum_line_size); + named_params["max_line_size"] = Value::BIGINT(NumericCast(maximum_line_size)); if (dialect_options.skip_rows.IsSetByUser()) { named_params["skip"] = Value::BIGINT(GetSkipRows()); } @@ -541,7 +541,7 @@ void CSVReaderOptions::ToNamedParameters(named_parameter_map_t &named_params) { named_params["column_names"] = StringVectorToValue(name_list); } named_params["all_varchar"] = Value::BOOLEAN(all_varchar); - named_params["maximum_line_size"] = Value::BIGINT(maximum_line_size); + named_params["maximum_line_size"] = Value::BIGINT(NumericCast(maximum_line_size)); } } // namespace duckdb diff --git a/src/execution/operator/helper/physical_execute.cpp b/src/execution/operator/helper/physical_execute.cpp index 4a07692181ac..22cab9be8dcc 100644 --- a/src/execution/operator/helper/physical_execute.cpp +++ b/src/execution/operator/helper/physical_execute.cpp @@ -5,7 +5,7 @@ namespace duckdb { PhysicalExecute::PhysicalExecute(PhysicalOperator &plan) - : PhysicalOperator(PhysicalOperatorType::EXECUTE, plan.types, -1), plan(plan) { + : PhysicalOperator(PhysicalOperatorType::EXECUTE, plan.types, idx_t(-1)), plan(plan) { } vector> PhysicalExecute::GetChildren() const { diff --git a/src/execution/operator/helper/physical_reservoir_sample.cpp b/src/execution/operator/helper/physical_reservoir_sample.cpp index 1bc87d25fc3b..b253025f4728 100644 --- a/src/execution/operator/helper/physical_reservoir_sample.cpp +++ b/src/execution/operator/helper/physical_reservoir_sample.cpp @@ -17,7 +17,7 @@ class SampleGlobalSinkState : public GlobalSinkState { } sample = make_uniq(allocator, percentage, options.seed); } else { - auto size = options.sample_size.GetValue(); + auto size = NumericCast(options.sample_size.GetValue()); if (size == 0) { return; } diff --git a/src/execution/operator/join/physical_asof_join.cpp b/src/execution/operator/join/physical_asof_join.cpp index 3996063315c5..8d65a1d6e7eb 100644 --- a/src/execution/operator/join/physical_asof_join.cpp +++ b/src/execution/operator/join/physical_asof_join.cpp @@ -159,7 +159,7 @@ SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, Cl // The data is all in so we can initialise the left partitioning. const vector> partitions_stats; gstate.lhs_sink = make_uniq(context, lhs_partitions, lhs_orders, children[0]->types, - partitions_stats, 0); + partitions_stats, 0U); gstate.lhs_sink->SyncPartitioning(gstate.rhs_sink); // Find the first group to sort diff --git a/src/execution/operator/join/physical_hash_join.cpp b/src/execution/operator/join/physical_hash_join.cpp index 1e1e9fd12384..8632c44bda13 100644 --- a/src/execution/operator/join/physical_hash_join.cpp +++ b/src/execution/operator/join/physical_hash_join.cpp @@ -88,7 +88,7 @@ PhysicalHashJoin::PhysicalHashJoin(LogicalOperator &op, unique_ptr(TaskScheduler::GetScheduler(context).NumberOfThreads())), temporary_memory_update_count(0), temporary_memory_state(TemporaryMemoryManager::Get(context).Register(context)), finalized(false), scanned_data(false) { @@ -206,7 +206,7 @@ unique_ptr PhysicalHashJoin::InitializeHashTable(ClientContext &c auto count_fun = CountFun::GetFunction(); vector> children; // this is a dummy but we need it to make the hash table understand whats going on - children.push_back(make_uniq_base(count_fun.return_type, 0)); + children.push_back(make_uniq_base(count_fun.return_type, 0U)); aggr = function_binder.BindAggregateFunction(count_fun, std::move(children), nullptr, AggregateType::NON_DISTINCT); correlated_aggregates.push_back(&*aggr); @@ -321,11 +321,11 @@ class HashJoinFinalizeEvent : public BasePipelineEvent { vector> finalize_tasks; auto &ht = *sink.hash_table; const auto chunk_count = ht.GetDataCollection().ChunkCount(); - const idx_t num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); + const auto num_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); if (num_threads == 1 || (ht.Count() < PARALLEL_CONSTRUCT_THRESHOLD && !context.config.verify_parallelism)) { // Single-threaded finalize finalize_tasks.push_back( - make_uniq(shared_from_this(), context, sink, 0, chunk_count, false)); + make_uniq(shared_from_this(), context, sink, 0U, chunk_count, false)); } else { // Parallel finalize auto chunks_per_thread = MaxValue((chunk_count + num_threads - 1) / num_threads, 1); @@ -816,7 +816,7 @@ void HashJoinGlobalSourceState::PrepareBuild(HashJoinGlobalSinkState &sink) { build_chunk_count = data_collection.ChunkCount(); build_chunk_done = 0; - auto num_threads = TaskScheduler::GetScheduler(sink.context).NumberOfThreads(); + auto num_threads = NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads()); build_chunks_per_thread = MaxValue((build_chunk_count + num_threads - 1) / num_threads, 1); ht.InitializePointerTable(); @@ -847,7 +847,7 @@ void HashJoinGlobalSourceState::PrepareScanHT(HashJoinGlobalSinkState &sink) { full_outer_chunk_count = data_collection.ChunkCount(); full_outer_chunk_done = 0; - auto num_threads = TaskScheduler::GetScheduler(sink.context).NumberOfThreads(); + auto num_threads = NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads()); full_outer_chunks_per_thread = MaxValue((full_outer_chunk_count + num_threads - 1) / num_threads, 1); global_stage = HashJoinSourceStage::SCAN_HT; diff --git a/src/execution/operator/join/physical_iejoin.cpp b/src/execution/operator/join/physical_iejoin.cpp index 6e21b8602919..bc99e9fbf0a8 100644 --- a/src/execution/operator/join/physical_iejoin.cpp +++ b/src/execution/operator/join/physical_iejoin.cpp @@ -325,7 +325,7 @@ idx_t IEJoinUnion::AppendKey(SortedTable &table, ExpressionExecutor &executor, S payload.data[0].Sequence(rid, increment, scan_count); payload.SetCardinality(scan_count); keys.Fuse(payload); - rid += increment * scan_count; + rid += UnsafeNumericCast(increment) * scan_count; // Sort on the sort columns (which will no longer be needed) keys.Split(payload, payload_idx); @@ -385,7 +385,7 @@ IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, Sorte payload_layout.Initialize(types); // Sort on the first expression - auto ref = make_uniq(order1.expression->return_type, 0); + auto ref = make_uniq(order1.expression->return_type, 0U); vector orders; orders.emplace_back(order1.type, order1.null_order, std::move(ref)); @@ -426,7 +426,7 @@ IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, Sorte // Sort on the first expression orders.clear(); - ref = make_uniq(order2.expression->return_type, 0); + ref = make_uniq(order2.expression->return_type, 0U); orders.emplace_back(order2.type, order2.null_order, std::move(ref)); ExpressionExecutor executor(context); @@ -434,7 +434,7 @@ IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, Sorte l2 = make_uniq(context, orders, payload_layout); for (idx_t base = 0, block_idx = 0; block_idx < l1->BlockCount(); ++block_idx) { - base += AppendKey(*l1, executor, *l2, 1, base, block_idx); + base += AppendKey(*l1, executor, *l2, 1, NumericCast(base), block_idx); } Sort(*l2); diff --git a/src/execution/operator/join/physical_piecewise_merge_join.cpp b/src/execution/operator/join/physical_piecewise_merge_join.cpp index 433ff32982b0..287c797186b7 100644 --- a/src/execution/operator/join/physical_piecewise_merge_join.cpp +++ b/src/execution/operator/join/physical_piecewise_merge_join.cpp @@ -105,7 +105,7 @@ unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSinkState(Clien unique_ptr PhysicalPiecewiseMergeJoin::GetLocalSinkState(ExecutionContext &context) const { // We only sink the RHS - return make_uniq(context.client, *this, 1); + return make_uniq(context.client, *this, 1U); } SinkResultType PhysicalPiecewiseMergeJoin::Sink(ExecutionContext &context, DataChunk &chunk, @@ -223,7 +223,7 @@ class PiecewiseMergeJoinState : public CachingOperatorState { void ResolveJoinKeys(DataChunk &input) { // sort by join key lhs_global_state = make_uniq(buffer_manager, lhs_order, lhs_layout); - lhs_local_table = make_uniq(context, op, 0); + lhs_local_table = make_uniq(context, op, 0U); lhs_local_table->Sink(input, *lhs_global_state); // Set external (can be forced with the PRAGMA) diff --git a/src/execution/operator/join/physical_range_join.cpp b/src/execution/operator/join/physical_range_join.cpp index 5d6d44f47bff..8962663e96ad 100644 --- a/src/execution/operator/join/physical_range_join.cpp +++ b/src/execution/operator/join/physical_range_join.cpp @@ -126,7 +126,7 @@ class RangeJoinMergeEvent : public BasePipelineEvent { // Schedule tasks equal to the number of threads, which will each merge multiple partitions auto &ts = TaskScheduler::GetScheduler(context); - idx_t num_threads = ts.NumberOfThreads(); + auto num_threads = NumericCast(ts.NumberOfThreads()); vector> iejoin_tasks; for (idx_t tnum = 0; tnum < num_threads; tnum++) { diff --git a/src/execution/operator/order/physical_order.cpp b/src/execution/operator/order/physical_order.cpp index ac933f3fdc84..f74693bd86cb 100644 --- a/src/execution/operator/order/physical_order.cpp +++ b/src/execution/operator/order/physical_order.cpp @@ -143,7 +143,7 @@ class OrderMergeEvent : public BasePipelineEvent { // Schedule tasks equal to the number of threads, which will each merge multiple partitions auto &ts = TaskScheduler::GetScheduler(context); - idx_t num_threads = ts.NumberOfThreads(); + auto num_threads = NumericCast(ts.NumberOfThreads()); vector> merge_tasks; for (idx_t tnum = 0; tnum < num_threads; tnum++) { diff --git a/src/execution/operator/persistent/physical_batch_copy_to_file.cpp b/src/execution/operator/persistent/physical_batch_copy_to_file.cpp index 5447f433cb69..c2c4d06ff9dd 100644 --- a/src/execution/operator/persistent/physical_batch_copy_to_file.cpp +++ b/src/execution/operator/persistent/physical_batch_copy_to_file.cpp @@ -434,7 +434,7 @@ void PhysicalBatchCopyToFile::RepartitionBatches(ClientContext &context, GlobalS // create an empty collection auto new_collection = make_uniq(context, children[0]->types, ColumnDataAllocatorType::HYBRID); - append_batch = make_uniq(0, std::move(new_collection)); + append_batch = make_uniq(0U, std::move(new_collection)); } if (append_batch) { append_batch->collection->InitializeAppend(append_state); @@ -459,7 +459,7 @@ void PhysicalBatchCopyToFile::RepartitionBatches(ClientContext &context, GlobalS auto new_collection = make_uniq(context, children[0]->types, ColumnDataAllocatorType::HYBRID); - append_batch = make_uniq(0, std::move(new_collection)); + append_batch = make_uniq(0U, std::move(new_collection)); append_batch->collection->InitializeAppend(append_state); } } @@ -605,7 +605,7 @@ SourceResultType PhysicalBatchCopyToFile::GetData(ExecutionContext &context, Dat auto &g = sink_state->Cast(); chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(g.rows_copied)); + chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.rows_copied.load()))); return SourceResultType::FINISHED; } diff --git a/src/execution/operator/persistent/physical_batch_insert.cpp b/src/execution/operator/persistent/physical_batch_insert.cpp index f2f242a90d78..9f54ec44c6c9 100644 --- a/src/execution/operator/persistent/physical_batch_insert.cpp +++ b/src/execution/operator/persistent/physical_batch_insert.cpp @@ -166,7 +166,8 @@ class BatchInsertLocalState : public LocalSinkState { void CreateNewCollection(DuckTableEntry &table, const vector &insert_types) { auto &table_info = table.GetStorage().info; auto &block_manager = TableIOManager::Get(table.GetStorage()).GetBlockManagerForRowData(); - current_collection = make_uniq(table_info, block_manager, insert_types, MAX_ROW_ID); + current_collection = + make_uniq(table_info, block_manager, insert_types, NumericCast(MAX_ROW_ID)); current_collection->InitializeEmpty(); current_collection->InitializeAppend(current_append_state); } @@ -303,8 +304,8 @@ void BatchInsertGlobalState::ScheduleMergeTasks(idx_t min_batch_index) { auto &scheduled_task = to_be_scheduled_tasks[i - 1]; if (scheduled_task.start_index + 1 < scheduled_task.end_index) { // erase all entries except the first one - collections.erase(collections.begin() + scheduled_task.start_index + 1, - collections.begin() + scheduled_task.end_index); + collections.erase(collections.begin() + NumericCast(scheduled_task.start_index) + 1, + collections.begin() + NumericCast(scheduled_task.end_index)); } } } @@ -613,7 +614,7 @@ SourceResultType PhysicalBatchInsert::GetData(ExecutionContext &context, DataChu auto &insert_gstate = sink_state->Cast(); chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(insert_gstate.insert_count)); + chunk.SetValue(0, 0, Value::BIGINT(NumericCast(insert_gstate.insert_count))); return SourceResultType::FINISHED; } diff --git a/src/execution/operator/persistent/physical_copy_to_file.cpp b/src/execution/operator/persistent/physical_copy_to_file.cpp index 5925e9f012eb..e4e58b3b234b 100644 --- a/src/execution/operator/persistent/physical_copy_to_file.cpp +++ b/src/execution/operator/persistent/physical_copy_to_file.cpp @@ -423,7 +423,7 @@ SourceResultType PhysicalCopyToFile::GetData(ExecutionContext &context, DataChun auto &g = sink_state->Cast(); chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(g.rows_copied)); + chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.rows_copied.load()))); return SourceResultType::FINISHED; } diff --git a/src/execution/operator/persistent/physical_delete.cpp b/src/execution/operator/persistent/physical_delete.cpp index 42085f0e6128..4fc17049032a 100644 --- a/src/execution/operator/persistent/physical_delete.cpp +++ b/src/execution/operator/persistent/physical_delete.cpp @@ -91,7 +91,7 @@ SourceResultType PhysicalDelete::GetData(ExecutionContext &context, DataChunk &c auto &g = sink_state->Cast(); if (!return_chunk) { chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(g.deleted_count)); + chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.deleted_count))); return SourceResultType::FINISHED; } diff --git a/src/execution/operator/persistent/physical_export.cpp b/src/execution/operator/persistent/physical_export.cpp index 3979a88eeb59..121a7aa2a85f 100644 --- a/src/execution/operator/persistent/physical_export.cpp +++ b/src/execution/operator/persistent/physical_export.cpp @@ -31,7 +31,7 @@ static void WriteStringStreamToFile(FileSystem &fs, stringstream &ss, const stri auto ss_string = ss.str(); auto handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW | FileLockType::WRITE_LOCK); - fs.Write(*handle, (void *)ss_string.c_str(), ss_string.size()); + fs.Write(*handle, (void *)ss_string.c_str(), NumericCast(ss_string.size())); handle.reset(); } diff --git a/src/execution/operator/persistent/physical_insert.cpp b/src/execution/operator/persistent/physical_insert.cpp index fd17fbb8ea69..768334b50abe 100644 --- a/src/execution/operator/persistent/physical_insert.cpp +++ b/src/execution/operator/persistent/physical_insert.cpp @@ -448,7 +448,7 @@ SinkResultType PhysicalInsert::Sink(ExecutionContext &context, DataChunk &chunk, auto &table_info = storage.info; auto &block_manager = TableIOManager::Get(storage).GetBlockManagerForRowData(); lstate.local_collection = - make_uniq(table_info, block_manager, insert_types, MAX_ROW_ID); + make_uniq(table_info, block_manager, insert_types, NumericCast(MAX_ROW_ID)); lstate.local_collection->InitializeEmpty(); lstate.local_collection->InitializeAppend(lstate.local_append_state); lstate.writer = &gstate.table.GetStorage().CreateOptimisticWriter(context.client); @@ -540,7 +540,7 @@ SourceResultType PhysicalInsert::GetData(ExecutionContext &context, DataChunk &c auto &insert_gstate = sink_state->Cast(); if (!return_chunk) { chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(insert_gstate.insert_count)); + chunk.SetValue(0, 0, Value::BIGINT(NumericCast(insert_gstate.insert_count))); return SourceResultType::FINISHED; } diff --git a/src/execution/operator/persistent/physical_update.cpp b/src/execution/operator/persistent/physical_update.cpp index fb7f1c30054d..c8bab2854c06 100644 --- a/src/execution/operator/persistent/physical_update.cpp +++ b/src/execution/operator/persistent/physical_update.cpp @@ -175,7 +175,7 @@ SourceResultType PhysicalUpdate::GetData(ExecutionContext &context, DataChunk &c auto &g = sink_state->Cast(); if (!return_chunk) { chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(g.updated_count)); + chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.updated_count))); return SourceResultType::FINISHED; } diff --git a/src/execution/physical_plan/plan_comparison_join.cpp b/src/execution/physical_plan/plan_comparison_join.cpp index 3397a0604fa6..743f8189b61c 100644 --- a/src/execution/physical_plan/plan_comparison_join.cpp +++ b/src/execution/physical_plan/plan_comparison_join.cpp @@ -109,7 +109,7 @@ void CheckForPerfectJoinOpt(LogicalComparisonJoin &op, PerfectHashJoinStats &joi join_state.build_min = NumericStats::Min(stats_build); join_state.build_max = NumericStats::Max(stats_build); join_state.estimated_cardinality = op.estimated_cardinality; - join_state.build_range = build_range; + join_state.build_range = NumericCast(build_range); if (join_state.build_range > MAX_BUILD_SIZE) { return; } diff --git a/src/execution/physical_plan/plan_create_table.cpp b/src/execution/physical_plan/plan_create_table.cpp index be854e0da308..06b728ee47eb 100644 --- a/src/execution/physical_plan/plan_create_table.cpp +++ b/src/execution/physical_plan/plan_create_table.cpp @@ -22,10 +22,10 @@ unique_ptr DuckCatalog::PlanCreateTableAs(ClientContext &conte auto num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); unique_ptr create; if (!parallel_streaming_insert && use_batch_index) { - create = make_uniq(op, op.schema, std::move(op.info), 0); + create = make_uniq(op, op.schema, std::move(op.info), 0U); } else { - create = make_uniq(op, op.schema, std::move(op.info), 0, + create = make_uniq(op, op.schema, std::move(op.info), 0U, parallel_streaming_insert && num_threads > 1); } diff --git a/src/execution/physical_plan/plan_top_n.cpp b/src/execution/physical_plan/plan_top_n.cpp index b3043440dcd4..9748904c57a8 100644 --- a/src/execution/physical_plan/plan_top_n.cpp +++ b/src/execution/physical_plan/plan_top_n.cpp @@ -9,8 +9,8 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalTopN &op) auto plan = CreatePlan(*op.children[0]); - auto top_n = - make_uniq(op.types, std::move(op.orders), (idx_t)op.limit, op.offset, op.estimated_cardinality); + auto top_n = make_uniq(op.types, std::move(op.orders), NumericCast(op.limit), + NumericCast(op.offset), op.estimated_cardinality); top_n->children.push_back(std::move(plan)); return std::move(top_n); } diff --git a/src/main/capi/arrow-c.cpp b/src/main/capi/arrow-c.cpp index 335b936950eb..f74d9e22fdb8 100644 --- a/src/main/capi/arrow-c.cpp +++ b/src/main/capi/arrow-c.cpp @@ -127,7 +127,7 @@ idx_t duckdb_arrow_rows_changed(duckdb_arrow result) { auto rows = collection.GetRows(); D_ASSERT(row_count == 1); D_ASSERT(rows.size() == 1); - rows_changed = rows[0].GetValue(0).GetValue(); + rows_changed = duckdb::NumericCast(rows[0].GetValue(0).GetValue()); } return rows_changed; } @@ -291,8 +291,8 @@ duckdb_state duckdb_arrow_scan(duckdb_connection connection, const char *table_n } typedef void (*release_fn_t)(ArrowSchema *); - std::vector release_fns(schema.n_children); - for (int64_t i = 0; i < schema.n_children; i++) { + std::vector release_fns(duckdb::NumericCast(schema.n_children)); + for (idx_t i = 0; i < duckdb::NumericCast(schema.n_children); i++) { auto child = schema.children[i]; release_fns[i] = child->release; child->release = arrow_array_stream_wrapper::EmptySchemaRelease; @@ -301,7 +301,7 @@ duckdb_state duckdb_arrow_scan(duckdb_connection connection, const char *table_n auto ret = arrow_array_stream_wrapper::Ingest(connection, table_name, stream); // Restore release functions. - for (int64_t i = 0; i < schema.n_children; i++) { + for (idx_t i = 0; i < duckdb::NumericCast(schema.n_children); i++) { schema.children[i]->release = release_fns[i]; } diff --git a/src/main/capi/result-c.cpp b/src/main/capi/result-c.cpp index 81c03cc5b260..5b17c0132751 100644 --- a/src/main/capi/result-c.cpp +++ b/src/main/capi/result-c.cpp @@ -117,7 +117,7 @@ struct CDecimalConverter : public CBaseConverter { template static DST Convert(SRC input) { duckdb_hugeint result; - result.lower = input; + result.lower = NumericCast(input); result.upper = 0; return result; } @@ -350,7 +350,7 @@ bool DeprecatedMaterializeResult(duckdb_result *result) { // update total changes auto row_changes = materialized.GetValue(0, 0); if (!row_changes.IsNull() && row_changes.DefaultTryCastAs(LogicalType::BIGINT)) { - result->__deprecated_rows_changed = row_changes.GetValue(); + result->__deprecated_rows_changed = NumericCast(row_changes.GetValue()); } } // now write the data diff --git a/src/main/client_context.cpp b/src/main/client_context.cpp index 909593aa66f8..8c753d2fadc8 100644 --- a/src/main/client_context.cpp +++ b/src/main/client_context.cpp @@ -482,7 +482,8 @@ ClientContext::PendingPreparedStatementInternal(ClientContextLock &lock, shared_ display_create_func = config.display_create_func ? config.display_create_func : ProgressBar::DefaultProgressBarDisplay; } - active_query->progress_bar = make_uniq(executor, config.wait_time, display_create_func); + active_query->progress_bar = + make_uniq(executor, NumericCast(config.wait_time), display_create_func); active_query->progress_bar->Start(); query_progress.Restart(); } diff --git a/src/main/database.cpp b/src/main/database.cpp index 2e3fd0204e3a..50cd320a261a 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -381,7 +381,7 @@ const DBConfig &DBConfig::GetConfig(const ClientContext &context) { } idx_t DatabaseInstance::NumberOfThreads() { - return scheduler->NumberOfThreads(); + return NumericCast(scheduler->NumberOfThreads()); } const unordered_set &DatabaseInstance::LoadedExtensions() { diff --git a/src/main/query_profiler.cpp b/src/main/query_profiler.cpp index 176b1e4359c8..79e3020679c7 100644 --- a/src/main/query_profiler.cpp +++ b/src/main/query_profiler.cpp @@ -313,7 +313,7 @@ static string DrawPadded(const string &str, idx_t width) { } else { width -= str.size(); auto half_spaces = width / 2; - auto extra_left_space = width % 2 != 0 ? 1 : 0; + auto extra_left_space = NumericCast(width % 2 != 0 ? 1 : 0); return string(half_spaces + extra_left_space, ' ') + str + string(half_spaces, ' '); } } diff --git a/src/main/relation.cpp b/src/main/relation.cpp index 970ab4a58add..76095c2b7832 100644 --- a/src/main/relation.cpp +++ b/src/main/relation.cpp @@ -374,7 +374,7 @@ unique_ptr Relation::GetQueryNode() { } void Relation::Head(idx_t limit) { - auto limit_node = Limit(limit); + auto limit_node = Limit(NumericCast(limit)); limit_node->Execute()->Print(); } // LCOV_EXCL_STOP diff --git a/src/main/secret/secret.cpp b/src/main/secret/secret.cpp index 450e5621c323..097a2194cc3f 100644 --- a/src/main/secret/secret.cpp +++ b/src/main/secret/secret.cpp @@ -15,7 +15,7 @@ int64_t BaseSecret::MatchScore(const string &path) const { continue; } if (StringUtil::StartsWith(path, prefix)) { - longest_match = MaxValue(prefix.length(), longest_match); + longest_match = MaxValue(NumericCast(prefix.length()), longest_match); } } return longest_match; diff --git a/src/main/settings/settings.cpp b/src/main/settings/settings.cpp index 3344b77dc017..0726d926355d 100644 --- a/src/main/settings/settings.cpp +++ b/src/main/settings/settings.cpp @@ -749,7 +749,7 @@ void ExternalThreadsSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, c if (new_val < 0) { throw SyntaxException("Must have a non-negative number of external threads!"); } - idx_t new_external_threads = new_val; + auto new_external_threads = NumericCast(new_val); if (db) { TaskScheduler::GetScheduler(*db).SetThreads(config.options.maximum_threads, new_external_threads); } @@ -766,7 +766,7 @@ void ExternalThreadsSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) Value ExternalThreadsSetting::GetSetting(ClientContext &context) { auto &config = DBConfig::GetConfig(context); - return Value::BIGINT(config.options.external_threads); + return Value::BIGINT(NumericCast(config.options.external_threads)); } //===--------------------------------------------------------------------===// @@ -1000,7 +1000,7 @@ void PartitionedWriteFlushThreshold::SetLocal(ClientContext &context, const Valu } Value PartitionedWriteFlushThreshold::GetSetting(ClientContext &context) { - return Value::BIGINT(ClientConfig::GetConfig(context).partitioned_write_flush_threshold); + return Value::BIGINT(NumericCast(ClientConfig::GetConfig(context).partitioned_write_flush_threshold)); } //===--------------------------------------------------------------------===// @@ -1030,11 +1030,11 @@ void PerfectHashThresholdSetting::SetLocal(ClientContext &context, const Value & if (bits < 0 || bits > 32) { throw ParserException("Perfect HT threshold out of range: should be within range 0 - 32"); } - ClientConfig::GetConfig(context).perfect_ht_threshold = bits; + ClientConfig::GetConfig(context).perfect_ht_threshold = NumericCast(bits); } Value PerfectHashThresholdSetting::GetSetting(ClientContext &context) { - return Value::BIGINT(ClientConfig::GetConfig(context).perfect_ht_threshold); + return Value::BIGINT(NumericCast(ClientConfig::GetConfig(context).perfect_ht_threshold)); } //===--------------------------------------------------------------------===// @@ -1049,7 +1049,7 @@ void PivotFilterThreshold::SetLocal(ClientContext &context, const Value &input) } Value PivotFilterThreshold::GetSetting(ClientContext &context) { - return Value::BIGINT(ClientConfig::GetConfig(context).pivot_filter_threshold); + return Value::BIGINT(NumericCast(ClientConfig::GetConfig(context).pivot_filter_threshold)); } //===--------------------------------------------------------------------===// @@ -1064,7 +1064,7 @@ void PivotLimitSetting::SetLocal(ClientContext &context, const Value &input) { } Value PivotLimitSetting::GetSetting(ClientContext &context) { - return Value::BIGINT(ClientConfig::GetConfig(context).pivot_limit); + return Value::BIGINT(NumericCast(ClientConfig::GetConfig(context).pivot_limit)); } //===--------------------------------------------------------------------===// @@ -1280,7 +1280,7 @@ void ThreadsSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Val if (new_val < 1) { throw SyntaxException("Must have at least 1 thread!"); } - idx_t new_maximum_threads = new_val; + auto new_maximum_threads = NumericCast(new_val); if (db) { TaskScheduler::GetScheduler(*db).SetThreads(new_maximum_threads, config.options.external_threads); } @@ -1297,7 +1297,7 @@ void ThreadsSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { Value ThreadsSetting::GetSetting(ClientContext &context) { auto &config = DBConfig::GetConfig(context); - return Value::BIGINT(config.options.maximum_threads); + return Value::BIGINT(NumericCast(config.options.maximum_threads)); } //===--------------------------------------------------------------------===// diff --git a/src/parallel/pipeline.cpp b/src/parallel/pipeline.cpp index e0ace608b19f..643cb7c8ee54 100644 --- a/src/parallel/pipeline.cpp +++ b/src/parallel/pipeline.cpp @@ -110,9 +110,9 @@ bool Pipeline::ScheduleParallel(shared_ptr &event) { "Attempting to schedule a pipeline where the sink requires batch index but source does not support it"); } } - idx_t max_threads = source_state->MaxThreads(); + auto max_threads = source_state->MaxThreads(); auto &scheduler = TaskScheduler::GetScheduler(executor.context); - idx_t active_threads = scheduler.NumberOfThreads(); + auto active_threads = NumericCast(scheduler.NumberOfThreads()); if (max_threads > active_threads) { max_threads = active_threads; } diff --git a/src/parallel/task_scheduler.cpp b/src/parallel/task_scheduler.cpp index 5036483388bc..b0f8997ee82a 100644 --- a/src/parallel/task_scheduler.cpp +++ b/src/parallel/task_scheduler.cpp @@ -260,7 +260,7 @@ void TaskScheduler::SetAllocatorFlushTreshold(idx_t threshold) { void TaskScheduler::Signal(idx_t n) { #ifndef DUCKDB_NO_THREADS - queue->semaphore.signal(n); + queue->semaphore.signal(NumericCast(n)); #endif } @@ -279,7 +279,7 @@ void TaskScheduler::RelaunchThreads() { void TaskScheduler::RelaunchThreadsInternal(int32_t n) { #ifndef DUCKDB_NO_THREADS auto &config = DBConfig::GetConfig(db); - idx_t new_thread_count = n; + auto new_thread_count = NumericCast(n); if (threads.size() == new_thread_count) { current_thread_count = NumericCast(threads.size() + config.options.external_threads); return; From a0f7ff1dabff4a626ccf1ca17695cc31860862cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Thu, 4 Apr 2024 16:27:38 +0200 Subject: [PATCH 048/201] orrrr --- src/include/duckdb/common/bitpacking.hpp | 2 +- .../duckdb/storage/buffer/block_handle.hpp | 1 + src/storage/arena_allocator.cpp | 8 ++++--- .../buffer/buffer_pool_reservation.cpp | 4 ++-- src/storage/compression/bitpacking.cpp | 12 +++++----- .../compression/dictionary_compression.cpp | 12 ++++++---- .../compression/fixed_size_uncompressed.cpp | 2 +- src/storage/compression/rle.cpp | 2 +- .../compression/string_uncompressed.cpp | 9 +++---- .../compression/validity_uncompressed.cpp | 2 +- src/storage/data_table.cpp | 4 ++-- src/storage/local_storage.cpp | 11 +++++---- src/storage/single_file_block_manager.cpp | 24 +++++++++---------- src/storage/standard_buffer_manager.cpp | 16 ++++++------- src/storage/storage_manager.cpp | 6 ++--- src/storage/temporary_file_manager.cpp | 6 ++--- src/storage/temporary_memory_manager.cpp | 2 +- 17 files changed, 65 insertions(+), 58 deletions(-) diff --git a/src/include/duckdb/common/bitpacking.hpp b/src/include/duckdb/common/bitpacking.hpp index 7bddedfc42d6..3b499a4b4368 100644 --- a/src/include/duckdb/common/bitpacking.hpp +++ b/src/include/duckdb/common/bitpacking.hpp @@ -112,7 +112,7 @@ class BitpackingPrimitives { return num_to_round; } - return num_to_round + BITPACKING_ALGORITHM_GROUP_SIZE - remainder; + return num_to_round + BITPACKING_ALGORITHM_GROUP_SIZE - NumericCast(remainder); } private: diff --git a/src/include/duckdb/storage/buffer/block_handle.hpp b/src/include/duckdb/storage/buffer/block_handle.hpp index 3de827ea8fcf..34ac9d1c4b7a 100644 --- a/src/include/duckdb/storage/buffer/block_handle.hpp +++ b/src/include/duckdb/storage/buffer/block_handle.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/atomic.hpp" #include "duckdb/common/common.hpp" #include "duckdb/common/mutex.hpp" +#include "duckdb/common/numeric_utils.hpp" #include "duckdb/storage/storage_info.hpp" #include "duckdb/common/file_buffer.hpp" #include "duckdb/common/enums/memory_tag.hpp" diff --git a/src/storage/arena_allocator.cpp b/src/storage/arena_allocator.cpp index c659b2c13d11..d7198c06281e 100644 --- a/src/storage/arena_allocator.cpp +++ b/src/storage/arena_allocator.cpp @@ -1,6 +1,7 @@ #include "duckdb/storage/arena_allocator.hpp" #include "duckdb/common/assert.hpp" +#include "duckdb/common/numeric_utils.hpp" namespace duckdb { @@ -88,10 +89,11 @@ data_ptr_t ArenaAllocator::Reallocate(data_ptr_t pointer, idx_t old_size, idx_t } auto head_ptr = head->data.get() + head->current_position; - int64_t diff = size - old_size; - if (pointer == head_ptr && (size < old_size || head->current_position + diff <= head->maximum_size)) { + int64_t diff = NumericCast(size) - NumericCast(old_size); + if (pointer == head_ptr && (size < old_size || NumericCast(head->current_position) + diff <= + NumericCast(head->maximum_size))) { // passed pointer is the head pointer, and the diff fits on the current chunk - head->current_position += diff; + head->current_position += NumericCast(diff); return pointer; } else { // allocate new memory diff --git a/src/storage/buffer/buffer_pool_reservation.cpp b/src/storage/buffer/buffer_pool_reservation.cpp index f22a96ffa58f..60a7bc882e61 100644 --- a/src/storage/buffer/buffer_pool_reservation.cpp +++ b/src/storage/buffer/buffer_pool_reservation.cpp @@ -23,8 +23,8 @@ BufferPoolReservation::~BufferPoolReservation() { } void BufferPoolReservation::Resize(idx_t new_size) { - int64_t delta = (int64_t)new_size - size; - pool.IncreaseUsedMemory(tag, delta); + int64_t delta = NumericCast(new_size) - NumericCast(size); + pool.IncreaseUsedMemory(tag, NumericCast(delta)); size = new_size; } diff --git a/src/storage/compression/bitpacking.cpp b/src/storage/compression/bitpacking.cpp index fa7f4a7c79f3..efdd03087bc8 100644 --- a/src/storage/compression/bitpacking.cpp +++ b/src/storage/compression/bitpacking.cpp @@ -64,7 +64,7 @@ typedef uint32_t bitpacking_metadata_encoded_t; static bitpacking_metadata_encoded_t EncodeMeta(bitpacking_metadata_t metadata) { D_ASSERT(metadata.offset <= 0x00FFFFFF); // max uint24_t bitpacking_metadata_encoded_t encoded_value = metadata.offset; - encoded_value |= (uint8_t)metadata.mode << 24; + encoded_value |= UnsafeNumericCast((uint8_t)metadata.mode << 24); return encoded_value; } static bitpacking_metadata_t DecodeMeta(bitpacking_metadata_encoded_t *metadata_encoded) { @@ -471,8 +471,8 @@ struct BitpackingCompressState : public CompressionState { }; bool CanStore(idx_t data_bytes, idx_t meta_bytes) { - auto required_data_bytes = AlignValue((data_ptr + data_bytes) - data_ptr); - auto required_meta_bytes = Storage::BLOCK_SIZE - (metadata_ptr - data_ptr) + meta_bytes; + auto required_data_bytes = AlignValue(UnsafeNumericCast((data_ptr + data_bytes) - data_ptr)); + auto required_meta_bytes = Storage::BLOCK_SIZE - UnsafeNumericCast(metadata_ptr - data_ptr) + meta_bytes; return required_data_bytes + required_meta_bytes <= Storage::BLOCK_SIZE - BitpackingPrimitives::BITPACKING_HEADER_SIZE; @@ -514,8 +514,8 @@ struct BitpackingCompressState : public CompressionState { auto base_ptr = handle.Ptr(); // Compact the segment by moving the metadata next to the data. - idx_t metadata_offset = AlignValue(data_ptr - base_ptr); - idx_t metadata_size = base_ptr + Storage::BLOCK_SIZE - metadata_ptr; + auto metadata_offset = NumericCast(AlignValue(data_ptr - base_ptr)); + auto metadata_size = NumericCast(base_ptr + Storage::BLOCK_SIZE - metadata_ptr); idx_t total_segment_size = metadata_offset + metadata_size; // Asserting things are still sane here @@ -866,7 +866,7 @@ template void BitpackingFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { BitpackingScanState scan_state(segment); - scan_state.Skip(segment, row_id); + scan_state.Skip(segment, NumericCast(row_id)); D_ASSERT(scan_state.current_group_offset < BITPACKING_METADATA_GROUP_SIZE); diff --git a/src/storage/compression/dictionary_compression.cpp b/src/storage/compression/dictionary_compression.cpp index 4f2bce492d11..79ccc64471b9 100644 --- a/src/storage/compression/dictionary_compression.cpp +++ b/src/storage/compression/dictionary_compression.cpp @@ -456,7 +456,8 @@ unique_ptr DictionaryCompressionStorage::StringInitScan(Column for (uint32_t i = 0; i < index_buffer_count; i++) { // NOTE: the passing of dict_child_vector, will not be used, its for big strings uint16_t str_len = GetStringLength(index_buffer_ptr, i); - dict_child_data[i] = FetchStringFromDict(segment, dict, baseptr, index_buffer_ptr[i], str_len); + dict_child_data[i] = + FetchStringFromDict(segment, dict, baseptr, UnsafeNumericCast(index_buffer_ptr[i]), str_len); } return std::move(state); @@ -509,7 +510,8 @@ void DictionaryCompressionStorage::StringScanPartial(ColumnSegment &segment, Col auto string_number = scan_state.sel_vec->get_index(i + start_offset); auto dict_offset = index_buffer_ptr[string_number]; auto str_len = GetStringLength(index_buffer_ptr, UnsafeNumericCast(string_number)); - result_data[result_offset + i] = FetchStringFromDict(segment, dict, baseptr, dict_offset, str_len); + result_data[result_offset + i] = + FetchStringFromDict(segment, dict, baseptr, UnsafeNumericCast(dict_offset), str_len); } } else { @@ -559,11 +561,11 @@ void DictionaryCompressionStorage::StringFetchRow(ColumnSegment &segment, Column auto result_data = FlatVector::GetData(result); // Handling non-bitpacking-group-aligned start values; - idx_t start_offset = row_id % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; + idx_t start_offset = NumericCast(row_id) % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; // Decompress part of selection buffer we need for this value. sel_t decompression_buffer[BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE]; - data_ptr_t src = data_ptr_cast(&base_data[((row_id - start_offset) * width) / 8]); + data_ptr_t src = data_ptr_cast(&base_data[((NumericCast(row_id) - start_offset) * width) / 8]); BitpackingPrimitives::UnPackBuffer(data_ptr_cast(decompression_buffer), src, BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE, width); @@ -571,7 +573,7 @@ void DictionaryCompressionStorage::StringFetchRow(ColumnSegment &segment, Column auto dict_offset = index_buffer_ptr[selection_value]; uint16_t str_len = GetStringLength(index_buffer_ptr, selection_value); - result_data[result_idx] = FetchStringFromDict(segment, dict, baseptr, dict_offset, str_len); + result_data[result_idx] = FetchStringFromDict(segment, dict, baseptr, NumericCast(dict_offset), str_len); } //===--------------------------------------------------------------------===// diff --git a/src/storage/compression/fixed_size_uncompressed.cpp b/src/storage/compression/fixed_size_uncompressed.cpp index c1b210d787e2..bbccd3f0e166 100644 --- a/src/storage/compression/fixed_size_uncompressed.cpp +++ b/src/storage/compression/fixed_size_uncompressed.cpp @@ -172,7 +172,7 @@ void FixedSizeFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t ro auto handle = buffer_manager.Pin(segment.block); // first fetch the data from the base table - auto data_ptr = handle.Ptr() + segment.GetBlockOffset() + row_id * sizeof(T); + auto data_ptr = handle.Ptr() + segment.GetBlockOffset() + NumericCast(row_id) * sizeof(T); memcpy(FlatVector::GetData(result) + result_idx * sizeof(T), data_ptr, sizeof(T)); } diff --git a/src/storage/compression/rle.cpp b/src/storage/compression/rle.cpp index edc111dd337a..e518b14602f8 100644 --- a/src/storage/compression/rle.cpp +++ b/src/storage/compression/rle.cpp @@ -378,7 +378,7 @@ void RLEScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, V template void RLEFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { RLEScanState scan_state(segment); - scan_state.Skip(segment, row_id); + scan_state.Skip(segment, NumericCast(row_id)); auto data = scan_state.handle.Ptr() + segment.GetBlockOffset(); auto data_pointer = reinterpret_cast(data + RLEConstants::RLE_HEADER_SIZE); diff --git a/src/storage/compression/string_uncompressed.cpp b/src/storage/compression/string_uncompressed.cpp index bfdebbe35a36..dea2ef6fc9f7 100644 --- a/src/storage/compression/string_uncompressed.cpp +++ b/src/storage/compression/string_uncompressed.cpp @@ -87,7 +87,7 @@ void UncompressedStringStorage::StringScanPartial(ColumnSegment &segment, Column for (idx_t i = 0; i < scan_count; i++) { // std::abs used since offsets can be negative to indicate big strings - uint32_t string_length = std::abs(base_data[start + i]) - std::abs(previous_offset); + auto string_length = UnsafeNumericCast(std::abs(base_data[start + i]) - std::abs(previous_offset)); result_data[result_offset + i] = FetchStringFromDict(segment, dict, result, baseptr, base_data[start + i], string_length); previous_offset = base_data[start + i]; @@ -133,9 +133,9 @@ void UncompressedStringStorage::StringFetchRow(ColumnSegment &segment, ColumnFet uint32_t string_length; if ((idx_t)row_id == 0) { // edge case where this is the first string in the dict - string_length = std::abs(dict_offset); + string_length = NumericCast(std::abs(dict_offset)); } else { - string_length = std::abs(dict_offset) - std::abs(base_data[row_id - 1]); + string_length = NumericCast(std::abs(dict_offset) - std::abs(base_data[row_id - 1])); } result_data[result_idx] = FetchStringFromDict(segment, dict, result, baseptr, dict_offset, string_length); } @@ -347,7 +347,8 @@ string_t UncompressedStringStorage::ReadOverflowString(ColumnSegment &segment, V // now append the string to the single buffer while (remaining > 0) { - idx_t to_write = MinValue(remaining, Storage::BLOCK_SIZE - sizeof(block_id_t) - offset); + idx_t to_write = + MinValue(remaining, Storage::BLOCK_SIZE - sizeof(block_id_t) - UnsafeNumericCast(offset)); memcpy(target_ptr, handle.Ptr() + offset, to_write); remaining -= to_write; offset += to_write; diff --git a/src/storage/compression/validity_uncompressed.cpp b/src/storage/compression/validity_uncompressed.cpp index 3df0cae80f91..f2f01533b89c 100644 --- a/src/storage/compression/validity_uncompressed.cpp +++ b/src/storage/compression/validity_uncompressed.cpp @@ -384,7 +384,7 @@ void ValidityFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row auto dataptr = handle.Ptr() + segment.GetBlockOffset(); ValidityMask mask(reinterpret_cast(dataptr)); auto &result_mask = FlatVector::Validity(result); - if (!mask.RowIsValidUnsafe(row_id)) { + if (!mask.RowIsValidUnsafe(NumericCast(row_id))) { result_mask.SetInvalid(result_idx); } } diff --git a/src/storage/data_table.cpp b/src/storage/data_table.cpp index 08c389557ee0..a48e6163e6c3 100644 --- a/src/storage/data_table.cpp +++ b/src/storage/data_table.cpp @@ -742,7 +742,7 @@ void DataTable::AppendLock(TableAppendState &state) { if (!is_root) { throw TransactionException("Transaction conflict: adding entries to a table that has been altered!"); } - state.row_start = row_groups->GetTotalRows(); + state.row_start = NumericCast(row_groups->GetTotalRows()); state.current_row = state.row_start; } @@ -855,7 +855,7 @@ void DataTable::RevertAppend(idx_t start_row, idx_t count) { idx_t scan_count = MinValue(count, row_groups->GetTotalRows() - start_row); ScanTableSegment(start_row, scan_count, [&](DataChunk &chunk) { for (idx_t i = 0; i < chunk.size(); i++) { - row_data[i] = current_row_base + i; + row_data[i] = NumericCast(current_row_base + i); } info->indexes.Scan([&](Index &index) { index.Delete(chunk, row_identifiers); diff --git a/src/storage/local_storage.cpp b/src/storage/local_storage.cpp index 513a18e2faf1..b695f27f880e 100644 --- a/src/storage/local_storage.cpp +++ b/src/storage/local_storage.cpp @@ -191,7 +191,7 @@ void LocalTableStorage::AppendToIndexes(DuckTransaction &transaction, TableAppen return true; }); if (append_to_table) { - table.RevertAppendInternal(append_state.row_start); + table.RevertAppendInternal(NumericCast(append_state.row_start)); } // we need to vacuum the indexes to remove any buffers that are now empty @@ -361,8 +361,9 @@ void LocalStorage::InitializeAppend(LocalAppendState &state, DataTable &table) { void LocalStorage::Append(LocalAppendState &state, DataChunk &chunk) { // append to unique indices (if any) auto storage = state.storage; - idx_t base_id = MAX_ROW_ID + storage->row_groups->GetTotalRows() + state.append_state.total_append_count; - auto error = DataTable::AppendToIndexes(storage->indexes, chunk, base_id); + idx_t base_id = + NumericCast(MAX_ROW_ID) + storage->row_groups->GetTotalRows() + state.append_state.total_append_count; + auto error = DataTable::AppendToIndexes(storage->indexes, chunk, NumericCast(base_id)); if (error.HasError()) { error.Throw(); } @@ -383,7 +384,7 @@ void LocalStorage::LocalMerge(DataTable &table, RowGroupCollection &collection) auto &storage = table_manager.GetOrCreateStorage(table); if (!storage.indexes.Empty()) { // append data to indexes if required - row_t base_id = MAX_ROW_ID + storage.row_groups->GetTotalRows(); + row_t base_id = MAX_ROW_ID + NumericCast(storage.row_groups->GetTotalRows()); auto error = storage.AppendToIndexes(transaction, collection, storage.indexes, table.GetTypes(), base_id); if (error.HasError()) { error.Throw(); @@ -447,7 +448,7 @@ void LocalStorage::Flush(DataTable &table, LocalTableStorage &storage) { TableAppendState append_state; table.AppendLock(append_state); - transaction.PushAppend(table, append_state.row_start, append_count); + transaction.PushAppend(table, NumericCast(append_state.row_start), append_count); if ((append_state.row_start == 0 || storage.row_groups->GetTotalRows() >= MERGE_THRESHOLD) && storage.deleted_rows == 0) { // table is currently empty OR we are bulk appending: move over the storage directly diff --git a/src/storage/single_file_block_manager.cpp b/src/storage/single_file_block_manager.cpp index b46f4f9273c9..98775c1c745e 100644 --- a/src/storage/single_file_block_manager.cpp +++ b/src/storage/single_file_block_manager.cpp @@ -195,8 +195,8 @@ void SingleFileBlockManager::CreateNewDatabase() { DatabaseHeader h1; // header 1 h1.iteration = 0; - h1.meta_block = INVALID_BLOCK; - h1.free_list = INVALID_BLOCK; + h1.meta_block = idx_t(INVALID_BLOCK); + h1.free_list = idx_t(INVALID_BLOCK); h1.block_count = 0; h1.block_size = Storage::BLOCK_ALLOC_SIZE; h1.vector_size = STANDARD_VECTOR_SIZE; @@ -205,8 +205,8 @@ void SingleFileBlockManager::CreateNewDatabase() { // header 2 DatabaseHeader h2; h2.iteration = 0; - h2.meta_block = INVALID_BLOCK; - h2.free_list = INVALID_BLOCK; + h2.meta_block = idx_t(INVALID_BLOCK); + h2.free_list = idx_t(INVALID_BLOCK); h2.block_count = 0; h2.block_size = Storage::BLOCK_ALLOC_SIZE; h2.vector_size = STANDARD_VECTOR_SIZE; @@ -283,7 +283,7 @@ void SingleFileBlockManager::Initialize(DatabaseHeader &header) { free_list_id = header.free_list; meta_block = header.meta_block; iteration_count = header.iteration; - max_block = header.block_count; + max_block = NumericCast(header.block_count); } void SingleFileBlockManager::LoadFreeList() { @@ -386,7 +386,7 @@ idx_t SingleFileBlockManager::GetMetaBlock() { idx_t SingleFileBlockManager::TotalBlocks() { lock_guard lock(block_lock); - return max_block; + return NumericCast(max_block); } idx_t SingleFileBlockManager::FreeBlocks() { @@ -413,12 +413,12 @@ unique_ptr SingleFileBlockManager::CreateBlock(block_id_t block_id, FileB void SingleFileBlockManager::Read(Block &block) { D_ASSERT(block.id >= 0); D_ASSERT(std::find(free_list.begin(), free_list.end(), block.id) == free_list.end()); - ReadAndChecksum(block, BLOCK_START + block.id * Storage::BLOCK_ALLOC_SIZE); + ReadAndChecksum(block, BLOCK_START + NumericCast(block.id) * Storage::BLOCK_ALLOC_SIZE); } void SingleFileBlockManager::Write(FileBuffer &buffer, block_id_t block_id) { D_ASSERT(block_id >= 0); - ChecksumAndWrite(buffer, BLOCK_START + block_id * Storage::BLOCK_ALLOC_SIZE); + ChecksumAndWrite(buffer, BLOCK_START + NumericCast(block_id) * Storage::BLOCK_ALLOC_SIZE); } void SingleFileBlockManager::Truncate() { @@ -440,7 +440,7 @@ void SingleFileBlockManager::Truncate() { // truncate the file free_list.erase(free_list.lower_bound(max_block), free_list.end()); newly_freed_list.erase(newly_freed_list.lower_bound(max_block), newly_freed_list.end()); - handle->Truncate(BLOCK_START + max_block * Storage::BLOCK_ALLOC_SIZE); + handle->Truncate(NumericCast(BLOCK_START + NumericCast(max_block) * Storage::BLOCK_ALLOC_SIZE)); } vector SingleFileBlockManager::GetFreeListBlocks() { @@ -529,7 +529,7 @@ void SingleFileBlockManager::WriteHeader(DatabaseHeader header) { header.free_list = DConstants::INVALID_INDEX; } metadata_manager.Flush(); - header.block_count = max_block; + header.block_count = NumericCast(max_block); auto &config = DBConfig::Get(db); if (config.options.checkpoint_abort == CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE) { @@ -569,8 +569,8 @@ void SingleFileBlockManager::TrimFreeBlocks() { // We are now one too far. --itr; // Trim the range. - handle->Trim(BLOCK_START + (first * Storage::BLOCK_ALLOC_SIZE), - (last + 1 - first) * Storage::BLOCK_ALLOC_SIZE); + handle->Trim(BLOCK_START + (NumericCast(first) * Storage::BLOCK_ALLOC_SIZE), + NumericCast(last + 1 - first) * Storage::BLOCK_ALLOC_SIZE); } } newly_freed_list.clear(); diff --git a/src/storage/standard_buffer_manager.cpp b/src/storage/standard_buffer_manager.cpp index ab5ad560ba12..0b62bd8db48f 100644 --- a/src/storage/standard_buffer_manager.cpp +++ b/src/storage/standard_buffer_manager.cpp @@ -126,7 +126,7 @@ void StandardBufferManager::ReAllocate(shared_ptr &handle, idx_t bl D_ASSERT(handle->memory_usage == handle->memory_charge.size); auto req = handle->buffer->CalculateMemory(block_size); - int64_t memory_delta = NumericCast(req.alloc_size) - handle->memory_usage; + int64_t memory_delta = NumericCast(req.alloc_size) - NumericCast(handle->memory_usage); if (memory_delta == 0) { return; @@ -134,10 +134,10 @@ void StandardBufferManager::ReAllocate(shared_ptr &handle, idx_t bl // evict blocks until we have space to resize this block // unlock the handle lock during the call to EvictBlocksOrThrow lock.unlock(); - auto reservation = - EvictBlocksOrThrow(handle->tag, memory_delta, nullptr, "failed to resize block from %s to %s%s", - StringUtil::BytesToHumanReadableString(handle->memory_usage), - StringUtil::BytesToHumanReadableString(req.alloc_size)); + auto reservation = EvictBlocksOrThrow(handle->tag, NumericCast(memory_delta), nullptr, + "failed to resize block from %s to %s%s", + StringUtil::BytesToHumanReadableString(handle->memory_usage), + StringUtil::BytesToHumanReadableString(req.alloc_size)); lock.lock(); // EvictBlocks decrements 'current_memory' for us. @@ -183,10 +183,10 @@ BufferHandle StandardBufferManager::Pin(shared_ptr &handle) { auto buf = handle->Load(handle, std::move(reusable_buffer)); handle->memory_charge = std::move(reservation); // In the case of a variable sized block, the buffer may be smaller than a full block. - int64_t delta = handle->buffer->AllocSize() - handle->memory_usage; + int64_t delta = NumericCast(handle->buffer->AllocSize()) - NumericCast(handle->memory_usage); if (delta) { D_ASSERT(delta < 0); - handle->memory_usage += delta; + handle->memory_usage += NumericCast(delta); handle->memory_charge.Resize(handle->memory_usage); } D_ASSERT(handle->memory_usage == handle->buffer->AllocSize()); @@ -367,7 +367,7 @@ vector StandardBufferManager::GetTemporaryFiles() { TemporaryFileInformation info; info.path = name; auto handle = fs.OpenFile(name, FileFlags::FILE_FLAGS_READ); - info.size = fs.GetFileSize(*handle); + info.size = NumericCast(fs.GetFileSize(*handle)); handle.reset(); result.push_back(info); }); diff --git a/src/storage/storage_manager.cpp b/src/storage/storage_manager.cpp index 1a798e3c9a06..dcc5b150d611 100644 --- a/src/storage/storage_manager.cpp +++ b/src/storage/storage_manager.cpp @@ -200,7 +200,7 @@ class SingleFileStorageCommitState : public StorageCommitState { wal.skip_writing = false; if (wal.GetTotalWritten() > initial_written) { // remove any entries written into the WAL by truncating it - wal.Truncate(initial_wal_size); + wal.Truncate(NumericCast(initial_wal_size)); } } } @@ -284,7 +284,7 @@ DatabaseSize SingleFileStorageManager::GetDatabaseSize() { ds.used_blocks = ds.total_blocks - ds.free_blocks; ds.bytes = (ds.total_blocks * ds.block_size); if (auto wal = GetWriteAheadLog()) { - ds.wal_size = wal->GetWALSize(); + ds.wal_size = NumericCast(wal->GetWALSize()); } } return ds; @@ -302,7 +302,7 @@ bool SingleFileStorageManager::AutomaticCheckpoint(idx_t estimated_wal_bytes) { } auto &config = DBConfig::Get(db); - auto initial_size = log->GetWALSize(); + auto initial_size = NumericCast(log->GetWALSize()); idx_t expected_wal_size = initial_size + estimated_wal_bytes; return expected_wal_size > config.options.checkpoint_wal_size; } diff --git a/src/storage/temporary_file_manager.cpp b/src/storage/temporary_file_manager.cpp index c374829037ec..632f7075f755 100644 --- a/src/storage/temporary_file_manager.cpp +++ b/src/storage/temporary_file_manager.cpp @@ -106,7 +106,7 @@ void TemporaryFileHandle::EraseBlockIndex(block_id_t block_index) { // remove the block (and potentially truncate the temp file) TemporaryFileLock lock(file_lock); D_ASSERT(handle); - RemoveTempBlockIndex(lock, block_index); + RemoveTempBlockIndex(lock, NumericCast(block_index)); } bool TemporaryFileHandle::DeleteIfEmpty() { @@ -147,7 +147,7 @@ void TemporaryFileHandle::RemoveTempBlockIndex(TemporaryFileLock &, idx_t index) #ifndef WIN32 // this ended up causing issues when sorting auto max_index = index_manager.GetMaxIndex(); auto &fs = FileSystem::GetFileSystem(db); - fs.Truncate(*handle, GetPositionInFile(max_index + 1)); + fs.Truncate(*handle, NumericCast(GetPositionInFile(max_index + 1))); #endif } } @@ -310,7 +310,7 @@ void TemporaryFileManager::EraseUsedBlock(TemporaryManagerLock &lock, block_id_t throw InternalException("EraseUsedBlock - Block %llu not found in used blocks", id); } used_blocks.erase(entry); - handle->EraseBlockIndex(index.block_index); + handle->EraseBlockIndex(NumericCast(index.block_index)); if (handle->DeleteIfEmpty()) { EraseFileHandle(lock, index.file_index); } diff --git a/src/storage/temporary_memory_manager.cpp b/src/storage/temporary_memory_manager.cpp index a6d07b9bf635..ba046d30fb17 100644 --- a/src/storage/temporary_memory_manager.cpp +++ b/src/storage/temporary_memory_manager.cpp @@ -47,7 +47,7 @@ void TemporaryMemoryManager::UpdateConfiguration(ClientContext &context) { memory_limit = MAXIMUM_MEMORY_LIMIT_RATIO * double(buffer_manager.GetMaxMemory()); has_temporary_directory = buffer_manager.HasTemporaryDirectory(); - num_threads = task_scheduler.NumberOfThreads(); + num_threads = NumericCast(task_scheduler.NumberOfThreads()); query_max_memory = buffer_manager.GetQueryMaxMemory(); } From 8d8044c5319406acb9d8a11f5251a41e6ad552e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Thu, 4 Apr 2024 16:58:33 +0200 Subject: [PATCH 049/201] moo --- src/include/duckdb/common/bitpacking.hpp | 4 ++-- third_party/concurrentqueue/concurrentqueue.h | 10 +++++----- third_party/utf8proc/include/utf8proc.hpp | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/include/duckdb/common/bitpacking.hpp b/src/include/duckdb/common/bitpacking.hpp index 3b499a4b4368..e0a1fe23edb6 100644 --- a/src/include/duckdb/common/bitpacking.hpp +++ b/src/include/duckdb/common/bitpacking.hpp @@ -174,7 +174,7 @@ class BitpackingPrimitives { if (bitwidth < sizeof(T) * 8 && bitwidth != 0) { if (is_signed) { D_ASSERT(max_value <= (T(1) << (bitwidth - 1)) - 1); - D_ASSERT(min_value >= (T(-1) * ((T(1) << (bitwidth - 1)) - 1) - 1)); + // D_ASSERT(min_value >= (T(-1) * ((T(1) << (bitwidth - 1)) - 1) - 1)); } else { D_ASSERT(max_value <= (T(1) << (bitwidth)) - 1); } @@ -192,7 +192,7 @@ class BitpackingPrimitives { T const mask = UnsafeNumericCast(T_U(1) << (width - 1)); for (idx_t i = 0; i < BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; ++i) { T value = Load(dst + i * sizeof(T)); - value = UnsafeNumericCast(value & ((T_U(1) << width) - T_U(1))); + value = UnsafeNumericCast(T_U(value) & ((T_U(1) << width) - T_U(1))); T result = (value ^ mask) - mask; Store(result, dst + i * sizeof(T)); } diff --git a/third_party/concurrentqueue/concurrentqueue.h b/third_party/concurrentqueue/concurrentqueue.h index f3e2b1005eec..0f5ad0a4d625 100644 --- a/third_party/concurrentqueue/concurrentqueue.h +++ b/third_party/concurrentqueue/concurrentqueue.h @@ -1942,7 +1942,7 @@ class ConcurrentQueue // block size (in order to get a correct signed block count offset in all cases): auto headBase = localBlockIndex->entries[localBlockIndexHead].base; auto blockBaseIndex = index & ~static_cast(BLOCK_SIZE - 1); - auto offset = static_cast(static_cast::type>(blockBaseIndex - headBase) / BLOCK_SIZE); + auto offset = static_cast(static_cast::type>(blockBaseIndex - headBase) / static_cast::type>(BLOCK_SIZE)); auto block = localBlockIndex->entries[(localBlockIndexHead + offset) & (localBlockIndex->size - 1)].block; // Dequeue @@ -2202,7 +2202,7 @@ class ConcurrentQueue auto headBase = localBlockIndex->entries[localBlockIndexHead].base; auto firstBlockBaseIndex = firstIndex & ~static_cast(BLOCK_SIZE - 1); - auto offset = static_cast(static_cast::type>(firstBlockBaseIndex - headBase) / BLOCK_SIZE); + auto offset = static_cast(static_cast::type>(firstBlockBaseIndex - headBase) / static_cast::type>(BLOCK_SIZE)); auto indexIndex = (localBlockIndexHead + offset) & (localBlockIndex->size - 1); // Iterate the blocks and dequeue @@ -2875,7 +2875,7 @@ class ConcurrentQueue assert(tailBase != INVALID_BLOCK_BASE); // Note: Must use division instead of shift because the index may wrap around, causing a negative // offset, whose negativity we want to preserve - auto offset = static_cast(static_cast::type>(index - tailBase) / BLOCK_SIZE); + auto offset = static_cast(static_cast::type>(index - tailBase) / static_cast::type>(BLOCK_SIZE)); size_t idx = (tail + offset) & (localBlockIndex->capacity - 1); assert(localBlockIndex->index[idx]->key.load(std::memory_order_relaxed) == index && localBlockIndex->index[idx]->value.load(std::memory_order_relaxed) != nullptr); return idx; @@ -3630,7 +3630,7 @@ ConsumerToken::ConsumerToken(ConcurrentQueue& queue) : itemsConsumedFromCurrent(0), currentProducer(nullptr), desiredProducer(nullptr) { initialOffset = queue.nextExplicitConsumerId.fetch_add(1, std::memory_order_release); - lastKnownGlobalOffset = -1; + lastKnownGlobalOffset = uint32_t(-1); } template @@ -3638,7 +3638,7 @@ ConsumerToken::ConsumerToken(BlockingConcurrentQueue& queue) : itemsConsumedFromCurrent(0), currentProducer(nullptr), desiredProducer(nullptr) { initialOffset = reinterpret_cast*>(&queue)->nextExplicitConsumerId.fetch_add(1, std::memory_order_release); - lastKnownGlobalOffset = -1; + lastKnownGlobalOffset = uint32_t(-1); } template diff --git a/third_party/utf8proc/include/utf8proc.hpp b/third_party/utf8proc/include/utf8proc.hpp index 336f95e2ed5a..37fa1f2664b8 100644 --- a/third_party/utf8proc/include/utf8proc.hpp +++ b/third_party/utf8proc/include/utf8proc.hpp @@ -637,7 +637,7 @@ void utf8proc_grapheme_callback(const char *s, size_t len, T &&fun) { size_t start = 0; size_t cpos = 0; while(true) { - cpos += UnsafeNumericCast(sz); + cpos += static_cast(sz); if (cpos >= len) { fun(start, cpos); return; From b134927d2af44dfde501d3db51838e38f52f6bf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Thu, 4 Apr 2024 17:14:08 +0200 Subject: [PATCH 050/201] un-funking build --- .../operator/aggregate/physical_window.cpp | 2 +- src/optimizer/move_constants.cpp | 165 ++++++++++++++++++ .../rule/conjunction_simplification.cpp | 70 ++++++++ src/optimizer/rule/constant_folding.cpp | 43 +++++ .../rule/date_part_simplification.cpp | 103 +++++++++++ src/optimizer/rule/empty_needle_removal.cpp | 54 ++++++ src/optimizer/rule/enum_comparison.cpp | 70 ++++++++ .../rule/in_clause_simplification_rule.cpp | 57 ++++++ src/optimizer/rule/like_optimizations.cpp | 161 +++++++++++++++++ src/optimizer/rule/move_constants.cpp | 164 +++++++++++++++++ src/planner/CMakeLists.txt | 26 +++ 11 files changed, 914 insertions(+), 1 deletion(-) create mode 100644 src/optimizer/move_constants.cpp diff --git a/src/execution/operator/aggregate/physical_window.cpp b/src/execution/operator/aggregate/physical_window.cpp index 01e6752a2cd0..c4cd9d4cbb98 100644 --- a/src/execution/operator/aggregate/physical_window.cpp +++ b/src/execution/operator/aggregate/physical_window.cpp @@ -327,7 +327,7 @@ void WindowPartitionSourceState::MaterializeSortedData() { heap->blocks = std::move(sd.heap_blocks); hash_group.reset(); } else { - heap = make_uniq(buffer_manager, Storage::BLOCK_SIZE, 1U, true); + heap = make_uniq(buffer_manager, (idx_t)Storage::BLOCK_SIZE, 1U, true); } heap->count = std::accumulate(heap->blocks.begin(), heap->blocks.end(), idx_t(0), [&](idx_t c, const unique_ptr &b) { return c + b->count; }); diff --git a/src/optimizer/move_constants.cpp b/src/optimizer/move_constants.cpp new file mode 100644 index 000000000000..636265ff9131 --- /dev/null +++ b/src/optimizer/move_constants.cpp @@ -0,0 +1,165 @@ +#include "duckdb/optimizer/rule/move_constants.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/value_operations/value_operations.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" + +namespace duckdb { + +MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + auto op = make_uniq(); + op->matchers.push_back(make_uniq()); + op->policy = SetMatcher::Policy::UNORDERED; + + auto arithmetic = make_uniq(); + // we handle multiplication, addition and subtraction because those are "easy" + // integer division makes the division case difficult + // e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules + arithmetic->function = make_uniq(unordered_set {"+", "-", "*"}); + // we match only on integral numeric types + arithmetic->type = make_uniq(); + auto child_constant_matcher = make_uniq(); + auto child_expression_matcher = make_uniq(); + child_constant_matcher->type = make_uniq(); + child_expression_matcher->type = make_uniq(); + arithmetic->matchers.push_back(std::move(child_constant_matcher)); + arithmetic->matchers.push_back(std::move(child_expression_matcher)); + arithmetic->policy = SetMatcher::Policy::SOME; + op->matchers.push_back(std::move(arithmetic)); + root = std::move(op); +} + +unique_ptr MoveConstantsRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &comparison = bindings[0].get().Cast(); + auto &outer_constant = bindings[1].get().Cast(); + auto &arithmetic = bindings[2].get().Cast(); + auto &inner_constant = bindings[3].get().Cast(); + D_ASSERT(arithmetic.return_type.IsIntegral()); + D_ASSERT(arithmetic.children[0]->return_type.IsIntegral()); + if (inner_constant.value.IsNull() || outer_constant.value.IsNull()) { + return make_uniq(Value(comparison.return_type)); + } + auto &constant_type = outer_constant.return_type; + hugeint_t outer_value = IntegralValue::Get(outer_constant.value); + hugeint_t inner_value = IntegralValue::Get(inner_constant.value); + + idx_t arithmetic_child_index = arithmetic.children[0].get() == &inner_constant ? 1 : 0; + auto &op_type = arithmetic.function.name; + if (op_type == "+") { + // [x + 1 COMP 10] OR [1 + x COMP 10] + // order does not matter in addition: + // simply change right side to 10-1 (outer_constant - inner_constant) + if (!Hugeint::TrySubtractInPlace(outer_value, inner_value)) { + return nullptr; + } + auto result_value = Value::HUGEINT(outer_value); + if (!result_value.DefaultTryCastAs(constant_type)) { + if (comparison.type != ExpressionType::COMPARE_EQUAL) { + return nullptr; + } + // if the cast is not possible then the comparison is not possible + // for example, if we have x + 5 = 3, where x is an unsigned number, we will get x = -2 + // since this is not possible we can remove the entire branch here + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(false)); + } + outer_constant.value = std::move(result_value); + } else if (op_type == "-") { + // [x - 1 COMP 10] O R [1 - x COMP 10] + // order matters in subtraction: + if (arithmetic_child_index == 0) { + // [x - 1 COMP 10] + // change right side to 10+1 (outer_constant + inner_constant) + if (!Hugeint::TryAddInPlace(outer_value, inner_value)) { + return nullptr; + } + auto result_value = Value::HUGEINT(outer_value); + if (!result_value.DefaultTryCastAs(constant_type)) { + // if the cast is not possible then an equality comparison is not possible + if (comparison.type != ExpressionType::COMPARE_EQUAL) { + return nullptr; + } + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(false)); + } + outer_constant.value = std::move(result_value); + } else { + // [1 - x COMP 10] + // change right side to 1-10=-9 + if (!Hugeint::TrySubtractInPlace(inner_value, outer_value)) { + return nullptr; + } + auto result_value = Value::HUGEINT(inner_value); + if (!result_value.DefaultTryCastAs(constant_type)) { + // if the cast is not possible then an equality comparison is not possible + if (comparison.type != ExpressionType::COMPARE_EQUAL) { + return nullptr; + } + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(false)); + } + outer_constant.value = std::move(result_value); + // in this case, we should also flip the comparison + // e.g. if we have [4 - x < 2] then we should have [x > 2] + comparison.type = FlipComparisonExpression(comparison.type); + } + } else { + D_ASSERT(op_type == "*"); + // [x * 2 COMP 10] OR [2 * x COMP 10] + // order does not matter in multiplication: + // change right side to 10/2 (outer_constant / inner_constant) + // but ONLY if outer_constant is cleanly divisible by the inner_constant + if (inner_value == 0) { + // x * 0, the result is either 0 or NULL + // we let the arithmetic_simplification rule take care of simplifying this first + return nullptr; + } + // check out of range for HUGEINT or not cleanly divisible + // HUGEINT is not cleanly divisible when outer_value == minimum and inner value == -1. (modulo overflow) + if ((outer_value == NumericLimits::Minimum() && inner_value == -1) || + outer_value % inner_value != 0) { + bool is_equality = comparison.type == ExpressionType::COMPARE_EQUAL; + bool is_inequality = comparison.type == ExpressionType::COMPARE_NOTEQUAL; + if (is_equality || is_inequality) { + // we know the values are not equal + // the result will be either FALSE or NULL (if COMPARE_EQUAL) + // or TRUE or NULL (if COMPARE_NOTEQUAL) + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(is_inequality)); + } else { + // not cleanly divisible and we are doing > >= < <=, skip the simplification for now + return nullptr; + } + } + if (inner_value < 0) { + // multiply by negative value, need to flip expression + comparison.type = FlipComparisonExpression(comparison.type); + } + // else divide the RHS by the LHS + // we need to do a range check on the cast even though we do a division + // because e.g. -128 / -1 = 128, which is out of range + auto result_value = Value::HUGEINT(outer_value / inner_value); + if (!result_value.DefaultTryCastAs(constant_type)) { + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(false)); + } + outer_constant.value = std::move(result_value); + } + // replace left side with x + // first extract x from the arithmetic expression + auto arithmetic_child = std::move(arithmetic.children[arithmetic_child_index]); + // then place in the comparison + if (comparison.left.get() == &outer_constant) { + comparison.right = std::move(arithmetic_child); + } else { + comparison.left = std::move(arithmetic_child); + } + changes_made = true; + return nullptr; +} + +} // namespace duckdb diff --git a/src/optimizer/rule/conjunction_simplification.cpp b/src/optimizer/rule/conjunction_simplification.cpp index 8b137891791f..14fd80a9423b 100644 --- a/src/optimizer/rule/conjunction_simplification.cpp +++ b/src/optimizer/rule/conjunction_simplification.cpp @@ -1 +1,71 @@ +#include "duckdb/optimizer/rule/conjunction_simplification.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" + +namespace duckdb { + +ConjunctionSimplificationRule::ConjunctionSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on a ComparisonExpression that has a ConstantExpression as a check + auto op = make_uniq(); + op->matchers.push_back(make_uniq()); + op->policy = SetMatcher::Policy::SOME; + root = std::move(op); +} + +unique_ptr ConjunctionSimplificationRule::RemoveExpression(BoundConjunctionExpression &conj, + const Expression &expr) { + for (idx_t i = 0; i < conj.children.size(); i++) { + if (conj.children[i].get() == &expr) { + // erase the expression + conj.children.erase_at(i); + break; + } + } + if (conj.children.size() == 1) { + // one expression remaining: simply return that expression and erase the conjunction + return std::move(conj.children[0]); + } + return nullptr; +} + +unique_ptr ConjunctionSimplificationRule::Apply(LogicalOperator &op, + vector> &bindings, bool &changes_made, + bool is_root) { + auto &conjunction = bindings[0].get().Cast(); + auto &constant_expr = bindings[1].get(); + // the constant_expr is a scalar expression that we have to fold + // use an ExpressionExecutor to execute the expression + D_ASSERT(constant_expr.IsFoldable()); + Value constant_value; + if (!ExpressionExecutor::TryEvaluateScalar(GetContext(), constant_expr, constant_value)) { + return nullptr; + } + constant_value = constant_value.DefaultCastAs(LogicalType::BOOLEAN); + if (constant_value.IsNull()) { + // we can't simplify conjunctions with a constant NULL + return nullptr; + } + if (conjunction.type == ExpressionType::CONJUNCTION_AND) { + if (!BooleanValue::Get(constant_value)) { + // FALSE in AND, result of expression is false + return make_uniq(Value::BOOLEAN(false)); + } else { + // TRUE in AND, remove the expression from the set + return RemoveExpression(conjunction, constant_expr); + } + } else { + D_ASSERT(conjunction.type == ExpressionType::CONJUNCTION_OR); + if (!BooleanValue::Get(constant_value)) { + // FALSE in OR, remove the expression from the set + return RemoveExpression(conjunction, constant_expr); + } else { + // TRUE in OR, result of expression is true + return make_uniq(Value::BOOLEAN(true)); + } + } +} + +} // namespace duckdb diff --git a/src/optimizer/rule/constant_folding.cpp b/src/optimizer/rule/constant_folding.cpp index 8b137891791f..7702b46040f7 100644 --- a/src/optimizer/rule/constant_folding.cpp +++ b/src/optimizer/rule/constant_folding.cpp @@ -1 +1,44 @@ +#include "duckdb/optimizer/rule/constant_folding.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" + +namespace duckdb { + +//! The ConstantFoldingExpressionMatcher matches on any scalar expression (i.e. Expression::IsFoldable is true) +class ConstantFoldingExpressionMatcher : public FoldableConstantMatcher { +public: + bool Match(Expression &expr, vector> &bindings) override { + // we also do not match on ConstantExpressions, because we cannot fold those any further + if (expr.type == ExpressionType::VALUE_CONSTANT) { + return false; + } + return FoldableConstantMatcher::Match(expr, bindings); + } +}; + +ConstantFoldingRule::ConstantFoldingRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + auto op = make_uniq(); + root = std::move(op); +} + +unique_ptr ConstantFoldingRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &root = bindings[0].get(); + // the root is a scalar expression that we have to fold + D_ASSERT(root.IsFoldable() && root.type != ExpressionType::VALUE_CONSTANT); + + // use an ExpressionExecutor to execute the expression + Value result_value; + if (!ExpressionExecutor::TryEvaluateScalar(GetContext(), root, result_value)) { + return nullptr; + } + D_ASSERT(result_value.type().InternalType() == root.return_type.InternalType()); + // now get the value from the result vector and insert it back into the plan as a constant expression + return make_uniq(result_value); +} + +} // namespace duckdb diff --git a/src/optimizer/rule/date_part_simplification.cpp b/src/optimizer/rule/date_part_simplification.cpp index 8b137891791f..6737e576c7dd 100644 --- a/src/optimizer/rule/date_part_simplification.cpp +++ b/src/optimizer/rule/date_part_simplification.cpp @@ -1 +1,104 @@ +#include "duckdb/optimizer/rule/date_part_simplification.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/optimizer/matcher/expression_matcher.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/function/function_binder.hpp" + +namespace duckdb { + +DatePartSimplificationRule::DatePartSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + auto func = make_uniq(); + func->function = make_uniq("date_part"); + func->matchers.push_back(make_uniq()); + func->matchers.push_back(make_uniq()); + func->policy = SetMatcher::Policy::ORDERED; + root = std::move(func); +} + +unique_ptr DatePartSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &date_part = bindings[0].get().Cast(); + auto &constant_expr = bindings[1].get().Cast(); + auto &constant = constant_expr.value; + + if (constant.IsNull()) { + // NULL specifier: return constant NULL + return make_uniq(Value(date_part.return_type)); + } + // otherwise check the specifier + auto specifier = GetDatePartSpecifier(StringValue::Get(constant)); + string new_function_name; + switch (specifier) { + case DatePartSpecifier::YEAR: + new_function_name = "year"; + break; + case DatePartSpecifier::MONTH: + new_function_name = "month"; + break; + case DatePartSpecifier::DAY: + new_function_name = "day"; + break; + case DatePartSpecifier::DECADE: + new_function_name = "decade"; + break; + case DatePartSpecifier::CENTURY: + new_function_name = "century"; + break; + case DatePartSpecifier::MILLENNIUM: + new_function_name = "millennium"; + break; + case DatePartSpecifier::QUARTER: + new_function_name = "quarter"; + break; + case DatePartSpecifier::WEEK: + new_function_name = "week"; + break; + case DatePartSpecifier::YEARWEEK: + new_function_name = "yearweek"; + break; + case DatePartSpecifier::DOW: + new_function_name = "dayofweek"; + break; + case DatePartSpecifier::ISODOW: + new_function_name = "isodow"; + break; + case DatePartSpecifier::DOY: + new_function_name = "dayofyear"; + break; + case DatePartSpecifier::MICROSECONDS: + new_function_name = "microsecond"; + break; + case DatePartSpecifier::MILLISECONDS: + new_function_name = "millisecond"; + break; + case DatePartSpecifier::SECOND: + new_function_name = "second"; + break; + case DatePartSpecifier::MINUTE: + new_function_name = "minute"; + break; + case DatePartSpecifier::HOUR: + new_function_name = "hour"; + break; + default: + return nullptr; + } + // found a replacement function: bind it + vector> children; + children.push_back(std::move(date_part.children[1])); + + ErrorData error; + FunctionBinder binder(rewriter.context); + auto function = binder.BindScalarFunction(DEFAULT_SCHEMA, new_function_name, std::move(children), error, false); + if (!function) { + error.Throw(); + } + return function; +} + +} // namespace duckdb diff --git a/src/optimizer/rule/empty_needle_removal.cpp b/src/optimizer/rule/empty_needle_removal.cpp index 8b137891791f..a8985838b1d5 100644 --- a/src/optimizer/rule/empty_needle_removal.cpp +++ b/src/optimizer/rule/empty_needle_removal.cpp @@ -1 +1,55 @@ +#include "duckdb/optimizer/rule/empty_needle_removal.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/expression/bound_case_expression.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" + +namespace duckdb { + +EmptyNeedleRemovalRule::EmptyNeedleRemovalRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on a FunctionExpression that has a foldable ConstantExpression + auto func = make_uniq(); + func->matchers.push_back(make_uniq()); + func->matchers.push_back(make_uniq()); + func->policy = SetMatcher::Policy::SOME; + + unordered_set functions = {"prefix", "contains", "suffix"}; + func->function = make_uniq(functions); + root = std::move(func); +} + +unique_ptr EmptyNeedleRemovalRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &root = bindings[0].get().Cast(); + D_ASSERT(root.children.size() == 2); + auto &prefix_expr = bindings[2].get(); + + // the constant_expr is a scalar expression that we have to fold + if (!prefix_expr.IsFoldable()) { + return nullptr; + } + D_ASSERT(root.return_type.id() == LogicalTypeId::BOOLEAN); + + auto prefix_value = ExpressionExecutor::EvaluateScalar(GetContext(), prefix_expr); + + if (prefix_value.IsNull()) { + return make_uniq(Value(LogicalType::BOOLEAN)); + } + + D_ASSERT(prefix_value.type() == prefix_expr.return_type); + auto &needle_string = StringValue::Get(prefix_value); + + // PREFIX('xyz', '') is TRUE + // PREFIX(NULL, '') is NULL + // so rewrite PREFIX(x, '') to TRUE_OR_NULL(x) + if (needle_string.empty()) { + return ExpressionRewriter::ConstantOrNull(std::move(root.children[0]), Value::BOOLEAN(true)); + } + return nullptr; +} + +} // namespace duckdb diff --git a/src/optimizer/rule/enum_comparison.cpp b/src/optimizer/rule/enum_comparison.cpp index 8b137891791f..aeb5e224800e 100644 --- a/src/optimizer/rule/enum_comparison.cpp +++ b/src/optimizer/rule/enum_comparison.cpp @@ -1 +1,71 @@ +#include "duckdb/optimizer/rule/enum_comparison.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/optimizer/matcher/type_matcher_id.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" +#include "duckdb/common/types.hpp" + +namespace duckdb { + +EnumComparisonRule::EnumComparisonRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on a ComparisonExpression that is an Equality and has a VARCHAR and ENUM as its children + auto op = make_uniq(); + // Enum requires expression to be root + op->expr_type = make_uniq(ExpressionType::COMPARE_EQUAL); + for (idx_t i = 0; i < 2; i++) { + auto child = make_uniq(); + child->type = make_uniq(LogicalTypeId::VARCHAR); + child->matcher = make_uniq(); + child->matcher->type = make_uniq(LogicalTypeId::ENUM); + op->matchers.push_back(std::move(child)); + } + root = std::move(op); +} + +bool AreMatchesPossible(LogicalType &left, LogicalType &right) { + LogicalType *small_enum, *big_enum; + if (EnumType::GetSize(left) < EnumType::GetSize(right)) { + small_enum = &left; + big_enum = &right; + } else { + small_enum = &right; + big_enum = &left; + } + auto &string_vec = EnumType::GetValuesInsertOrder(*small_enum); + auto string_vec_ptr = FlatVector::GetData(string_vec); + auto size = EnumType::GetSize(*small_enum); + for (idx_t i = 0; i < size; i++) { + auto key = string_vec_ptr[i].GetString(); + if (EnumType::GetPos(*big_enum, key) != -1) { + return true; + } + } + return false; +} +unique_ptr EnumComparisonRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + + auto &root = bindings[0].get().Cast(); + auto &left_child = bindings[1].get().Cast(); + auto &right_child = bindings[3].get().Cast(); + + if (!AreMatchesPossible(left_child.child->return_type, right_child.child->return_type)) { + vector> children; + children.push_back(std::move(root.left)); + children.push_back(std::move(root.right)); + return ExpressionRewriter::ConstantOrNull(std::move(children), Value::BOOLEAN(false)); + } + + if (!is_root || op.type != LogicalOperatorType::LOGICAL_FILTER) { + return nullptr; + } + + auto cast_left_to_right = + BoundCastExpression::AddDefaultCastToType(std::move(left_child.child), right_child.child->return_type, true); + return make_uniq(root.type, std::move(cast_left_to_right), std::move(right_child.child)); +} + +} // namespace duckdb diff --git a/src/optimizer/rule/in_clause_simplification_rule.cpp b/src/optimizer/rule/in_clause_simplification_rule.cpp index 8b137891791f..07e433773796 100644 --- a/src/optimizer/rule/in_clause_simplification_rule.cpp +++ b/src/optimizer/rule/in_clause_simplification_rule.cpp @@ -1 +1,58 @@ +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/optimizer/rule/in_clause_simplification.hpp" +#include "duckdb/planner/expression/list.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" + +namespace duckdb { + +InClauseSimplificationRule::InClauseSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on InClauseExpression that has a ConstantExpression as a check + auto op = make_uniq(); + op->policy = SetMatcher::Policy::SOME; + root = std::move(op); +} + +unique_ptr InClauseSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &expr = bindings[0].get().Cast(); + if (expr.children[0]->expression_class != ExpressionClass::BOUND_CAST) { + return nullptr; + } + auto &cast_expression = expr.children[0]->Cast(); + if (cast_expression.child->expression_class != ExpressionClass::BOUND_COLUMN_REF) { + return nullptr; + } + //! Here we check if we can apply the expression on the constant side + auto target_type = cast_expression.source_type(); + if (!BoundCastExpression::CastIsInvertible(cast_expression.return_type, target_type)) { + return nullptr; + } + vector> cast_list; + //! First check if we can cast all children + for (size_t i = 1; i < expr.children.size(); i++) { + if (expr.children[i]->expression_class != ExpressionClass::BOUND_CONSTANT) { + return nullptr; + } + D_ASSERT(expr.children[i]->IsFoldable()); + auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), *expr.children[i]); + auto new_constant = constant_value.DefaultTryCastAs(target_type); + if (!new_constant) { + return nullptr; + } else { + auto new_constant_expr = make_uniq(constant_value); + cast_list.push_back(std::move(new_constant_expr)); + } + } + //! We can cast, so we move the new constant + for (size_t i = 1; i < expr.children.size(); i++) { + expr.children[i] = std::move(cast_list[i - 1]); + + // expr->children[i] = std::move(new_constant_expr); + } + //! We can cast the full list, so we move the column + expr.children[0] = std::move(cast_expression.child); + return nullptr; +} + +} // namespace duckdb diff --git a/src/optimizer/rule/like_optimizations.cpp b/src/optimizer/rule/like_optimizations.cpp index 8b137891791f..96f7b1501e8a 100644 --- a/src/optimizer/rule/like_optimizations.cpp +++ b/src/optimizer/rule/like_optimizations.cpp @@ -1 +1,162 @@ +#include "duckdb/optimizer/rule/like_optimizations.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" + +namespace duckdb { + +LikeOptimizationRule::LikeOptimizationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match on a FunctionExpression that has a foldable ConstantExpression + auto func = make_uniq(); + func->matchers.push_back(make_uniq()); + func->matchers.push_back(make_uniq()); + func->policy = SetMatcher::Policy::ORDERED; + // we match on LIKE ("~~") and NOT LIKE ("!~~") + func->function = make_uniq(unordered_set {"!~~", "~~"}); + root = std::move(func); +} + +static bool PatternIsConstant(const string &pattern) { + for (idx_t i = 0; i < pattern.size(); i++) { + if (pattern[i] == '%' || pattern[i] == '_') { + return false; + } + } + return true; +} + +static bool PatternIsPrefix(const string &pattern) { + idx_t i; + for (i = pattern.size(); i > 0; i--) { + if (pattern[i - 1] != '%') { + break; + } + } + if (i == pattern.size()) { + // no trailing % + // cannot be a prefix + return false; + } + // continue to look in the string + // if there is a % or _ in the string (besides at the very end) this is not a prefix match + for (; i > 0; i--) { + if (pattern[i - 1] == '%' || pattern[i - 1] == '_') { + return false; + } + } + return true; +} + +static bool PatternIsSuffix(const string &pattern) { + idx_t i; + for (i = 0; i < pattern.size(); i++) { + if (pattern[i] != '%') { + break; + } + } + if (i == 0) { + // no leading % + // cannot be a suffix + return false; + } + // continue to look in the string + // if there is a % or _ in the string (besides at the beginning) this is not a suffix match + for (; i < pattern.size(); i++) { + if (pattern[i] == '%' || pattern[i] == '_') { + return false; + } + } + return true; +} + +static bool PatternIsContains(const string &pattern) { + idx_t start; + idx_t end; + for (start = 0; start < pattern.size(); start++) { + if (pattern[start] != '%') { + break; + } + } + for (end = pattern.size(); end > 0; end--) { + if (pattern[end - 1] != '%') { + break; + } + } + if (start == 0 || end == pattern.size()) { + // contains requires both a leading AND a trailing % + return false; + } + // check if there are any other special characters in the string + // if there is a % or _ in the string (besides at the beginning/end) this is not a contains match + for (idx_t i = start; i < end; i++) { + if (pattern[i] == '%' || pattern[i] == '_') { + return false; + } + } + return true; +} + +unique_ptr LikeOptimizationRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &root = bindings[0].get().Cast(); + auto &constant_expr = bindings[2].get().Cast(); + D_ASSERT(root.children.size() == 2); + + if (constant_expr.value.IsNull()) { + return make_uniq(Value(root.return_type)); + } + + // the constant_expr is a scalar expression that we have to fold + if (!constant_expr.IsFoldable()) { + return nullptr; + } + + auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), constant_expr); + D_ASSERT(constant_value.type() == constant_expr.return_type); + auto &patt_str = StringValue::Get(constant_value); + + bool is_not_like = root.function.name == "!~~"; + if (PatternIsConstant(patt_str)) { + // Pattern is constant + return make_uniq(is_not_like ? ExpressionType::COMPARE_NOTEQUAL + : ExpressionType::COMPARE_EQUAL, + std::move(root.children[0]), std::move(root.children[1])); + } else if (PatternIsPrefix(patt_str)) { + // Prefix LIKE pattern : [^%_]*[%]+, ignoring underscore + return ApplyRule(root, PrefixFun::GetFunction(), patt_str, is_not_like); + } else if (PatternIsSuffix(patt_str)) { + // Suffix LIKE pattern: [%]+[^%_]*, ignoring underscore + return ApplyRule(root, SuffixFun::GetFunction(), patt_str, is_not_like); + } else if (PatternIsContains(patt_str)) { + // Contains LIKE pattern: [%]+[^%_]*[%]+, ignoring underscore + return ApplyRule(root, ContainsFun::GetFunction(), patt_str, is_not_like); + } + return nullptr; +} + +unique_ptr LikeOptimizationRule::ApplyRule(BoundFunctionExpression &expr, ScalarFunction function, + string pattern, bool is_not_like) { + // replace LIKE by an optimized function + unique_ptr result; + auto new_function = + make_uniq(expr.return_type, std::move(function), std::move(expr.children), nullptr); + + // removing "%" from the pattern + pattern.erase(std::remove(pattern.begin(), pattern.end(), '%'), pattern.end()); + + new_function->children[1] = make_uniq(Value(std::move(pattern))); + + result = std::move(new_function); + if (is_not_like) { + auto negation = make_uniq(ExpressionType::OPERATOR_NOT, LogicalType::BOOLEAN); + negation->children.push_back(std::move(result)); + result = std::move(negation); + } + + return result; +} + +} // namespace duckdb diff --git a/src/optimizer/rule/move_constants.cpp b/src/optimizer/rule/move_constants.cpp index 8b137891791f..636265ff9131 100644 --- a/src/optimizer/rule/move_constants.cpp +++ b/src/optimizer/rule/move_constants.cpp @@ -1 +1,165 @@ +#include "duckdb/optimizer/rule/move_constants.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/value_operations/value_operations.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" + +namespace duckdb { + +MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + auto op = make_uniq(); + op->matchers.push_back(make_uniq()); + op->policy = SetMatcher::Policy::UNORDERED; + + auto arithmetic = make_uniq(); + // we handle multiplication, addition and subtraction because those are "easy" + // integer division makes the division case difficult + // e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules + arithmetic->function = make_uniq(unordered_set {"+", "-", "*"}); + // we match only on integral numeric types + arithmetic->type = make_uniq(); + auto child_constant_matcher = make_uniq(); + auto child_expression_matcher = make_uniq(); + child_constant_matcher->type = make_uniq(); + child_expression_matcher->type = make_uniq(); + arithmetic->matchers.push_back(std::move(child_constant_matcher)); + arithmetic->matchers.push_back(std::move(child_expression_matcher)); + arithmetic->policy = SetMatcher::Policy::SOME; + op->matchers.push_back(std::move(arithmetic)); + root = std::move(op); +} + +unique_ptr MoveConstantsRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &comparison = bindings[0].get().Cast(); + auto &outer_constant = bindings[1].get().Cast(); + auto &arithmetic = bindings[2].get().Cast(); + auto &inner_constant = bindings[3].get().Cast(); + D_ASSERT(arithmetic.return_type.IsIntegral()); + D_ASSERT(arithmetic.children[0]->return_type.IsIntegral()); + if (inner_constant.value.IsNull() || outer_constant.value.IsNull()) { + return make_uniq(Value(comparison.return_type)); + } + auto &constant_type = outer_constant.return_type; + hugeint_t outer_value = IntegralValue::Get(outer_constant.value); + hugeint_t inner_value = IntegralValue::Get(inner_constant.value); + + idx_t arithmetic_child_index = arithmetic.children[0].get() == &inner_constant ? 1 : 0; + auto &op_type = arithmetic.function.name; + if (op_type == "+") { + // [x + 1 COMP 10] OR [1 + x COMP 10] + // order does not matter in addition: + // simply change right side to 10-1 (outer_constant - inner_constant) + if (!Hugeint::TrySubtractInPlace(outer_value, inner_value)) { + return nullptr; + } + auto result_value = Value::HUGEINT(outer_value); + if (!result_value.DefaultTryCastAs(constant_type)) { + if (comparison.type != ExpressionType::COMPARE_EQUAL) { + return nullptr; + } + // if the cast is not possible then the comparison is not possible + // for example, if we have x + 5 = 3, where x is an unsigned number, we will get x = -2 + // since this is not possible we can remove the entire branch here + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(false)); + } + outer_constant.value = std::move(result_value); + } else if (op_type == "-") { + // [x - 1 COMP 10] O R [1 - x COMP 10] + // order matters in subtraction: + if (arithmetic_child_index == 0) { + // [x - 1 COMP 10] + // change right side to 10+1 (outer_constant + inner_constant) + if (!Hugeint::TryAddInPlace(outer_value, inner_value)) { + return nullptr; + } + auto result_value = Value::HUGEINT(outer_value); + if (!result_value.DefaultTryCastAs(constant_type)) { + // if the cast is not possible then an equality comparison is not possible + if (comparison.type != ExpressionType::COMPARE_EQUAL) { + return nullptr; + } + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(false)); + } + outer_constant.value = std::move(result_value); + } else { + // [1 - x COMP 10] + // change right side to 1-10=-9 + if (!Hugeint::TrySubtractInPlace(inner_value, outer_value)) { + return nullptr; + } + auto result_value = Value::HUGEINT(inner_value); + if (!result_value.DefaultTryCastAs(constant_type)) { + // if the cast is not possible then an equality comparison is not possible + if (comparison.type != ExpressionType::COMPARE_EQUAL) { + return nullptr; + } + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(false)); + } + outer_constant.value = std::move(result_value); + // in this case, we should also flip the comparison + // e.g. if we have [4 - x < 2] then we should have [x > 2] + comparison.type = FlipComparisonExpression(comparison.type); + } + } else { + D_ASSERT(op_type == "*"); + // [x * 2 COMP 10] OR [2 * x COMP 10] + // order does not matter in multiplication: + // change right side to 10/2 (outer_constant / inner_constant) + // but ONLY if outer_constant is cleanly divisible by the inner_constant + if (inner_value == 0) { + // x * 0, the result is either 0 or NULL + // we let the arithmetic_simplification rule take care of simplifying this first + return nullptr; + } + // check out of range for HUGEINT or not cleanly divisible + // HUGEINT is not cleanly divisible when outer_value == minimum and inner value == -1. (modulo overflow) + if ((outer_value == NumericLimits::Minimum() && inner_value == -1) || + outer_value % inner_value != 0) { + bool is_equality = comparison.type == ExpressionType::COMPARE_EQUAL; + bool is_inequality = comparison.type == ExpressionType::COMPARE_NOTEQUAL; + if (is_equality || is_inequality) { + // we know the values are not equal + // the result will be either FALSE or NULL (if COMPARE_EQUAL) + // or TRUE or NULL (if COMPARE_NOTEQUAL) + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(is_inequality)); + } else { + // not cleanly divisible and we are doing > >= < <=, skip the simplification for now + return nullptr; + } + } + if (inner_value < 0) { + // multiply by negative value, need to flip expression + comparison.type = FlipComparisonExpression(comparison.type); + } + // else divide the RHS by the LHS + // we need to do a range check on the cast even though we do a division + // because e.g. -128 / -1 = 128, which is out of range + auto result_value = Value::HUGEINT(outer_value / inner_value); + if (!result_value.DefaultTryCastAs(constant_type)) { + return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), + Value::BOOLEAN(false)); + } + outer_constant.value = std::move(result_value); + } + // replace left side with x + // first extract x from the arithmetic expression + auto arithmetic_child = std::move(arithmetic.children[arithmetic_child_index]); + // then place in the comparison + if (comparison.left.get() == &outer_constant) { + comparison.right = std::move(arithmetic_child); + } else { + comparison.left = std::move(arithmetic_child); + } + changes_made = true; + return nullptr; +} + +} // namespace duckdb diff --git a/src/planner/CMakeLists.txt b/src/planner/CMakeLists.txt index 8b137891791f..19f4c28a0758 100644 --- a/src/planner/CMakeLists.txt +++ b/src/planner/CMakeLists.txt @@ -1 +1,27 @@ +add_subdirectory(expression) +add_subdirectory(binder) +add_subdirectory(expression_binder) +add_subdirectory(filter) +add_subdirectory(operator) +add_subdirectory(subquery) +add_library_unity( + duckdb_planner + OBJECT + bound_result_modifier.cpp + bound_parameter_map.cpp + expression_iterator.cpp + expression.cpp + table_binding.cpp + expression_binder.cpp + joinside.cpp + logical_operator.cpp + binder.cpp + bind_context.cpp + planner.cpp + pragma_handler.cpp + logical_operator_visitor.cpp + table_filter.cpp) +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) From c14e21b77bbb1249b39c630428955b83a43f1b43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Fri, 5 Apr 2024 13:24:06 +0200 Subject: [PATCH 051/201] unfunking tests --- src/common/operator/cast_operators.cpp | 4 +- src/common/row_operations/row_aggregate.cpp | 4 +- src/common/string_util.cpp | 4 +- src/common/types/conflict_manager.cpp | 2 +- src/common/types/hugeint.cpp | 8 +-- .../numeric_inplace_operators.cpp | 2 +- src/core_functions/scalar/bit/bitstring.cpp | 27 +++++--- .../scalar/blob/create_sort_key.cpp | 4 +- src/core_functions/scalar/date/date_part.cpp | 4 +- .../scalar/list/array_slice.cpp | 22 +++--- src/core_functions/scalar/string/hex.cpp | 29 ++++---- .../operator/join/physical_iejoin.cpp | 2 +- src/execution/perfect_aggregate_hashtable.cpp | 8 +-- src/function/scalar/list/list_extract.cpp | 4 +- src/function/scalar/list/list_select.cpp | 3 +- src/function/scalar/strftime_format.cpp | 67 +++++++++---------- .../duckdb/common/sort/duckdb_pdqsort.hpp | 2 +- src/include/duckdb/common/types/bit.hpp | 2 +- .../duckdb/common/types/cast_helpers.hpp | 4 +- .../duckdb/storage/string_uncompressed.hpp | 2 +- src/main/capi/result-c.cpp | 2 +- .../transform/helpers/transform_typename.cpp | 4 +- .../buffer/buffer_pool_reservation.cpp | 4 +- 23 files changed, 106 insertions(+), 108 deletions(-) diff --git a/src/common/operator/cast_operators.cpp b/src/common/operator/cast_operators.cpp index be59aabbe296..769ff78cdaf7 100644 --- a/src/common/operator/cast_operators.cpp +++ b/src/common/operator/cast_operators.cpp @@ -1647,7 +1647,7 @@ struct HugeIntegerCastOperation { template static bool HandleDigit(T &state, uint8_t digit) { if (NEGATIVE) { - if (DUCKDB_UNLIKELY(UnsafeNumericCast(state.intermediate) < + if (DUCKDB_UNLIKELY(static_cast(state.intermediate) < (NumericLimits::Minimum() + digit) / 10)) { // intermediate is full: need to flush it if (!state.Flush()) { @@ -1702,7 +1702,7 @@ struct HugeIntegerCastOperation { remainder = negate_result; } state.decimal = remainder; - state.decimal_total_digits = UnsafeNumericCast(-e); + state.decimal_total_digits = static_cast(-e); state.decimal_intermediate = 0; state.decimal_intermediate_digits = 0; return Finalize(state); diff --git a/src/common/row_operations/row_aggregate.cpp b/src/common/row_operations/row_aggregate.cpp index 8be0ada4fa20..0ea80035e08d 100644 --- a/src/common/row_operations/row_aggregate.cpp +++ b/src/common/row_operations/row_aggregate.cpp @@ -95,8 +95,8 @@ void RowOperations::CombineStates(RowOperationsState &state, TupleDataLayout &la } // Now subtract the offset to get back to the original position - VectorOperations::AddInPlace(sources, UnsafeNumericCast(-offset), count); - VectorOperations::AddInPlace(targets, UnsafeNumericCast(-offset), count); + VectorOperations::AddInPlace(sources, -UnsafeNumericCast(offset), count); + VectorOperations::AddInPlace(targets, -UnsafeNumericCast(offset), count); } void RowOperations::FinalizeStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, diff --git a/src/common/string_util.cpp b/src/common/string_util.cpp index 7d33b1cd80f7..2e9eb676e527 100644 --- a/src/common/string_util.cpp +++ b/src/common/string_util.cpp @@ -204,7 +204,7 @@ string StringUtil::Upper(const string &str) { string StringUtil::Lower(const string &str) { string copy(str); transform(copy.begin(), copy.end(), copy.begin(), - [](unsigned char c) { return StringUtil::CharacterToLower(UnsafeNumericCast(c)); }); + [](unsigned char c) { return StringUtil::CharacterToLower(static_cast(c)); }); return (copy); } @@ -216,7 +216,7 @@ bool StringUtil::IsLower(const string &str) { uint64_t StringUtil::CIHash(const string &str) { uint32_t hash = 0; for (auto c : str) { - hash += UnsafeNumericCast(StringUtil::CharacterToLower(UnsafeNumericCast(c))); + hash += static_cast(StringUtil::CharacterToLower(static_cast(c))); hash += hash << 10; hash ^= hash >> 6; } diff --git a/src/common/types/conflict_manager.cpp b/src/common/types/conflict_manager.cpp index 171d45115fb4..8e7ce0b9bb93 100644 --- a/src/common/types/conflict_manager.cpp +++ b/src/common/types/conflict_manager.cpp @@ -159,7 +159,7 @@ bool ConflictManager::AddNull(idx_t chunk_index) { if (!IsConflict(LookupResultType::LOOKUP_NULL)) { return false; } - return AddHit(chunk_index, UnsafeNumericCast(DConstants::INVALID_INDEX)); + return AddHit(chunk_index, static_cast(DConstants::INVALID_INDEX)); } bool ConflictManager::SingleIndexTarget() const { diff --git a/src/common/types/hugeint.cpp b/src/common/types/hugeint.cpp index c99abaa9ee1e..d83d81caa13b 100644 --- a/src/common/types/hugeint.cpp +++ b/src/common/types/hugeint.cpp @@ -830,14 +830,14 @@ hugeint_t hugeint_t::operator>>(const hugeint_t &rhs) const { return *this; } else if (shift == 64) { result.upper = (upper < 0) ? -1 : 0; - result.lower = UnsafeNumericCast(upper); + result.lower = uint64_t(upper); } else if (shift < 64) { // perform lower shift in unsigned integer, and mask away the most significant bit result.lower = (uint64_t(upper) << (64 - shift)) | (lower >> shift); result.upper = upper >> shift; } else { D_ASSERT(shift < 128); - result.lower = UnsafeNumericCast(upper >> (shift - 64)); + result.lower = uint64_t(upper >> (shift - 64)); result.upper = (upper < 0) ? -1 : 0; } return result; @@ -852,7 +852,7 @@ hugeint_t hugeint_t::operator<<(const hugeint_t &rhs) const { if (rhs.upper != 0 || shift >= 128) { return hugeint_t(0); } else if (shift == 64) { - result.upper = UnsafeNumericCast(lower); + result.upper = int64_t(lower); result.lower = 0; } else if (shift == 0) { return *this; @@ -860,7 +860,7 @@ hugeint_t hugeint_t::operator<<(const hugeint_t &rhs) const { // perform upper shift in unsigned integer, and mask away the most significant bit uint64_t upper_shift = ((uint64_t(upper) << shift) + (lower >> (64 - shift))) & 0x7FFFFFFFFFFFFFFF; result.lower = lower << shift; - result.upper = UnsafeNumericCast(upper_shift); + result.upper = int64_t(upper_shift); } else { D_ASSERT(shift < 128); result.lower = 0; diff --git a/src/common/vector_operations/numeric_inplace_operators.cpp b/src/common/vector_operations/numeric_inplace_operators.cpp index 86b507a3ce9a..863f3ba8dd6f 100644 --- a/src/common/vector_operations/numeric_inplace_operators.cpp +++ b/src/common/vector_operations/numeric_inplace_operators.cpp @@ -30,7 +30,7 @@ void VectorOperations::AddInPlace(Vector &input, int64_t right, idx_t count) { D_ASSERT(input.GetVectorType() == VectorType::FLAT_VECTOR); auto data = FlatVector::GetData(input); for (idx_t i = 0; i < count; i++) { - data[i] += UnsafeNumericCast(right); + data[i] = UnsafeNumericCast(UnsafeNumericCast(data[i]) + right); } break; } diff --git a/src/core_functions/scalar/bit/bitstring.cpp b/src/core_functions/scalar/bit/bitstring.cpp index 8a3074250139..babfadfe01e7 100644 --- a/src/core_functions/scalar/bit/bitstring.cpp +++ b/src/core_functions/scalar/bit/bitstring.cpp @@ -8,13 +8,17 @@ namespace duckdb { // BitStringFunction //===--------------------------------------------------------------------===// static void BitStringFunction(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t input, idx_t n) { - if (n < input.GetSize()) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t n) { + if (n < 0) { + throw InvalidInputException("The bitstring length cannot be negative"); + } + if (idx_t(n) < input.GetSize()) { throw InvalidInputException("Length must be equal or larger than input string"); } idx_t len; Bit::TryGetBitStringSize(input, len, nullptr); // string verification + len = Bit::ComputeBitstringLen(n); string_t target = StringVector::EmptyString(result, len); Bit::BitString(input, n, target); @@ -24,7 +28,7 @@ static void BitStringFunction(DataChunk &args, ExpressionState &state, Vector &r } ScalarFunction BitStringFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::UBIGINT}, LogicalType::BIT, BitStringFunction); + return ScalarFunction({LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::BIT, BitStringFunction); } //===--------------------------------------------------------------------===// @@ -33,7 +37,7 @@ ScalarFunction BitStringFun::GetFunction() { struct GetBitOperator { template static inline TR Operation(TA input, TB n) { - if (n > Bit::BitLength(input) - 1) { + if (n < 0 || (idx_t)n > Bit::BitLength(input) - 1) { throw OutOfRangeException("bit index %s out of valid range (0..%s)", NumericHelper::ToString(n), NumericHelper::ToString(Bit::BitLength(input) - 1)); } @@ -42,20 +46,21 @@ struct GetBitOperator { }; ScalarFunction GetBitFun::GetFunction() { - return ScalarFunction({LogicalType::BIT, LogicalType::UBIGINT}, LogicalType::INTEGER, - ScalarFunction::BinaryFunction); + return ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::INTEGER, + ScalarFunction::BinaryFunction); } //===--------------------------------------------------------------------===// // set_bit //===--------------------------------------------------------------------===// static void SetBitOperation(DataChunk &args, ExpressionState &state, Vector &result) { - TernaryExecutor::Execute( - args.data[0], args.data[1], args.data[2], result, args.size(), [&](string_t input, idx_t n, idx_t new_value) { + TernaryExecutor::Execute( + args.data[0], args.data[1], args.data[2], result, args.size(), + [&](string_t input, int32_t n, int32_t new_value) { if (new_value != 0 && new_value != 1) { throw InvalidInputException("The new bit must be 1 or 0"); } - if (n > Bit::BitLength(input) - 1) { + if (n < 0 || (idx_t)n > Bit::BitLength(input) - 1) { throw OutOfRangeException("bit index %s out of valid range (0..%s)", NumericHelper::ToString(n), NumericHelper::ToString(Bit::BitLength(input) - 1)); } @@ -67,7 +72,7 @@ static void SetBitOperation(DataChunk &args, ExpressionState &state, Vector &res } ScalarFunction SetBitFun::GetFunction() { - return ScalarFunction({LogicalType::BIT, LogicalType::UBIGINT, LogicalType::UBIGINT}, LogicalType::BIT, + return ScalarFunction({LogicalType::BIT, LogicalType::INTEGER, LogicalType::INTEGER}, LogicalType::BIT, SetBitOperation); } diff --git a/src/core_functions/scalar/blob/create_sort_key.cpp b/src/core_functions/scalar/blob/create_sort_key.cpp index 09ee36f5167e..e47986674271 100644 --- a/src/core_functions/scalar/blob/create_sort_key.cpp +++ b/src/core_functions/scalar/blob/create_sort_key.cpp @@ -519,8 +519,8 @@ void ConstructSortKeyList(SortKeyVectorData &vector_data, SortKeyChunk chunk, So } // write the end-of-list delimiter - result_ptr[offset++] = UnsafeNumericCast(info.flip_bytes ? ~SortKeyVectorData::LIST_DELIMITER - : SortKeyVectorData::LIST_DELIMITER); + result_ptr[offset++] = static_cast(info.flip_bytes ? ~SortKeyVectorData::LIST_DELIMITER + : SortKeyVectorData::LIST_DELIMITER); } } diff --git a/src/core_functions/scalar/date/date_part.cpp b/src/core_functions/scalar/date/date_part.cpp index ebbe158b51fd..3d0b75775ceb 100644 --- a/src/core_functions/scalar/date/date_part.cpp +++ b/src/core_functions/scalar/date/date_part.cpp @@ -1432,8 +1432,8 @@ void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec & // Both define epoch, and the correct value is the sum. // So mask it out and compute it separately. - Operation(bigint_values, double_values, d, idx, mask & UnsafeNumericCast(~EPOCH)); - Operation(bigint_values, double_values, t, idx, mask & UnsafeNumericCast(~EPOCH)); + Operation(bigint_values, double_values, d, idx, mask & ~UnsafeNumericCast(EPOCH)); + Operation(bigint_values, double_values, t, idx, mask & ~UnsafeNumericCast(EPOCH)); if (mask & EPOCH) { auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); diff --git a/src/core_functions/scalar/list/array_slice.cpp b/src/core_functions/scalar/list/array_slice.cpp index c91247ae9f69..7651f410d023 100644 --- a/src/core_functions/scalar/list/array_slice.cpp +++ b/src/core_functions/scalar/list/array_slice.cpp @@ -40,7 +40,7 @@ unique_ptr ListSliceBindData::Copy() const { } template -static idx_t CalculateSliceLength(INDEX_TYPE begin, INDEX_TYPE end, INDEX_TYPE step, bool svalid) { +static idx_t CalculateSliceLength(idx_t begin, idx_t end, INDEX_TYPE step, bool svalid) { if (step < 0) { step = abs(step); } @@ -48,14 +48,14 @@ static idx_t CalculateSliceLength(INDEX_TYPE begin, INDEX_TYPE end, INDEX_TYPE s throw InvalidInputException("Slice step cannot be zero"); } if (step == 1) { - return UnsafeNumericCast(end - begin); - } else if (step >= (end - begin)) { + return NumericCast(end - begin); + } else if (static_cast(step) >= (end - begin)) { return 1; } if ((end - begin) % step != 0) { - return UnsafeNumericCast((end - begin) / step + 1); + return (end - begin) / step + 1; } - return UnsafeNumericCast((end - begin) / step); + return (end - begin) / step; } template @@ -64,7 +64,7 @@ INDEX_TYPE ValueLength(const INPUT_TYPE &value) { } template <> -idx_t ValueLength(const list_entry_t &value) { +int64_t ValueLength(const list_entry_t &value) { return value.length; } @@ -119,8 +119,8 @@ INPUT_TYPE SliceValue(Vector &result, INPUT_TYPE input, INDEX_TYPE begin, INDEX_ template <> list_entry_t SliceValue(Vector &result, list_entry_t input, int64_t begin, int64_t end) { - input.offset = UnsafeNumericCast(UnsafeNumericCast(input.offset) + begin); - input.length = UnsafeNumericCast(end - begin); + input.offset += begin; + input.length = end - begin; return input; } @@ -145,14 +145,14 @@ list_entry_t SliceValueWithSteps(Vector &result, SelectionVector &sel, list_entr return input; } input.length = CalculateSliceLength(begin, end, step, true); - auto child_idx = UnsafeNumericCast(UnsafeNumericCast(input.offset) + begin); + idx_t child_idx = input.offset + begin; if (step < 0) { - child_idx = UnsafeNumericCast(UnsafeNumericCast(input.offset) + end - 1); + child_idx = input.offset + end - 1; } input.offset = sel_idx; for (idx_t i = 0; i < input.length; i++) { sel.set_index(sel_idx, child_idx); - child_idx = UnsafeNumericCast(UnsafeNumericCast(child_idx) + step); + child_idx += step; sel_idx++; } return input; diff --git a/src/core_functions/scalar/string/hex.cpp b/src/core_functions/scalar/string/hex.cpp index 6afaeadb56a4..dffbae70d030 100644 --- a/src/core_functions/scalar/string/hex.cpp +++ b/src/core_functions/scalar/string/hex.cpp @@ -90,8 +90,7 @@ struct HexIntegralOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto num_leading_zero = - UnsafeNumericCast(CountZeros::Leading(UnsafeNumericCast(input))); + idx_t num_leading_zero = CountZeros::Leading(input); idx_t num_bits_to_check = 64 - num_leading_zero; D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); @@ -110,7 +109,7 @@ struct HexIntegralOperator { auto target = StringVector::EmptyString(result, buffer_size); auto output = target.GetDataWriteable(); - WriteHexBytes(UnsafeNumericCast(input), output, buffer_size); + WriteHexBytes(input, output, buffer_size); target.Finalize(); return target; @@ -121,7 +120,7 @@ struct HexHugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto num_leading_zero = UnsafeNumericCast(CountZeros::Leading(input)); + idx_t num_leading_zero = CountZeros::Leading(input); idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); // Special case: All bits are zero @@ -148,7 +147,7 @@ struct HexUhugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto num_leading_zero = UnsafeNumericCast(CountZeros::Leading(input)); + idx_t num_leading_zero = CountZeros::Leading(input); idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); // Special case: All bits are zero @@ -190,7 +189,7 @@ struct BinaryStrOperator { auto output = target.GetDataWriteable(); for (idx_t i = 0; i < size; ++i) { - auto byte = UnsafeNumericCast(data[i]); + uint8_t byte = data[i]; for (idx_t i = 8; i >= 1; --i) { *output = ((byte >> (i - 1)) & 0x01) + '0'; output++; @@ -206,8 +205,7 @@ struct BinaryIntegralOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto num_leading_zero = - UnsafeNumericCast(CountZeros::Leading(UnsafeNumericCast(input))); + idx_t num_leading_zero = CountZeros::Leading(input); idx_t num_bits_to_check = 64 - num_leading_zero; D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); @@ -226,7 +224,7 @@ struct BinaryIntegralOperator { auto target = StringVector::EmptyString(result, buffer_size); auto output = target.GetDataWriteable(); - WriteBinBytes(UnsafeNumericCast(input), output, buffer_size); + WriteBinBytes(input, output, buffer_size); target.Finalize(); return target; @@ -236,7 +234,7 @@ struct BinaryIntegralOperator { struct BinaryHugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto num_leading_zero = UnsafeNumericCast(CountZeros::Leading(input)); + idx_t num_leading_zero = CountZeros::Leading(input); idx_t buffer_size = sizeof(INPUT_TYPE) * 8 - num_leading_zero; // Special case: All bits are zero @@ -261,7 +259,7 @@ struct BinaryHugeIntOperator { struct BinaryUhugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto num_leading_zero = UnsafeNumericCast(CountZeros::Leading(input)); + idx_t num_leading_zero = CountZeros::Leading(input); idx_t buffer_size = sizeof(INPUT_TYPE) * 8 - num_leading_zero; // Special case: All bits are zero @@ -303,7 +301,7 @@ struct FromHexOperator { // Treated as a single byte idx_t i = 0; if (size % 2 != 0) { - *output = UnsafeNumericCast(StringUtil::GetHexValue(data[i])); + *output = StringUtil::GetHexValue(data[i]); i++; output++; } @@ -311,7 +309,7 @@ struct FromHexOperator { for (; i < size; i += 2) { uint8_t major = StringUtil::GetHexValue(data[i]); uint8_t minor = StringUtil::GetHexValue(data[i + 1]); - *output = UnsafeNumericCast((major << 4) | minor); + *output = UnsafeNumericCast((major << 4) | minor); output++; } @@ -345,7 +343,7 @@ struct FromBinaryOperator { byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); i++; } - *output = UnsafeNumericCast(byte); + *output = byte; output++; } @@ -355,8 +353,7 @@ struct FromBinaryOperator { byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); i++; } - *output = UnsafeNumericCast(byte); - ; + *output = byte; output++; } diff --git a/src/execution/operator/join/physical_iejoin.cpp b/src/execution/operator/join/physical_iejoin.cpp index bc99e9fbf0a8..8c89195b1989 100644 --- a/src/execution/operator/join/physical_iejoin.cpp +++ b/src/execution/operator/join/physical_iejoin.cpp @@ -325,7 +325,7 @@ idx_t IEJoinUnion::AppendKey(SortedTable &table, ExpressionExecutor &executor, S payload.data[0].Sequence(rid, increment, scan_count); payload.SetCardinality(scan_count); keys.Fuse(payload); - rid += UnsafeNumericCast(increment) * scan_count; + rid += increment * UnsafeNumericCast(scan_count); // Sort on the sort columns (which will no longer be needed) keys.Split(payload, payload_idx); diff --git a/src/execution/perfect_aggregate_hashtable.cpp b/src/execution/perfect_aggregate_hashtable.cpp index a46e9499b5a3..da7c20192452 100644 --- a/src/execution/perfect_aggregate_hashtable.cpp +++ b/src/execution/perfect_aggregate_hashtable.cpp @@ -64,7 +64,7 @@ static void ComputeGroupLocationTemplated(UnifiedVectorFormat &group_data, Value // we only need to handle non-null values here if (group_data.validity.RowIsValid(index)) { D_ASSERT(data[index] >= min_val); - auto adjusted_value = UnsafeNumericCast((data[index] - min_val) + 1); + uintptr_t adjusted_value = (data[index] - min_val) + 1; address_data[i] += adjusted_value << current_shift; } } @@ -72,7 +72,7 @@ static void ComputeGroupLocationTemplated(UnifiedVectorFormat &group_data, Value // no null values: we can directly compute the addresses for (idx_t i = 0; i < count; i++) { auto index = group_data.sel->get_index(i); - auto adjusted_value = UnsafeNumericCast((data[index] - min_val) + 1); + uintptr_t adjusted_value = (data[index] - min_val) + 1; address_data[i] += adjusted_value << current_shift; } } @@ -149,7 +149,7 @@ void PerfectAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload) } // move to the next aggregate payload_idx += input_count; - VectorOperations::AddInPlace(addresses, NumericCast(aggregate.payload_size), payload.size()); + VectorOperations::AddInPlace(addresses, aggregate.payload_size, payload.size()); } } @@ -199,7 +199,7 @@ static void ReconstructGroupVectorTemplated(uint32_t group_values[], Value &min, auto min_data = min.GetValueUnsafe(); for (idx_t i = 0; i < entry_count; i++) { // extract the value of this group from the total group index - auto group_index = UnsafeNumericCast((group_values[i] >> shift) & mask); + auto group_index = UnsafeNumericCast((group_values[i] >> shift) & mask); if (group_index == 0) { // if it is 0, the value is NULL validity_mask.SetInvalid(i); diff --git a/src/function/scalar/list/list_extract.cpp b/src/function/scalar/list/list_extract.cpp index e1566641aafb..822b9079cdc9 100644 --- a/src/function/scalar/list/list_extract.cpp +++ b/src/function/scalar/list/list_extract.cpp @@ -62,13 +62,13 @@ void ListExtractTemplate(idx_t count, UnifiedVectorFormat &list_data, UnifiedVec result_mask.SetInvalid(i); continue; } - child_offset = list_entry.offset + list_entry.length + UnsafeNumericCast(offsets_entry); + child_offset = list_entry.offset + list_entry.length + offsets_entry; } else { if ((idx_t)offsets_entry >= list_entry.length) { result_mask.SetInvalid(i); continue; } - child_offset = list_entry.offset + UnsafeNumericCast(offsets_entry); + child_offset = list_entry.offset + offsets_entry; } auto child_index = child_format.sel->get_index(child_offset); if (child_format.validity.RowIsValid(child_index)) { diff --git a/src/function/scalar/list/list_select.cpp b/src/function/scalar/list/list_select.cpp index 9d45de5a1ecc..70ae15194fcf 100644 --- a/src/function/scalar/list/list_select.cpp +++ b/src/function/scalar/list/list_select.cpp @@ -11,8 +11,7 @@ struct SetSelectionVectorSelect { ValidityMask &input_validity, Vector &selection_entry, idx_t child_idx, idx_t &target_offset, idx_t selection_offset, idx_t input_offset, idx_t target_length) { - auto sel_idx = - UnsafeNumericCast(selection_entry.GetValue(selection_offset + child_idx).GetValue() - 1); + auto sel_idx = selection_entry.GetValue(selection_offset + child_idx).GetValue() - 1; if (sel_idx < target_length) { selection_vector.set_index(target_offset, input_offset + sel_idx); if (!input_validity.RowIsValid(input_offset + sel_idx)) { diff --git a/src/function/scalar/strftime_format.cpp b/src/function/scalar/strftime_format.cpp index e8d31f522b6f..5181b005ed45 100644 --- a/src/function/scalar/strftime_format.cpp +++ b/src/function/scalar/strftime_format.cpp @@ -80,7 +80,7 @@ idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date if (0 <= year && year <= 9999) { return 4; } else { - return UnsafeNumericCast(NumericHelper::SignedLength(year)); + return NumericHelper::SignedLength(year); } } case StrTimeSpecifier::MONTH_DECIMAL: { @@ -129,14 +129,11 @@ idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date return len; } case StrTimeSpecifier::DAY_OF_MONTH: - return UnsafeNumericCast( - NumericHelper::UnsignedLength(UnsafeNumericCast(Date::ExtractDay(date)))); + return NumericHelper::UnsignedLength(Date::ExtractDay(date)); case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: - return UnsafeNumericCast( - NumericHelper::UnsignedLength(UnsafeNumericCast(Date::ExtractDayOfTheYear(date)))); + return NumericHelper::UnsignedLength(Date::ExtractDayOfTheYear(date)); case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - return UnsafeNumericCast(NumericHelper::UnsignedLength( - AbsValue(UnsafeNumericCast(Date::ExtractYear(date)) % 100))); + return NumericHelper::UnsignedLength(AbsValue(Date::ExtractYear(date)) % 100); default: throw InternalException("Unimplemented specifier for GetSpecifierLength"); } @@ -198,13 +195,13 @@ char *StrfTimeFormat::WritePadded(char *target, uint32_t value, size_t padding) D_ASSERT(padding > 1); if (padding % 2) { int decimals = value % 1000; - WritePadded3(target + padding - 3, UnsafeNumericCast(decimals)); + WritePadded3(target + padding - 3, decimals); value /= 1000; padding -= 3; } for (size_t i = 0; i < padding / 2; i++) { int decimals = value % 100; - WritePadded2(target + padding - 2 * (i + 1), UnsafeNumericCast(decimals)); + WritePadded2(target + padding - 2 * (i + 1), decimals); value /= 100; } return target + padding; @@ -248,26 +245,26 @@ char *StrfTimeFormat::WriteDateSpecifier(StrTimeSpecifier specifier, date_t date } case StrTimeSpecifier::DAY_OF_YEAR_PADDED: { int32_t doy = Date::ExtractDayOfTheYear(date); - target = WritePadded3(target, UnsafeNumericCast(doy)); + target = WritePadded3(target, doy); break; } case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - target = WritePadded2(target, UnsafeNumericCast(Date::ExtractWeekNumberRegular(date, true))); + target = WritePadded2(target, Date::ExtractWeekNumberRegular(date, true)); break; case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - target = WritePadded2(target, UnsafeNumericCast(Date::ExtractWeekNumberRegular(date, false))); + target = WritePadded2(target, Date::ExtractWeekNumberRegular(date, false)); break; case StrTimeSpecifier::WEEK_NUMBER_ISO: - target = WritePadded2(target, UnsafeNumericCast(Date::ExtractISOWeekNumber(date))); + target = WritePadded2(target, Date::ExtractISOWeekNumber(date)); break; case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: { - auto doy = UnsafeNumericCast(Date::ExtractDayOfTheYear(date)); + uint32_t doy = Date::ExtractDayOfTheYear(date); target += NumericHelper::UnsignedLength(doy); NumericHelper::FormatUnsigned(doy, target); break; } case StrTimeSpecifier::YEAR_ISO: - target = WritePadded(target, UnsafeNumericCast(Date::ExtractISOYearNumber(date)), 4); + target = WritePadded(target, Date::ExtractISOYearNumber(date), 4); break; case StrTimeSpecifier::WEEKDAY_ISO: *target = char('0' + uint8_t(Date::ExtractISODayOfTheWeek(date))); @@ -284,7 +281,7 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t // data contains [0] year, [1] month, [2] day, [3] hour, [4] minute, [5] second, [6] msec, [7] utc switch (specifier) { case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - target = WritePadded2(target, UnsafeNumericCast(data[2])); + target = WritePadded2(target, data[2]); break; case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: { auto &month_name = Date::MONTH_NAMES_ABBREVIATED[data[1] - 1]; @@ -295,14 +292,14 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t return WriteString(target, month_name); } case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - target = WritePadded2(target, UnsafeNumericCast(data[1])); + target = WritePadded2(target, data[1]); break; case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - target = WritePadded2(target, UnsafeNumericCast(AbsValue(data[0]) % 100)); + target = WritePadded2(target, AbsValue(data[0]) % 100); break; case StrTimeSpecifier::YEAR_DECIMAL: if (data[0] >= 0 && data[0] <= 9999) { - target = WritePadded(target, UnsafeNumericCast(data[0]), 4); + target = WritePadded(target, data[0], 4); } else { int32_t year = data[0]; if (data[0] < 0) { @@ -310,13 +307,13 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t year = -year; target++; } - auto len = NumericHelper::UnsignedLength(UnsafeNumericCast(year)); + auto len = NumericHelper::UnsignedLength(year); NumericHelper::FormatUnsigned(year, target + len); target += len; } break; case StrTimeSpecifier::HOUR_24_PADDED: { - target = WritePadded2(target, UnsafeNumericCast(data[3])); + target = WritePadded2(target, data[3]); break; } case StrTimeSpecifier::HOUR_12_PADDED: { @@ -324,7 +321,7 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t if (hour == 0) { hour = 12; } - target = WritePadded2(target, UnsafeNumericCast(hour)); + target = WritePadded2(target, hour); break; } case StrTimeSpecifier::AM_PM: @@ -332,20 +329,20 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t *target++ = 'M'; break; case StrTimeSpecifier::MINUTE_PADDED: { - target = WritePadded2(target, UnsafeNumericCast(data[4])); + target = WritePadded2(target, data[4]); break; } case StrTimeSpecifier::SECOND_PADDED: - target = WritePadded2(target, UnsafeNumericCast(data[5])); + target = WritePadded2(target, data[5]); break; case StrTimeSpecifier::NANOSECOND_PADDED: - target = WritePadded(target, UnsafeNumericCast(data[6] * Interval::NANOS_PER_MICRO), 9); + target = WritePadded(target, data[6] * Interval::NANOS_PER_MICRO, 9); break; case StrTimeSpecifier::MICROSECOND_PADDED: - target = WritePadded(target, UnsafeNumericCast(data[6]), 6); + target = WritePadded(target, data[6], 6); break; case StrTimeSpecifier::MILLISECOND_PADDED: - target = WritePadded3(target, UnsafeNumericCast(data[6] / Interval::MICROS_PER_MSEC)); + target = WritePadded3(target, data[6] / Interval::MICROS_PER_MSEC); break; case StrTimeSpecifier::UTC_OFFSET: { *target++ = (data[7] < 0) ? '-' : '+'; @@ -353,10 +350,10 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t auto offset = abs(data[7]); auto offset_hours = offset / Interval::MINS_PER_HOUR; auto offset_minutes = offset % Interval::MINS_PER_HOUR; - target = WritePadded2(target, UnsafeNumericCast(offset_hours)); + target = WritePadded2(target, offset_hours); if (offset_minutes) { *target++ = ':'; - target = WritePadded2(target, UnsafeNumericCast(offset_minutes)); + target = WritePadded2(target, offset_minutes); } break; } @@ -367,7 +364,7 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t } break; case StrTimeSpecifier::DAY_OF_MONTH: { - target = Write2(target, UnsafeNumericCast(data[2] % 100)); + target = Write2(target, data[2] % 100); break; } case StrTimeSpecifier::MONTH_DECIMAL: { @@ -375,7 +372,7 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t break; } case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: { - target = Write2(target, UnsafeNumericCast(AbsValue(data[0]) % 100)); + target = Write2(target, AbsValue(data[0]) % 100); break; } case StrTimeSpecifier::HOUR_24_DECIMAL: { @@ -848,9 +845,9 @@ bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { // numeric specifier: parse a number uint64_t number = 0; size_t start_pos = pos; - size_t end_pos = start_pos + UnsafeNumericCast(numeric_width[i]); + size_t end_pos = start_pos + numeric_width[i]; while (pos < size && pos < end_pos && StringUtil::CharacterIsDigit(data[pos])) { - number = number * 10 + UnsafeNumericCast(data[pos]) - '0'; + number = number * 10 + data[pos] - '0'; pos++; } if (pos == start_pos) { @@ -1232,7 +1229,7 @@ bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { // But tz must not be empty. if (tz_end == tz_begin) { error_message = "Empty Time Zone name"; - error_position = UnsafeNumericCast(tz_begin - data); + error_position = tz_begin - data; return false; } result.tz.assign(tz_begin, tz_end); @@ -1291,7 +1288,7 @@ bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: { // Adjust weekday to be 0-based for the week type if (has_weekday) { - weekday = (weekday + 7 - uint64_t(offset_specifier == StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST)) % 7; + weekday = (weekday + 7 - int(offset_specifier == StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST)) % 7; } // Get the start of week 1, move back 7 days and then weekno * 7 + weekday gives the date const auto jan1 = Date::FromDate(result_data[0], 1, 1); diff --git a/src/include/duckdb/common/sort/duckdb_pdqsort.hpp b/src/include/duckdb/common/sort/duckdb_pdqsort.hpp index cae339b20f2e..a71788ebf4b4 100644 --- a/src/include/duckdb/common/sort/duckdb_pdqsort.hpp +++ b/src/include/duckdb/common/sort/duckdb_pdqsort.hpp @@ -320,7 +320,7 @@ inline T *align_cacheline(T *p) { #else std::size_t ip = reinterpret_cast(p); #endif - ip = (ip + cacheline_size - 1) & duckdb::UnsafeNumericCast(-cacheline_size); + ip = (ip + cacheline_size - 1) & -duckdb::UnsafeNumericCast(cacheline_size); return reinterpret_cast(ip); } diff --git a/src/include/duckdb/common/types/bit.hpp b/src/include/duckdb/common/types/bit.hpp index 903b3cc75ff9..12c64d8b37f2 100644 --- a/src/include/duckdb/common/types/bit.hpp +++ b/src/include/duckdb/common/types/bit.hpp @@ -104,7 +104,7 @@ void Bit::NumericToBit(T numeric, string_t &output_str) { *output = 0; // set padding to 0 ++output; for (idx_t idx = 0; idx < sizeof(T); ++idx) { - output[idx] = UnsafeNumericCast(data[sizeof(T) - idx - 1]); + output[idx] = static_cast(data[sizeof(T) - idx - 1]); } Bit::Finalize(output_str); } diff --git a/src/include/duckdb/common/types/cast_helpers.hpp b/src/include/duckdb/common/types/cast_helpers.hpp index 03c45b5a09df..358438da3c17 100644 --- a/src/include/duckdb/common/types/cast_helpers.hpp +++ b/src/include/duckdb/common/types/cast_helpers.hpp @@ -62,9 +62,9 @@ class NumericHelper { template static string_t FormatSigned(SIGNED value, Vector &vector) { int sign = -(value < 0); - UNSIGNED unsigned_value = UNSIGNED(value) ^ UNSIGNED(sign) - UNSIGNED(sign); + UNSIGNED unsigned_value = UnsafeNumericCast(UNSIGNED(value ^ sign) - sign); int length = UnsignedLength(unsigned_value) - sign; - string_t result = StringVector::EmptyString(vector, NumericCast(length)); + string_t result = StringVector::EmptyString(vector, NumericCast(length)); auto dataptr = result.GetDataWriteable(); auto endptr = dataptr + length; endptr = FormatUnsigned(unsigned_value, endptr); diff --git a/src/include/duckdb/storage/string_uncompressed.hpp b/src/include/duckdb/storage/string_uncompressed.hpp index dacad03a0edb..0561efe099a0 100644 --- a/src/include/duckdb/storage/string_uncompressed.hpp +++ b/src/include/duckdb/storage/string_uncompressed.hpp @@ -148,7 +148,7 @@ struct UncompressedStringStorage { // place the dictionary offset into the set of vectors // note: for overflow strings we write negative value - result_data[target_idx] = NumericCast(-(*dictionary_size)); + result_data[target_idx] = -NumericCast((*dictionary_size)); } else { // string fits in block, append to dictionary and increment dictionary position D_ASSERT(string_length < NumericLimits::Maximum()); diff --git a/src/main/capi/result-c.cpp b/src/main/capi/result-c.cpp index 5b17c0132751..fae2b0cb2c24 100644 --- a/src/main/capi/result-c.cpp +++ b/src/main/capi/result-c.cpp @@ -117,7 +117,7 @@ struct CDecimalConverter : public CBaseConverter { template static DST Convert(SRC input) { duckdb_hugeint result; - result.lower = NumericCast(input); + result.lower = static_cast(input); result.upper = 0; return result; } diff --git a/src/parser/transform/helpers/transform_typename.cpp b/src/parser/transform/helpers/transform_typename.cpp index edea1222fba0..01c1697107e1 100644 --- a/src/parser/transform/helpers/transform_typename.cpp +++ b/src/parser/transform/helpers/transform_typename.cpp @@ -245,7 +245,7 @@ LogicalType Transformer::TransformTypeName(duckdb_libpgquery::PGTypeName &type_n if (val->type != duckdb_libpgquery::T_PGInteger) { throw ParserException("Expected integer value as array bound"); } - auto array_size = NumericCast(val->val.ival); + auto array_size = val->val.ival; if (array_size < 0) { // -1 if bounds are empty result_type = LogicalType::LIST(result_type); @@ -255,7 +255,7 @@ LogicalType Transformer::TransformTypeName(duckdb_libpgquery::PGTypeName &type_n } else if (array_size > static_cast(ArrayType::MAX_ARRAY_SIZE)) { throw ParserException("Arrays must have a size of at most %d", ArrayType::MAX_ARRAY_SIZE); } else { - result_type = LogicalType::ARRAY(result_type, array_size); + result_type = LogicalType::ARRAY(result_type, NumericCast(array_size)); } } } diff --git a/src/storage/buffer/buffer_pool_reservation.cpp b/src/storage/buffer/buffer_pool_reservation.cpp index 60a7bc882e61..f22a96ffa58f 100644 --- a/src/storage/buffer/buffer_pool_reservation.cpp +++ b/src/storage/buffer/buffer_pool_reservation.cpp @@ -23,8 +23,8 @@ BufferPoolReservation::~BufferPoolReservation() { } void BufferPoolReservation::Resize(idx_t new_size) { - int64_t delta = NumericCast(new_size) - NumericCast(size); - pool.IncreaseUsedMemory(tag, NumericCast(delta)); + int64_t delta = (int64_t)new_size - size; + pool.IncreaseUsedMemory(tag, delta); size = new_size; } From 5752be9f5ce0d259b876b8a2c624895ae257c958 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 5 Apr 2024 15:07:45 +0200 Subject: [PATCH 052/201] make wrapper implementation for weak_ptr and shared_ptr --- .../catalog_entry/duck_table_entry.cpp | 2 +- src/common/http_state.cpp | 2 +- src/common/re2_regex.cpp | 2 +- src/common/symbols.cpp | 26 ++--- src/execution/physical_plan/plan_cte.cpp | 2 +- .../physical_plan/plan_recursive_cte.cpp | 2 +- .../catalog_entry/duck_table_entry.hpp | 4 +- src/include/duckdb/common/helper.hpp | 8 ++ src/include/duckdb/common/re2_regex.hpp | 2 +- src/include/duckdb/common/shared_ptr.hpp | 109 +++++++++++++++++- src/include/duckdb/common/types.hpp | 1 + src/include/duckdb/common/weak_ptr.hpp | 88 ++++++++++++++ .../execution/operator/set/physical_cte.hpp | 2 +- .../operator/set/physical_recursive_cte.hpp | 2 +- .../execution/physical_plan_generator.hpp | 2 +- .../main/buffered_data/buffered_data.hpp | 1 + src/include/duckdb/main/client_context.hpp | 2 +- src/include/duckdb/main/relation.hpp | 2 +- .../duckdb/main/relation/query_relation.hpp | 2 +- .../main/relation/table_function_relation.hpp | 4 +- .../duckdb/main/relation/table_relation.hpp | 2 +- .../duckdb/main/relation/value_relation.hpp | 6 +- .../duckdb/main/relation/view_relation.hpp | 2 +- src/include/duckdb/planner/bind_context.hpp | 8 +- .../duckdb/transaction/transaction.hpp | 1 + src/main/relation.cpp | 4 +- src/main/relation/query_relation.cpp | 2 +- src/main/relation/read_csv_relation.cpp | 2 +- src/main/relation/table_relation.cpp | 2 +- src/main/relation/value_relation.cpp | 4 +- src/main/relation/view_relation.cpp | 2 +- src/planner/bind_context.cpp | 2 +- src/planner/binder.cpp | 1 + 33 files changed, 253 insertions(+), 50 deletions(-) create mode 100644 src/include/duckdb/common/weak_ptr.hpp diff --git a/src/catalog/catalog_entry/duck_table_entry.cpp b/src/catalog/catalog_entry/duck_table_entry.cpp index eb7be8d7a451..05dd9410e220 100644 --- a/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/catalog/catalog_entry/duck_table_entry.cpp @@ -72,7 +72,7 @@ IndexStorageInfo GetIndexInfo(const IndexConstraintType &constraint_type, unique } DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, BoundCreateTableInfo &info, - std::shared_ptr inherited_storage) + shared_ptr inherited_storage) : TableCatalogEntry(catalog, schema, info.Base()), storage(std::move(inherited_storage)), bound_constraints(std::move(info.bound_constraints)), column_dependency_manager(std::move(info.column_dependency_manager)) { diff --git a/src/common/http_state.cpp b/src/common/http_state.cpp index b07c0d4b31f2..9d583fd8f5f2 100644 --- a/src/common/http_state.cpp +++ b/src/common/http_state.cpp @@ -26,7 +26,7 @@ void CachedFileHandle::AllocateBuffer(idx_t size) { if (file->initialized) { throw InternalException("Cannot allocate a buffer for a cached file that was already initialized"); } - file->data = std::shared_ptr(new char[size], std::default_delete()); + file->data = shared_ptr(new char[size], std::default_delete()); file->capacity = size; } diff --git a/src/common/re2_regex.cpp b/src/common/re2_regex.cpp index f22b681acfe6..da239cd4a213 100644 --- a/src/common/re2_regex.cpp +++ b/src/common/re2_regex.cpp @@ -10,7 +10,7 @@ namespace duckdb_re2 { Regex::Regex(const std::string &pattern, RegexOptions options) { RE2::Options o; o.set_case_sensitive(options == RegexOptions::CASE_INSENSITIVE); - regex = std::make_shared(StringPiece(pattern), o); + regex = make_shared(StringPiece(pattern), o); } bool RegexSearchInternal(const char *input, Match &match, const Regex &r, RE2::Anchor anchor, size_t start, diff --git a/src/common/symbols.cpp b/src/common/symbols.cpp index 8b550f2c93f5..816a40b0387f 100644 --- a/src/common/symbols.cpp +++ b/src/common/symbols.cpp @@ -182,24 +182,24 @@ INSTANTIATE_VECTOR(vector>) INSTANTIATE_VECTOR(vector>) INSTANTIATE_VECTOR(vector>) INSTANTIATE_VECTOR(vector>) -INSTANTIATE_VECTOR(vector>) +INSTANTIATE_VECTOR(vector>) INSTANTIATE_VECTOR(vector>) -INSTANTIATE_VECTOR(vector>) -INSTANTIATE_VECTOR(vector>) -INSTANTIATE_VECTOR(vector>) +INSTANTIATE_VECTOR(vector>) +INSTANTIATE_VECTOR(vector>) +INSTANTIATE_VECTOR(vector>) INSTANTIATE_VECTOR(vector>) INSTANTIATE_VECTOR(vector>) -INSTANTIATE_VECTOR(vector>) +INSTANTIATE_VECTOR(vector>) INSTANTIATE_VECTOR(vector>) -template class std::shared_ptr; -template class std::shared_ptr; -template class std::shared_ptr; -template class std::shared_ptr; -template class std::shared_ptr; -template class std::shared_ptr; -template class std::shared_ptr; -template class std::weak_ptr; +template class shared_ptr; +template class shared_ptr; +template class shared_ptr; +template class shared_ptr; +template class shared_ptr; +template class shared_ptr; +template class shared_ptr; +template class weak_ptr; #if !defined(__clang__) template struct std::atomic; diff --git a/src/execution/physical_plan/plan_cte.cpp b/src/execution/physical_plan/plan_cte.cpp index f323aed463b4..7a306c3a54ca 100644 --- a/src/execution/physical_plan/plan_cte.cpp +++ b/src/execution/physical_plan/plan_cte.cpp @@ -12,7 +12,7 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalMaterializ D_ASSERT(op.children.size() == 2); // Create the working_table that the PhysicalCTE will use for evaluation. - auto working_table = std::make_shared(context, op.children[0]->types); + auto working_table = make_shared(context, op.children[0]->types); // Add the ColumnDataCollection to the context of this PhysicalPlanGenerator recursive_cte_tables[op.table_index] = working_table; diff --git a/src/execution/physical_plan/plan_recursive_cte.cpp b/src/execution/physical_plan/plan_recursive_cte.cpp index 82ddd9c21d3e..89da0cd0706e 100644 --- a/src/execution/physical_plan/plan_recursive_cte.cpp +++ b/src/execution/physical_plan/plan_recursive_cte.cpp @@ -12,7 +12,7 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalRecursiveC D_ASSERT(op.children.size() == 2); // Create the working_table that the PhysicalRecursiveCTE will use for evaluation. - auto working_table = std::make_shared(context, op.types); + auto working_table = make_shared(context, op.types); // Add the ColumnDataCollection to the context of this PhysicalPlanGenerator recursive_cte_tables[op.table_index] = working_table; diff --git a/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp b/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp index 0890ce829ab4..d8154f3a3615 100644 --- a/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp +++ b/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp @@ -17,7 +17,7 @@ class DuckTableEntry : public TableCatalogEntry { public: //! Create a TableCatalogEntry and initialize storage for it DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, BoundCreateTableInfo &info, - std::shared_ptr inherited_storage = nullptr); + shared_ptr inherited_storage = nullptr); public: unique_ptr AlterEntry(ClientContext &context, AlterInfo &info) override; @@ -64,7 +64,7 @@ class DuckTableEntry : public TableCatalogEntry { private: //! A reference to the underlying storage unit used for this table - std::shared_ptr storage; + shared_ptr storage; //! A list of constraints that are part of this table vector> bound_constraints; //! Manages dependencies of the individual columns of the table diff --git a/src/include/duckdb/common/helper.hpp b/src/include/duckdb/common/helper.hpp index b19b85f6d851..4c57e51a272b 100644 --- a/src/include/duckdb/common/helper.hpp +++ b/src/include/duckdb/common/helper.hpp @@ -65,6 +65,14 @@ make_uniq(ARGS&&... args) // NOLINT: mimic std style return unique_ptr, true>(new DATA_TYPE(std::forward(args)...)); } +template +inline +shared_ptr +make_shared(ARGS&&... args) // NOLINT: mimic std style +{ + return shared_ptr(new DATA_TYPE(std::forward(args)...)); +} + template inline typename TemplatedUniqueIf::templated_unique_single_t diff --git a/src/include/duckdb/common/re2_regex.hpp b/src/include/duckdb/common/re2_regex.hpp index de4b3313e362..ae5c48fd51f7 100644 --- a/src/include/duckdb/common/re2_regex.hpp +++ b/src/include/duckdb/common/re2_regex.hpp @@ -22,7 +22,7 @@ class Regex { } private: - std::shared_ptr regex; + shared_ptr regex; }; struct GroupMatch { diff --git a/src/include/duckdb/common/shared_ptr.hpp b/src/include/duckdb/common/shared_ptr.hpp index 4d97075eb8d3..615273d7c27e 100644 --- a/src/include/duckdb/common/shared_ptr.hpp +++ b/src/include/duckdb/common/shared_ptr.hpp @@ -9,11 +9,114 @@ #pragma once #include +#include + +template +class weak_ptr; namespace duckdb { -using std::make_shared; -using std::shared_ptr; -using std::weak_ptr; +template +class shared_ptr { +private: + template + friend class weak_ptr; + std::shared_ptr internal; + +public: + // Constructors + shared_ptr() : internal() { + } + shared_ptr(std::nullptr_t) : internal(nullptr) { + } // Implicit conversion + template + explicit shared_ptr(U *ptr) : internal(ptr) { + } + shared_ptr(const shared_ptr &other) : internal(other.internal) { + } + shared_ptr(std::shared_ptr other) : internal(std::move(other)) { + } + + // Destructor + ~shared_ptr() = default; + + // Assignment operators + shared_ptr &operator=(const shared_ptr &other) { + internal = other.internal; + return *this; + } + + // Modifiers + void reset() { + internal.reset(); + } + + template + void reset(U *ptr) { + internal.reset(ptr); + } + + template + void reset(U *ptr, Deleter deleter) { + internal.reset(ptr, deleter); + } + + // Observers + T *get() const { + return internal.get(); + } + + long use_count() const { + return internal.use_count(); + } + + explicit operator bool() const noexcept { + return internal.operator bool(); + } + + // Element access + std::__add_lvalue_reference_t operator*() const { + return *internal; + } + + T *operator->() const { + return internal.operator->(); + } + + // Relational operators + template + bool operator==(const shared_ptr &other) const noexcept { + return internal == other.internal; + } + + bool operator==(std::nullptr_t) const noexcept { + return internal == nullptr; + } + + template + bool operator!=(const shared_ptr &other) const noexcept { + return internal != other.internal; + } + + template + bool operator<(const shared_ptr &other) const noexcept { + return internal < other.internal; + } + + template + bool operator<=(const shared_ptr &other) const noexcept { + return internal <= other.internal; + } + + template + bool operator>(const shared_ptr &other) const noexcept { + return internal > other.internal; + } + + template + bool operator>=(const shared_ptr &other) const noexcept { + return internal >= other.internal; + } +}; } // namespace duckdb diff --git a/src/include/duckdb/common/types.hpp b/src/include/duckdb/common/types.hpp index bd5778ec402a..0151bc31a85f 100644 --- a/src/include/duckdb/common/types.hpp +++ b/src/include/duckdb/common/types.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/constants.hpp" #include "duckdb/common/optional_ptr.hpp" #include "duckdb/common/vector.hpp" +#include "duckdb/common/helper.hpp" #include diff --git a/src/include/duckdb/common/weak_ptr.hpp b/src/include/duckdb/common/weak_ptr.hpp new file mode 100644 index 000000000000..bf442e02ad6a --- /dev/null +++ b/src/include/duckdb/common/weak_ptr.hpp @@ -0,0 +1,88 @@ +#pragma once + +#include "duckdb/common/shared_ptr.hpp" +#include + +namespace duckdb { + +template +class weak_ptr { +private: + std::weak_ptr internal; + +public: + // Constructors + weak_ptr() : internal() { + } + template + weak_ptr(const shared_ptr &ptr) : internal(ptr.internal) { + } + weak_ptr(const weak_ptr &other) : internal(other.internal) { + } + + // Destructor + ~weak_ptr() = default; + + // Assignment operators + weak_ptr &operator=(const weak_ptr &other) { + internal = other.internal; + return *this; + } + + template + weak_ptr &operator=(const shared_ptr &ptr) { + internal = ptr; + return *this; + } + + // Modifiers + void reset() { + internal.reset(); + } + + // Observers + long use_count() const { + return internal.use_count(); + } + + bool expired() const { + return internal.expired(); + } + + shared_ptr lock() const { + return internal.lock(); + } + + // Relational operators + template + bool operator==(const weak_ptr &other) const noexcept { + return internal == other.internal; + } + + template + bool operator!=(const weak_ptr &other) const noexcept { + return internal != other.internal; + } + + template + bool operator<(const weak_ptr &other) const noexcept { + return internal < other.internal; + } + + template + bool operator<=(const weak_ptr &other) const noexcept { + return internal <= other.internal; + } + + template + bool operator>(const weak_ptr &other) const noexcept { + return internal > other.internal; + } + + template + bool operator>=(const weak_ptr &other) const noexcept { + return internal >= other.internal; + } +}; + +} // namespace duckdb diff --git a/src/include/duckdb/execution/operator/set/physical_cte.hpp b/src/include/duckdb/execution/operator/set/physical_cte.hpp index 3ff5fb8cf3e1..fc006d4699d6 100644 --- a/src/include/duckdb/execution/operator/set/physical_cte.hpp +++ b/src/include/duckdb/execution/operator/set/physical_cte.hpp @@ -26,7 +26,7 @@ class PhysicalCTE : public PhysicalOperator { vector> cte_scans; - std::shared_ptr working_table; + shared_ptr working_table; idx_t table_index; string ctename; diff --git a/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp b/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp index 88071368eb47..61d82fa09f97 100644 --- a/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp +++ b/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp @@ -29,7 +29,7 @@ class PhysicalRecursiveCTE : public PhysicalOperator { idx_t table_index; bool union_all; - std::shared_ptr working_table; + shared_ptr working_table; shared_ptr recursive_meta_pipeline; public: diff --git a/src/include/duckdb/execution/physical_plan_generator.hpp b/src/include/duckdb/execution/physical_plan_generator.hpp index 8024bc50ed09..24715b143a40 100644 --- a/src/include/duckdb/execution/physical_plan_generator.hpp +++ b/src/include/duckdb/execution/physical_plan_generator.hpp @@ -31,7 +31,7 @@ class PhysicalPlanGenerator { LogicalDependencyList dependencies; //! Recursive CTEs require at least one ChunkScan, referencing the working_table. //! This data structure is used to establish it. - unordered_map> recursive_cte_tables; + unordered_map> recursive_cte_tables; //! Materialized CTE ids must be collected. unordered_map>> materialized_ctes; diff --git a/src/include/duckdb/main/buffered_data/buffered_data.hpp b/src/include/duckdb/main/buffered_data/buffered_data.hpp index d831736d69ac..8065fbee2c73 100644 --- a/src/include/duckdb/main/buffered_data/buffered_data.hpp +++ b/src/include/duckdb/main/buffered_data/buffered_data.hpp @@ -15,6 +15,7 @@ #include "duckdb/common/optional_idx.hpp" #include "duckdb/execution/physical_operator_states.hpp" #include "duckdb/common/enums/pending_execution_result.hpp" +#include "duckdb/common/weak_ptr.hpp" namespace duckdb { diff --git a/src/include/duckdb/main/client_context.hpp b/src/include/duckdb/main/client_context.hpp index 91aad1307c80..9f608273b668 100644 --- a/src/include/duckdb/main/client_context.hpp +++ b/src/include/duckdb/main/client_context.hpp @@ -304,7 +304,7 @@ class ClientContextWrapper { } private: - std::weak_ptr client_context; + weak_ptr client_context; }; } // namespace duckdb diff --git a/src/include/duckdb/main/relation.hpp b/src/include/duckdb/main/relation.hpp index 7d1798712975..c494366208cb 100644 --- a/src/include/duckdb/main/relation.hpp +++ b/src/include/duckdb/main/relation.hpp @@ -36,7 +36,7 @@ class TableRef; class Relation : public std::enable_shared_from_this { public: - Relation(const std::shared_ptr &context, RelationType type) : context(context), type(type) { + Relation(const shared_ptr &context, RelationType type) : context(context), type(type) { } Relation(ClientContextWrapper &context, RelationType type) : context(context.GetContext()), type(type) { } diff --git a/src/include/duckdb/main/relation/query_relation.hpp b/src/include/duckdb/main/relation/query_relation.hpp index 67cfcc063b27..f35a8133cd73 100644 --- a/src/include/duckdb/main/relation/query_relation.hpp +++ b/src/include/duckdb/main/relation/query_relation.hpp @@ -16,7 +16,7 @@ class SelectStatement; class QueryRelation : public Relation { public: - QueryRelation(const std::shared_ptr &context, unique_ptr select_stmt, string alias); + QueryRelation(const shared_ptr &context, unique_ptr select_stmt, string alias); ~QueryRelation() override; unique_ptr select_stmt; diff --git a/src/include/duckdb/main/relation/table_function_relation.hpp b/src/include/duckdb/main/relation/table_function_relation.hpp index 394e92ba716b..9d605c68f74f 100644 --- a/src/include/duckdb/main/relation/table_function_relation.hpp +++ b/src/include/duckdb/main/relation/table_function_relation.hpp @@ -14,11 +14,11 @@ namespace duckdb { class TableFunctionRelation : public Relation { public: - TableFunctionRelation(const std::shared_ptr &context, string name, vector parameters, + TableFunctionRelation(const shared_ptr &context, string name, vector parameters, named_parameter_map_t named_parameters, shared_ptr input_relation_p = nullptr, bool auto_init = true); - TableFunctionRelation(const std::shared_ptr &context, string name, vector parameters, + TableFunctionRelation(const shared_ptr &context, string name, vector parameters, shared_ptr input_relation_p = nullptr, bool auto_init = true); ~TableFunctionRelation() override { } diff --git a/src/include/duckdb/main/relation/table_relation.hpp b/src/include/duckdb/main/relation/table_relation.hpp index 77a950cebe89..a14ce054ba6d 100644 --- a/src/include/duckdb/main/relation/table_relation.hpp +++ b/src/include/duckdb/main/relation/table_relation.hpp @@ -15,7 +15,7 @@ namespace duckdb { class TableRelation : public Relation { public: - TableRelation(const std::shared_ptr &context, unique_ptr description); + TableRelation(const shared_ptr &context, unique_ptr description); unique_ptr description; diff --git a/src/include/duckdb/main/relation/value_relation.hpp b/src/include/duckdb/main/relation/value_relation.hpp index b8aa47c01a48..81fb8c4e1ce7 100644 --- a/src/include/duckdb/main/relation/value_relation.hpp +++ b/src/include/duckdb/main/relation/value_relation.hpp @@ -15,9 +15,9 @@ namespace duckdb { class ValueRelation : public Relation { public: - ValueRelation(const std::shared_ptr &context, const vector> &values, - vector names, string alias = "values"); - ValueRelation(const std::shared_ptr &context, const string &values, vector names, + ValueRelation(const shared_ptr &context, const vector> &values, vector names, + string alias = "values"); + ValueRelation(const shared_ptr &context, const string &values, vector names, string alias = "values"); vector>> expressions; diff --git a/src/include/duckdb/main/relation/view_relation.hpp b/src/include/duckdb/main/relation/view_relation.hpp index 8a8afa26071e..75f63910a520 100644 --- a/src/include/duckdb/main/relation/view_relation.hpp +++ b/src/include/duckdb/main/relation/view_relation.hpp @@ -14,7 +14,7 @@ namespace duckdb { class ViewRelation : public Relation { public: - ViewRelation(const std::shared_ptr &context, string schema_name, string view_name); + ViewRelation(const shared_ptr &context, string schema_name, string view_name); string schema_name; string view_name; diff --git a/src/include/duckdb/planner/bind_context.hpp b/src/include/duckdb/planner/bind_context.hpp index cab1479db700..e4b63f830747 100644 --- a/src/include/duckdb/planner/bind_context.hpp +++ b/src/include/duckdb/planner/bind_context.hpp @@ -41,7 +41,7 @@ class BindContext { explicit BindContext(Binder &binder); //! Keep track of recursive CTE references - case_insensitive_map_t> cte_references; + case_insensitive_map_t> cte_references; public: //! Given a column name, find the matching table it belongs to. Throws an @@ -129,10 +129,10 @@ class BindContext { //! (e.g. "column_name" might return "COLUMN_NAME") string GetActualColumnName(const string &binding, const string &column_name); - case_insensitive_map_t> GetCTEBindings() { + case_insensitive_map_t> GetCTEBindings() { return cte_bindings; } - void SetCTEBindings(case_insensitive_map_t> bindings) { + void SetCTEBindings(case_insensitive_map_t> bindings) { cte_bindings = std::move(bindings); } @@ -165,6 +165,6 @@ class BindContext { vector> using_column_sets; //! The set of CTE bindings - case_insensitive_map_t> cte_bindings; + case_insensitive_map_t> cte_bindings; }; } // namespace duckdb diff --git a/src/include/duckdb/transaction/transaction.hpp b/src/include/duckdb/transaction/transaction.hpp index 1c47725c10dd..dff31db5701c 100644 --- a/src/include/duckdb/transaction/transaction.hpp +++ b/src/include/duckdb/transaction/transaction.hpp @@ -13,6 +13,7 @@ #include "duckdb/transaction/undo_buffer.hpp" #include "duckdb/common/atomic.hpp" #include "duckdb/transaction/transaction_data.hpp" +#include "duckdb/common/weak_ptr.hpp" namespace duckdb { class SequenceCatalogEntry; diff --git a/src/main/relation.cpp b/src/main/relation.cpp index 970ab4a58add..87be11809ff4 100644 --- a/src/main/relation.cpp +++ b/src/main/relation.cpp @@ -277,7 +277,7 @@ void Relation::Create(const string &schema_name, const string &table_name) { } shared_ptr Relation::WriteCSVRel(const string &csv_file, case_insensitive_map_t> options) { - return std::make_shared(shared_from_this(), csv_file, std::move(options)); + return make_shared(shared_from_this(), csv_file, std::move(options)); } void Relation::WriteCSV(const string &csv_file, case_insensitive_map_t> options) { @@ -292,7 +292,7 @@ void Relation::WriteCSV(const string &csv_file, case_insensitive_map_t Relation::WriteParquetRel(const string &parquet_file, case_insensitive_map_t> options) { auto write_parquet = - std::make_shared(shared_from_this(), parquet_file, std::move(options)); + make_shared(shared_from_this(), parquet_file, std::move(options)); return std::move(write_parquet); } diff --git a/src/main/relation/query_relation.cpp b/src/main/relation/query_relation.cpp index f6421601d4a0..0ce867f51574 100644 --- a/src/main/relation/query_relation.cpp +++ b/src/main/relation/query_relation.cpp @@ -6,7 +6,7 @@ namespace duckdb { -QueryRelation::QueryRelation(const std::shared_ptr &context, unique_ptr select_stmt_p, +QueryRelation::QueryRelation(const shared_ptr &context, unique_ptr select_stmt_p, string alias_p) : Relation(context, RelationType::QUERY_RELATION), select_stmt(std::move(select_stmt_p)), alias(std::move(alias_p)) { diff --git a/src/main/relation/read_csv_relation.cpp b/src/main/relation/read_csv_relation.cpp index 1500720e0069..1529de9a7637 100644 --- a/src/main/relation/read_csv_relation.cpp +++ b/src/main/relation/read_csv_relation.cpp @@ -30,7 +30,7 @@ static Value CreateValueFromFileList(const vector &file_list) { return Value::LIST(std::move(files)); } -ReadCSVRelation::ReadCSVRelation(const std::shared_ptr &context, const vector &input, +ReadCSVRelation::ReadCSVRelation(const shared_ptr &context, const vector &input, named_parameter_map_t &&options, string alias_p) : TableFunctionRelation(context, "read_csv_auto", {CreateValueFromFileList(input)}, nullptr, false), alias(std::move(alias_p)) { diff --git a/src/main/relation/table_relation.cpp b/src/main/relation/table_relation.cpp index b4954e3d4fc2..2cdc0d9d945a 100644 --- a/src/main/relation/table_relation.cpp +++ b/src/main/relation/table_relation.cpp @@ -9,7 +9,7 @@ namespace duckdb { -TableRelation::TableRelation(const std::shared_ptr &context, unique_ptr description) +TableRelation::TableRelation(const shared_ptr &context, unique_ptr description) : Relation(context, RelationType::TABLE_RELATION), description(std::move(description)) { } diff --git a/src/main/relation/value_relation.cpp b/src/main/relation/value_relation.cpp index fe611700b42a..3e11ed6bbcdb 100644 --- a/src/main/relation/value_relation.cpp +++ b/src/main/relation/value_relation.cpp @@ -8,7 +8,7 @@ namespace duckdb { -ValueRelation::ValueRelation(const std::shared_ptr &context, const vector> &values, +ValueRelation::ValueRelation(const shared_ptr &context, const vector> &values, vector names_p, string alias_p) : Relation(context, RelationType::VALUE_LIST_RELATION), names(std::move(names_p)), alias(std::move(alias_p)) { // create constant expressions for the values @@ -24,7 +24,7 @@ ValueRelation::ValueRelation(const std::shared_ptr &context, cons context->TryBindRelation(*this, this->columns); } -ValueRelation::ValueRelation(const std::shared_ptr &context, const string &values_list, +ValueRelation::ValueRelation(const shared_ptr &context, const string &values_list, vector names_p, string alias_p) : Relation(context, RelationType::VALUE_LIST_RELATION), names(std::move(names_p)), alias(std::move(alias_p)) { this->expressions = Parser::ParseValuesList(values_list, context->GetParserOptions()); diff --git a/src/main/relation/view_relation.cpp b/src/main/relation/view_relation.cpp index b432f5d3770a..f6f1806727a8 100644 --- a/src/main/relation/view_relation.cpp +++ b/src/main/relation/view_relation.cpp @@ -7,7 +7,7 @@ namespace duckdb { -ViewRelation::ViewRelation(const std::shared_ptr &context, string schema_name_p, string view_name_p) +ViewRelation::ViewRelation(const shared_ptr &context, string schema_name_p, string view_name_p) : Relation(context, RelationType::VIEW_RELATION), schema_name(std::move(schema_name_p)), view_name(std::move(view_name_p)) { context->TryBindRelation(*this, this->columns); diff --git a/src/planner/bind_context.cpp b/src/planner/bind_context.cpp index 6323ebad9b9e..eac1d69c2bd6 100644 --- a/src/planner/bind_context.cpp +++ b/src/planner/bind_context.cpp @@ -520,7 +520,7 @@ void BindContext::AddCTEBinding(idx_t index, const string &alias, const vector(0); + cte_references[alias] = make_shared(0); } void BindContext::AddContext(BindContext other) { diff --git a/src/planner/binder.cpp b/src/planner/binder.cpp index 8239fd292664..75e2a0482364 100644 --- a/src/planner/binder.cpp +++ b/src/planner/binder.cpp @@ -16,6 +16,7 @@ #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/operator/logical_sample.hpp" #include "duckdb/parser/query_node/list.hpp" +#include "duckdb/common/helper.hpp" #include From 0cbaa36b84fffed565372c1faac03d229d885c23 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 5 Apr 2024 17:38:35 +0200 Subject: [PATCH 053/201] changes to get duckdb::shared_ptr working, almost in a functional state --- scripts/format.py | 2 + .../catalog_entry/duck_table_entry.cpp | 15 +- src/common/allocator.cpp | 2 +- src/common/arrow/arrow_wrapper.cpp | 2 +- src/common/extra_type_info.cpp | 8 +- src/common/http_state.cpp | 6 +- src/common/re2_regex.cpp | 2 +- src/common/symbols.cpp | 34 ++-- src/common/types.cpp | 28 ++-- .../types/column/column_data_collection.cpp | 14 +- .../column/column_data_collection_segment.cpp | 2 +- .../types/column/partitioned_column_data.cpp | 4 +- .../types/row/partitioned_tuple_data.cpp | 4 +- .../types/row/tuple_data_collection.cpp | 4 +- src/common/types/value.cpp | 36 ++--- src/common/types/vector_cache.cpp | 6 +- src/execution/aggregate_hashtable.cpp | 2 +- src/execution/index/art/art.cpp | 3 +- .../operator/aggregate/aggregate_object.cpp | 2 +- .../aggregate/physical_hash_aggregate.cpp | 4 +- .../physical_ungrouped_aggregate.cpp | 2 +- .../operator/aggregate/physical_window.cpp | 2 +- .../csv_scanner/buffer_manager/csv_buffer.cpp | 8 +- .../buffer_manager/csv_buffer_manager.cpp | 2 +- .../scanner/string_value_scanner.cpp | 15 +- .../csv_scanner/sniffer/csv_sniffer.cpp | 4 +- .../table_function/csv_file_scanner.cpp | 20 +-- .../table_function/global_csv_state.cpp | 18 +-- .../helper/physical_buffered_collector.cpp | 2 +- .../operator/join/physical_asof_join.cpp | 2 +- .../operator/join/physical_hash_join.cpp | 4 +- .../operator/join/physical_range_join.cpp | 2 +- .../operator/order/physical_order.cpp | 2 +- .../physical_batch_copy_to_file.cpp | 2 +- .../persistent/physical_copy_to_file.cpp | 2 +- .../schema/physical_create_art_index.cpp | 2 +- .../operator/set/physical_recursive_cte.cpp | 2 +- src/execution/physical_plan/plan_cte.cpp | 2 +- .../physical_plan/plan_recursive_cte.cpp | 2 +- src/function/table/copy_csv.cpp | 2 +- src/function/table/read_csv.cpp | 2 +- src/function/table/sniff_csv.cpp | 2 +- .../duckdb/common/enable_shared_from_this.ipp | 40 +++++ src/include/duckdb/common/exception.hpp | 1 - src/include/duckdb/common/helper.hpp | 9 +- src/include/duckdb/common/http_state.hpp | 2 +- .../duckdb/common/multi_file_reader.hpp | 2 +- src/include/duckdb/common/re2_regex.hpp | 5 +- src/include/duckdb/common/shared_ptr.hpp | 121 ++------------ src/include/duckdb/common/shared_ptr.ipp | 150 ++++++++++++++++++ src/include/duckdb/common/types.hpp | 2 +- .../duckdb/common/types/selection_vector.hpp | 2 +- src/include/duckdb/common/unique_ptr.hpp | 12 +- .../common/{weak_ptr.hpp => weak_ptr.ipp} | 12 +- .../main/buffered_data/buffered_data.hpp | 2 +- src/include/duckdb/main/client_context.hpp | 2 +- src/include/duckdb/main/database.hpp | 2 +- src/include/duckdb/main/relation.hpp | 2 +- src/include/duckdb/parallel/event.hpp | 2 +- src/include/duckdb/parallel/interrupt.hpp | 1 + src/include/duckdb/parallel/meta_pipeline.hpp | 2 +- src/include/duckdb/parallel/pipeline.hpp | 2 +- src/include/duckdb/parallel/task.hpp | 2 +- src/include/duckdb/planner/binder.hpp | 4 +- src/include/duckdb/storage/object_cache.hpp | 6 +- .../duckdb/storage/serialization/types.json | 2 +- .../duckdb/transaction/local_storage.hpp | 2 +- .../duckdb/transaction/transaction.hpp | 2 +- src/main/capi/table_function-c.cpp | 2 +- src/main/client_context.cpp | 2 +- src/main/client_data.cpp | 4 +- src/main/connection.cpp | 21 +-- src/main/database.cpp | 4 +- src/main/db_instance_cache.cpp | 2 +- src/main/relation.cpp | 58 +++---- src/main/relation/read_csv_relation.cpp | 2 +- src/main/relation/table_relation.cpp | 6 +- src/parallel/executor.cpp | 17 +- src/parallel/meta_pipeline.cpp | 4 +- src/planner/bind_context.cpp | 4 +- src/planner/binder.cpp | 2 +- src/planner/bound_parameter_map.cpp | 2 +- src/planner/planner.cpp | 2 +- src/storage/buffer/block_manager.cpp | 2 +- src/storage/checkpoint_manager.cpp | 2 +- src/storage/data_table.cpp | 4 +- src/storage/local_storage.cpp | 14 +- src/storage/serialization/serialize_types.cpp | 2 +- src/storage/standard_buffer_manager.cpp | 8 +- src/storage/statistics/column_statistics.cpp | 6 +- src/storage/table/row_group.cpp | 2 +- src/storage/table/row_group_collection.cpp | 8 +- src/storage/table/row_version_manager.cpp | 2 +- src/storage/wal_replay.cpp | 2 +- 94 files changed, 483 insertions(+), 373 deletions(-) create mode 100644 src/include/duckdb/common/enable_shared_from_this.ipp create mode 100644 src/include/duckdb/common/shared_ptr.ipp rename src/include/duckdb/common/{weak_ptr.hpp => weak_ptr.ipp} (87%) diff --git a/scripts/format.py b/scripts/format.py index 40cfbb95903b..e053a021f0e4 100644 --- a/scripts/format.py +++ b/scripts/format.py @@ -41,6 +41,7 @@ extensions = [ '.cpp', + '.ipp', '.c', '.hpp', '.h', @@ -240,6 +241,7 @@ def get_changed_files(revision): format_commands = { '.cpp': cpp_format_command, + '.ipp': cpp_format_command, '.c': cpp_format_command, '.hpp': cpp_format_command, '.h': cpp_format_command, diff --git a/src/catalog/catalog_entry/duck_table_entry.cpp b/src/catalog/catalog_entry/duck_table_entry.cpp index 05dd9410e220..53ca0cbba0c3 100644 --- a/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/catalog/catalog_entry/duck_table_entry.cpp @@ -83,8 +83,9 @@ DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, Bou for (auto &col_def : columns.Physical()) { storage_columns.push_back(col_def.Copy()); } - storage = make_shared(catalog.GetAttached(), StorageManager::Get(catalog).GetTableIOManager(&info), - schema.name, name, std::move(storage_columns), std::move(info.data)); + storage = + make_refcounted(catalog.GetAttached(), StorageManager::Get(catalog).GetTableIOManager(&info), + schema.name, name, std::move(storage_columns), std::move(info.data)); // create the unique indexes for the UNIQUE and PRIMARY KEY and FOREIGN KEY constraints idx_t indexes_idx = 0; @@ -344,7 +345,7 @@ unique_ptr DuckTableEntry::AddColumn(ClientContext &context, AddCo auto binder = Binder::CreateBinder(context); auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); auto new_storage = - make_shared(context, *storage, info.new_column, *bound_create_info->bound_defaults.back()); + make_refcounted(context, *storage, info.new_column, *bound_create_info->bound_defaults.back()); return make_uniq(catalog, schema, *bound_create_info, new_storage); } @@ -480,7 +481,7 @@ unique_ptr DuckTableEntry::RemoveColumn(ClientContext &context, Re return make_uniq(catalog, schema, *bound_create_info, storage); } auto new_storage = - make_shared(context, *storage, columns.LogicalToPhysical(LogicalIndex(removed_index)).index); + make_refcounted(context, *storage, columns.LogicalToPhysical(LogicalIndex(removed_index)).index); return make_uniq(catalog, schema, *bound_create_info, new_storage); } @@ -548,7 +549,7 @@ unique_ptr DuckTableEntry::SetNotNull(ClientContext &context, SetN } // Return with new storage info. Note that we need the bound column index here. - auto new_storage = make_shared( + auto new_storage = make_refcounted( context, *storage, make_uniq(columns.LogicalToPhysical(LogicalIndex(not_null_idx)))); return make_uniq(catalog, schema, *bound_create_info, new_storage); } @@ -659,8 +660,8 @@ unique_ptr DuckTableEntry::ChangeColumnType(ClientContext &context } auto new_storage = - make_shared(context, *storage, columns.LogicalToPhysical(LogicalIndex(change_idx)).index, - info.target_type, std::move(storage_oids), *bound_expression); + make_refcounted(context, *storage, columns.LogicalToPhysical(LogicalIndex(change_idx)).index, + info.target_type, std::move(storage_oids), *bound_expression); auto result = make_uniq(catalog, schema, *bound_create_info, new_storage); return std::move(result); } diff --git a/src/common/allocator.cpp b/src/common/allocator.cpp index 5487578f0258..835aba386ff2 100644 --- a/src/common/allocator.cpp +++ b/src/common/allocator.cpp @@ -195,7 +195,7 @@ data_ptr_t Allocator::DefaultReallocate(PrivateAllocatorData *private_data, data } shared_ptr &Allocator::DefaultAllocatorReference() { - static shared_ptr DEFAULT_ALLOCATOR = make_shared(); + static shared_ptr DEFAULT_ALLOCATOR = make_refcounted(); return DEFAULT_ALLOCATOR; } diff --git a/src/common/arrow/arrow_wrapper.cpp b/src/common/arrow/arrow_wrapper.cpp index d439d99079b1..1bcc48ce91b6 100644 --- a/src/common/arrow/arrow_wrapper.cpp +++ b/src/common/arrow/arrow_wrapper.cpp @@ -50,7 +50,7 @@ void ArrowArrayStreamWrapper::GetSchema(ArrowSchemaWrapper &schema) { } shared_ptr ArrowArrayStreamWrapper::GetNextChunk() { - auto current_chunk = make_shared(); + auto current_chunk = make_refcounted(); if (arrow_array_stream.get_next(&arrow_array_stream, ¤t_chunk->arrow_array)) { // LCOV_EXCL_START throw InvalidInputException("arrow_scan: get_next failed(): %s", string(GetError())); } // LCOV_EXCL_STOP diff --git a/src/common/extra_type_info.cpp b/src/common/extra_type_info.cpp index f8d27d86e8f0..c5c3ca1b6b8f 100644 --- a/src/common/extra_type_info.cpp +++ b/src/common/extra_type_info.cpp @@ -190,7 +190,7 @@ struct EnumTypeInfoTemplated : public EnumTypeInfo { deserializer.ReadList(201, "values", [&](Deserializer::List &list, idx_t i) { strings[i] = StringVector::AddStringOrBlob(values_insert_order, list.ReadElement()); }); - return make_shared(values_insert_order, size); + return make_refcounted(values_insert_order, size); } const string_map_t &GetValues() const { @@ -227,13 +227,13 @@ LogicalType EnumTypeInfo::CreateType(Vector &ordered_data, idx_t size) { auto enum_internal_type = EnumTypeInfo::DictType(size); switch (enum_internal_type) { case PhysicalType::UINT8: - info = make_shared>(ordered_data, size); + info = make_refcounted>(ordered_data, size); break; case PhysicalType::UINT16: - info = make_shared>(ordered_data, size); + info = make_refcounted>(ordered_data, size); break; case PhysicalType::UINT32: - info = make_shared>(ordered_data, size); + info = make_refcounted>(ordered_data, size); break; default: throw InternalException("Invalid Physical Type for ENUMs"); diff --git a/src/common/http_state.cpp b/src/common/http_state.cpp index 9d583fd8f5f2..880454b87a75 100644 --- a/src/common/http_state.cpp +++ b/src/common/http_state.cpp @@ -62,14 +62,14 @@ shared_ptr HTTPState::TryGetState(ClientContext &context, bool create auto lookup = context.registered_state.find("http_state"); if (lookup != context.registered_state.end()) { - return std::static_pointer_cast(lookup->second); + return shared_ptr_cast(lookup->second); } if (!create_on_missing) { return nullptr; } - auto http_state = make_shared(); + auto http_state = make_refcounted(); context.registered_state["http_state"] = http_state; return http_state; } @@ -87,7 +87,7 @@ shared_ptr &HTTPState::GetCachedFile(const string &path) { lock_guard lock(cached_files_mutex); auto &cache_entry_ref = cached_files[path]; if (!cache_entry_ref) { - cache_entry_ref = make_shared(); + cache_entry_ref = make_refcounted(); } return cache_entry_ref; } diff --git a/src/common/re2_regex.cpp b/src/common/re2_regex.cpp index da239cd4a213..4b3e2fb8e87b 100644 --- a/src/common/re2_regex.cpp +++ b/src/common/re2_regex.cpp @@ -10,7 +10,7 @@ namespace duckdb_re2 { Regex::Regex(const std::string &pattern, RegexOptions options) { RE2::Options o; o.set_case_sensitive(options == RegexOptions::CASE_INSENSITIVE); - regex = make_shared(StringPiece(pattern), o); + regex = duckdb::make_refcounted(StringPiece(pattern), o); } bool RegexSearchInternal(const char *input, Match &match, const Regex &r, RE2::Anchor anchor, size_t start, diff --git a/src/common/symbols.cpp b/src/common/symbols.cpp index 816a40b0387f..ee16effeac43 100644 --- a/src/common/symbols.cpp +++ b/src/common/symbols.cpp @@ -147,6 +147,15 @@ template class unique_ptr; template class unique_ptr; template class unique_ptr; +template class shared_ptr; +template class shared_ptr; +template class shared_ptr; +template class shared_ptr; +template class shared_ptr; +template class shared_ptr; +template class shared_ptr; +template class weak_ptr; + } // namespace duckdb #define INSTANTIATE_VECTOR(VECTOR_DEFINITION) \ @@ -160,15 +169,6 @@ template class unique_ptr; template std::VECTOR_DEFINITION::const_reference std::VECTOR_DEFINITION::front() const; \ template std::VECTOR_DEFINITION::reference std::VECTOR_DEFINITION::front(); -template class duckdb::vector; -template class duckdb::vector; -template class duckdb::vector; -template class duckdb::vector; -template class duckdb::vector; -template class duckdb::vector; -template class duckdb::vector>; -template class duckdb::vector; - INSTANTIATE_VECTOR(vector) INSTANTIATE_VECTOR(vector) INSTANTIATE_VECTOR(vector) @@ -192,14 +192,14 @@ INSTANTIATE_VECTOR(vector>) INSTANTIATE_VECTOR(vector>) INSTANTIATE_VECTOR(vector>) -template class shared_ptr; -template class shared_ptr; -template class shared_ptr; -template class shared_ptr; -template class shared_ptr; -template class shared_ptr; -template class shared_ptr; -template class weak_ptr; +template class duckdb::vector; +template class duckdb::vector; +template class duckdb::vector; +template class duckdb::vector; +template class duckdb::vector; +template class duckdb::vector; +template class duckdb::vector>; +template class duckdb::vector; #if !defined(__clang__) template struct std::atomic; diff --git a/src/common/types.cpp b/src/common/types.cpp index e4318c26ed22..f70d69d8102a 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -1127,7 +1127,7 @@ bool ApproxEqual(double ldecimal, double rdecimal) { //===--------------------------------------------------------------------===// void LogicalType::SetAlias(string alias) { if (!type_info_) { - type_info_ = make_shared(ExtraTypeInfoType::GENERIC_TYPE_INFO, std::move(alias)); + type_info_ = make_refcounted(ExtraTypeInfoType::GENERIC_TYPE_INFO, std::move(alias)); } else { type_info_->alias = std::move(alias); } @@ -1176,7 +1176,7 @@ uint8_t DecimalType::MaxWidth() { LogicalType LogicalType::DECIMAL(uint8_t width, uint8_t scale) { D_ASSERT(width >= scale); - auto type_info = make_shared(width, scale); + auto type_info = make_refcounted(width, scale); return LogicalType(LogicalTypeId::DECIMAL, std::move(type_info)); } @@ -1198,7 +1198,7 @@ string StringType::GetCollation(const LogicalType &type) { } LogicalType LogicalType::VARCHAR_COLLATION(string collation) { // NOLINT - auto string_info = make_shared(std::move(collation)); + auto string_info = make_refcounted(std::move(collation)); return LogicalType(LogicalTypeId::VARCHAR, std::move(string_info)); } @@ -1213,7 +1213,7 @@ const LogicalType &ListType::GetChildType(const LogicalType &type) { } LogicalType LogicalType::LIST(const LogicalType &child) { - auto info = make_shared(child); + auto info = make_refcounted(child); return LogicalType(LogicalTypeId::LIST, std::move(info)); } @@ -1285,12 +1285,12 @@ bool StructType::IsUnnamed(const LogicalType &type) { } LogicalType LogicalType::STRUCT(child_list_t children) { - auto info = make_shared(std::move(children)); + auto info = make_refcounted(std::move(children)); return LogicalType(LogicalTypeId::STRUCT, std::move(info)); } LogicalType LogicalType::AGGREGATE_STATE(aggregate_state_t state_type) { // NOLINT - auto info = make_shared(std::move(state_type)); + auto info = make_refcounted(std::move(state_type)); return LogicalType(LogicalTypeId::AGGREGATE_STATE, std::move(info)); } @@ -1315,7 +1315,7 @@ LogicalType LogicalType::MAP(const LogicalType &child_p) { new_children[1].first = "value"; auto child = LogicalType::STRUCT(std::move(new_children)); - auto info = make_shared(child); + auto info = make_refcounted(child); return LogicalType(LogicalTypeId::MAP, std::move(info)); } @@ -1344,7 +1344,7 @@ LogicalType LogicalType::UNION(child_list_t members) { D_ASSERT(members.size() <= UnionType::MAX_UNION_MEMBERS); // union types always have a hidden "tag" field in front members.insert(members.begin(), {"", LogicalType::UTINYINT}); - auto info = make_shared(std::move(members)); + auto info = make_refcounted(std::move(members)); return LogicalType(LogicalTypeId::UNION, std::move(info)); } @@ -1397,12 +1397,12 @@ const string &UserType::GetTypeName(const LogicalType &type) { } LogicalType LogicalType::USER(const string &user_type_name) { - auto info = make_shared(user_type_name); + auto info = make_refcounted(user_type_name); return LogicalType(LogicalTypeId::USER, std::move(info)); } LogicalType LogicalType::USER(string catalog, string schema, string name) { - auto info = make_shared(std::move(catalog), std::move(schema), std::move(name)); + auto info = make_refcounted(std::move(catalog), std::move(schema), std::move(name)); return LogicalType(LogicalTypeId::USER, std::move(info)); } @@ -1518,12 +1518,12 @@ LogicalType ArrayType::ConvertToList(const LogicalType &type) { LogicalType LogicalType::ARRAY(const LogicalType &child, idx_t size) { D_ASSERT(size > 0); D_ASSERT(size < ArrayType::MAX_ARRAY_SIZE); - auto info = make_shared(child, size); + auto info = make_refcounted(child, size); return LogicalType(LogicalTypeId::ARRAY, std::move(info)); } LogicalType LogicalType::ARRAY(const LogicalType &child) { - auto info = make_shared(child, 0); + auto info = make_refcounted(child, 0); return LogicalType(LogicalTypeId::ARRAY, std::move(info)); } @@ -1531,7 +1531,7 @@ LogicalType LogicalType::ARRAY(const LogicalType &child) { // Any Type //===--------------------------------------------------------------------===// LogicalType LogicalType::ANY_PARAMS(LogicalType target, idx_t cast_score) { // NOLINT - auto type_info = make_shared(std::move(target), cast_score); + auto type_info = make_refcounted(std::move(target), cast_score); return LogicalType(LogicalTypeId::ANY, std::move(type_info)); } @@ -1584,7 +1584,7 @@ LogicalType LogicalType::INTEGER_LITERAL(const Value &constant) { // NOLINT if (!constant.type().IsIntegral()) { throw InternalException("INTEGER_LITERAL can only be made from literals of integer types"); } - auto type_info = make_shared(constant); + auto type_info = make_refcounted(constant); return LogicalType(LogicalTypeId::INTEGER_LITERAL, std::move(type_info)); } diff --git a/src/common/types/column/column_data_collection.cpp b/src/common/types/column/column_data_collection.cpp index 8552332c2dc8..0f931d7cf909 100644 --- a/src/common/types/column/column_data_collection.cpp +++ b/src/common/types/column/column_data_collection.cpp @@ -51,17 +51,17 @@ ColumnDataCollection::ColumnDataCollection(Allocator &allocator_p) { types.clear(); count = 0; this->finished_append = false; - allocator = make_shared(allocator_p); + allocator = make_refcounted(allocator_p); } ColumnDataCollection::ColumnDataCollection(Allocator &allocator_p, vector types_p) { Initialize(std::move(types_p)); - allocator = make_shared(allocator_p); + allocator = make_refcounted(allocator_p); } ColumnDataCollection::ColumnDataCollection(BufferManager &buffer_manager, vector types_p) { Initialize(std::move(types_p)); - allocator = make_shared(buffer_manager); + allocator = make_refcounted(buffer_manager); } ColumnDataCollection::ColumnDataCollection(shared_ptr allocator_p, vector types_p) { @@ -71,7 +71,7 @@ ColumnDataCollection::ColumnDataCollection(shared_ptr alloc ColumnDataCollection::ColumnDataCollection(ClientContext &context, vector types_p, ColumnDataAllocatorType type) - : ColumnDataCollection(make_shared(context, type), std::move(types_p)) { + : ColumnDataCollection(make_refcounted(context, type), std::move(types_p)) { D_ASSERT(!types.empty()); } @@ -199,7 +199,7 @@ ColumnDataChunkIterationHelper::ColumnDataChunkIterationHelper(const ColumnDataC ColumnDataChunkIterationHelper::ColumnDataChunkIterator::ColumnDataChunkIterator( const ColumnDataCollection *collection_p, vector column_ids_p) - : collection(collection_p), scan_chunk(make_shared()), row_index(0) { + : collection(collection_p), scan_chunk(make_refcounted()), row_index(0) { if (!collection) { return; } @@ -246,7 +246,7 @@ ColumnDataRowIterationHelper::ColumnDataRowIterationHelper(const ColumnDataColle } ColumnDataRowIterationHelper::ColumnDataRowIterator::ColumnDataRowIterator(const ColumnDataCollection *collection_p) - : collection(collection_p), scan_chunk(make_shared()), current_row(*scan_chunk, 0, 0) { + : collection(collection_p), scan_chunk(make_refcounted()), current_row(*scan_chunk, 0, 0) { if (!collection) { return; } @@ -1041,7 +1041,7 @@ void ColumnDataCollection::Reset() { segments.clear(); // Refreshes the ColumnDataAllocator to prevent holding on to allocated data unnecessarily - allocator = make_shared(*allocator); + allocator = make_refcounted(*allocator); } struct ValueResultEquals { diff --git a/src/common/types/column/column_data_collection_segment.cpp b/src/common/types/column/column_data_collection_segment.cpp index 9713b66af09e..1f815d521974 100644 --- a/src/common/types/column/column_data_collection_segment.cpp +++ b/src/common/types/column/column_data_collection_segment.cpp @@ -7,7 +7,7 @@ namespace duckdb { ColumnDataCollectionSegment::ColumnDataCollectionSegment(shared_ptr allocator_p, vector types_p) : allocator(std::move(allocator_p)), types(std::move(types_p)), count(0), - heap(make_shared(allocator->GetAllocator())) { + heap(make_refcounted(allocator->GetAllocator())) { } idx_t ColumnDataCollectionSegment::GetDataSize(idx_t type_size) { diff --git a/src/common/types/column/partitioned_column_data.cpp b/src/common/types/column/partitioned_column_data.cpp index c785f346e869..7d47e129f26b 100644 --- a/src/common/types/column/partitioned_column_data.cpp +++ b/src/common/types/column/partitioned_column_data.cpp @@ -9,7 +9,7 @@ namespace duckdb { PartitionedColumnData::PartitionedColumnData(PartitionedColumnDataType type_p, ClientContext &context_p, vector types_p) : type(type_p), context(context_p), types(std::move(types_p)), - allocators(make_shared()) { + allocators(make_refcounted()) { } PartitionedColumnData::PartitionedColumnData(const PartitionedColumnData &other) @@ -165,7 +165,7 @@ vector> &PartitionedColumnData::GetPartitions() } void PartitionedColumnData::CreateAllocator() { - allocators->allocators.emplace_back(make_shared(BufferManager::GetBufferManager(context))); + allocators->allocators.emplace_back(make_refcounted(BufferManager::GetBufferManager(context))); allocators->allocators.back()->MakeShared(); } diff --git a/src/common/types/row/partitioned_tuple_data.cpp b/src/common/types/row/partitioned_tuple_data.cpp index 979b292294e7..cd67c32abb0d 100644 --- a/src/common/types/row/partitioned_tuple_data.cpp +++ b/src/common/types/row/partitioned_tuple_data.cpp @@ -9,7 +9,7 @@ namespace duckdb { PartitionedTupleData::PartitionedTupleData(PartitionedTupleDataType type_p, BufferManager &buffer_manager_p, const TupleDataLayout &layout_p) : type(type_p), buffer_manager(buffer_manager_p), layout(layout_p.Copy()), count(0), data_size(0), - allocators(make_shared()) { + allocators(make_refcounted()) { } PartitionedTupleData::PartitionedTupleData(const PartitionedTupleData &other) @@ -434,7 +434,7 @@ void PartitionedTupleData::Print() { // LCOV_EXCL_STOP void PartitionedTupleData::CreateAllocator() { - allocators->allocators.emplace_back(make_shared(buffer_manager, layout)); + allocators->allocators.emplace_back(make_refcounted(buffer_manager, layout)); } } // namespace duckdb diff --git a/src/common/types/row/tuple_data_collection.cpp b/src/common/types/row/tuple_data_collection.cpp index 8e548f9f1e28..7ffcac79abce 100644 --- a/src/common/types/row/tuple_data_collection.cpp +++ b/src/common/types/row/tuple_data_collection.cpp @@ -12,7 +12,7 @@ namespace duckdb { using ValidityBytes = TupleDataLayout::ValidityBytes; TupleDataCollection::TupleDataCollection(BufferManager &buffer_manager, const TupleDataLayout &layout_p) - : layout(layout_p.Copy()), allocator(make_shared(buffer_manager, layout)) { + : layout(layout_p.Copy()), allocator(make_refcounted(buffer_manager, layout)) { Initialize(); } @@ -377,7 +377,7 @@ void TupleDataCollection::Reset() { segments.clear(); // Refreshes the TupleDataAllocator to prevent holding on to allocated data unnecessarily - allocator = make_shared(*allocator); + allocator = make_refcounted(*allocator); } void TupleDataCollection::InitializeChunk(DataChunk &chunk) const { diff --git a/src/common/types/value.cpp b/src/common/types/value.cpp index fb3ee10cea4d..a2558d6cea42 100644 --- a/src/common/types/value.cpp +++ b/src/common/types/value.cpp @@ -162,7 +162,7 @@ Value::Value(string val) : type_(LogicalType::VARCHAR), is_null(false) { if (!Value::StringIsValid(val.c_str(), val.size())) { throw ErrorManager::InvalidUnicodeError(val, "value construction"); } - value_info_ = make_shared(std::move(val)); + value_info_ = make_refcounted(std::move(val)); } Value::~Value() { @@ -668,7 +668,7 @@ Value Value::STRUCT(const LogicalType &type, vector struct_values) { for (size_t i = 0; i < struct_values.size(); i++) { struct_values[i] = struct_values[i].DefaultCastAs(child_types[i].second); } - result.value_info_ = make_shared(std::move(struct_values)); + result.value_info_ = make_refcounted(std::move(struct_values)); result.type_ = type; result.is_null = false; return result; @@ -711,7 +711,7 @@ Value Value::MAP(const LogicalType &key_type, const LogicalType &value_type, vec new_children.push_back(std::make_pair("value", std::move(values[i]))); values[i] = Value::STRUCT(std::move(new_children)); } - result.value_info_ = make_shared(std::move(values)); + result.value_info_ = make_refcounted(std::move(values)); return result; } @@ -735,7 +735,7 @@ Value Value::UNION(child_list_t members, uint8_t tag, Value value) } } union_values[tag + 1] = std::move(value); - result.value_info_ = make_shared(std::move(union_values)); + result.value_info_ = make_refcounted(std::move(union_values)); result.type_ = LogicalType::UNION(std::move(members)); return result; } @@ -752,7 +752,7 @@ Value Value::LIST(vector values) { #endif Value result; result.type_ = LogicalType::LIST(values[0].type()); - result.value_info_ = make_shared(std::move(values)); + result.value_info_ = make_refcounted(std::move(values)); result.is_null = false; return result; } @@ -770,7 +770,7 @@ Value Value::LIST(const LogicalType &child_type, vector values) { Value Value::EMPTYLIST(const LogicalType &child_type) { Value result; result.type_ = LogicalType::LIST(child_type); - result.value_info_ = make_shared(); + result.value_info_ = make_refcounted(); result.is_null = false; return result; } @@ -787,7 +787,7 @@ Value Value::ARRAY(vector values) { #endif Value result; result.type_ = LogicalType::ARRAY(values[0].type(), values.size()); - result.value_info_ = make_shared(std::move(values)); + result.value_info_ = make_refcounted(std::move(values)); result.is_null = false; return result; } @@ -805,7 +805,7 @@ Value Value::ARRAY(const LogicalType &child_type, vector values) { Value Value::EMPTYARRAY(const LogicalType &child_type, uint32_t size) { Value result; result.type_ = LogicalType::ARRAY(child_type, size); - result.value_info_ = make_shared(); + result.value_info_ = make_refcounted(); result.is_null = false; return result; } @@ -813,35 +813,35 @@ Value Value::EMPTYARRAY(const LogicalType &child_type, uint32_t size) { Value Value::BLOB(const_data_ptr_t data, idx_t len) { Value result(LogicalType::BLOB); result.is_null = false; - result.value_info_ = make_shared(string(const_char_ptr_cast(data), len)); + result.value_info_ = make_refcounted(string(const_char_ptr_cast(data), len)); return result; } Value Value::BLOB(const string &data) { Value result(LogicalType::BLOB); result.is_null = false; - result.value_info_ = make_shared(Blob::ToBlob(string_t(data))); + result.value_info_ = make_refcounted(Blob::ToBlob(string_t(data))); return result; } Value Value::AGGREGATE_STATE(const LogicalType &type, const_data_ptr_t data, idx_t len) { // NOLINT Value result(type); result.is_null = false; - result.value_info_ = make_shared(string(const_char_ptr_cast(data), len)); + result.value_info_ = make_refcounted(string(const_char_ptr_cast(data), len)); return result; } Value Value::BIT(const_data_ptr_t data, idx_t len) { Value result(LogicalType::BIT); result.is_null = false; - result.value_info_ = make_shared(string(const_char_ptr_cast(data), len)); + result.value_info_ = make_refcounted(string(const_char_ptr_cast(data), len)); return result; } Value Value::BIT(const string &data) { Value result(LogicalType::BIT); result.is_null = false; - result.value_info_ = make_shared(Bit::ToBit(string_t(data))); + result.value_info_ = make_refcounted(Bit::ToBit(string_t(data))); return result; } @@ -1936,27 +1936,27 @@ Value Value::Deserialize(Deserializer &deserializer) { case PhysicalType::VARCHAR: { auto str = deserializer.ReadProperty(102, "value"); if (type.id() == LogicalTypeId::BLOB) { - new_value.value_info_ = make_shared(Blob::ToBlob(str)); + new_value.value_info_ = make_refcounted(Blob::ToBlob(str)); } else { - new_value.value_info_ = make_shared(str); + new_value.value_info_ = make_refcounted(str); } } break; case PhysicalType::LIST: { deserializer.ReadObject(102, "value", [&](Deserializer &obj) { auto children = obj.ReadProperty>(100, "children"); - new_value.value_info_ = make_shared(children); + new_value.value_info_ = make_refcounted(children); }); } break; case PhysicalType::STRUCT: { deserializer.ReadObject(102, "value", [&](Deserializer &obj) { auto children = obj.ReadProperty>(100, "children"); - new_value.value_info_ = make_shared(children); + new_value.value_info_ = make_refcounted(children); }); } break; case PhysicalType::ARRAY: { deserializer.ReadObject(102, "value", [&](Deserializer &obj) { auto children = obj.ReadProperty>(100, "children"); - new_value.value_info_ = make_shared(children); + new_value.value_info_ = make_refcounted(children); }); } break; default: diff --git a/src/common/types/vector_cache.cpp b/src/common/types/vector_cache.cpp index c0ea6fa7cc3e..0c5075a54b45 100644 --- a/src/common/types/vector_cache.cpp +++ b/src/common/types/vector_cache.cpp @@ -18,7 +18,7 @@ class VectorCacheBuffer : public VectorBuffer { auto &child_type = ListType::GetChildType(type); child_caches.push_back(make_buffer(allocator, child_type, capacity)); auto child_vector = make_uniq(child_type, false, false); - auxiliary = make_shared(std::move(child_vector)); + auxiliary = make_refcounted(std::move(child_vector)); break; } case PhysicalType::ARRAY: { @@ -26,7 +26,7 @@ class VectorCacheBuffer : public VectorBuffer { auto array_size = ArrayType::GetSize(type); child_caches.push_back(make_buffer(allocator, child_type, array_size * capacity)); auto child_vector = make_uniq(child_type, true, false, array_size * capacity); - auxiliary = make_shared(std::move(child_vector), array_size, capacity); + auxiliary = make_refcounted(std::move(child_vector), array_size, capacity); break; } case PhysicalType::STRUCT: { @@ -34,7 +34,7 @@ class VectorCacheBuffer : public VectorBuffer { for (auto &child_type : child_types) { child_caches.push_back(make_buffer(allocator, child_type.second, capacity)); } - auto struct_buffer = make_shared(type); + auto struct_buffer = make_refcounted(type); auxiliary = std::move(struct_buffer); break; } diff --git a/src/execution/aggregate_hashtable.cpp b/src/execution/aggregate_hashtable.cpp index 90029b0a4b6c..95f826a35596 100644 --- a/src/execution/aggregate_hashtable.cpp +++ b/src/execution/aggregate_hashtable.cpp @@ -40,7 +40,7 @@ GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, All vector aggregate_objects_p, idx_t initial_capacity, idx_t radix_bits) : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)), - radix_bits(radix_bits), count(0), capacity(0), aggregate_allocator(make_shared(allocator)) { + radix_bits(radix_bits), count(0), capacity(0), aggregate_allocator(make_refcounted(allocator)) { // Append hash column to the end and initialise the row layout group_types_p.emplace_back(LogicalType::HASH); diff --git a/src/execution/index/art/art.cpp b/src/execution/index/art/art.cpp index 9fa44deafedf..7e3c7f3f38d3 100644 --- a/src/execution/index/art/art.cpp +++ b/src/execution/index/art/art.cpp @@ -54,7 +54,8 @@ ART::ART(const string &name, const IndexConstraintType index_constraint_type, co make_uniq(sizeof(Node16), block_manager), make_uniq(sizeof(Node48), block_manager), make_uniq(sizeof(Node256), block_manager)}; - allocators = make_shared, ALLOCATOR_COUNT>>(std::move(allocator_array)); + allocators = + make_refcounted, ALLOCATOR_COUNT>>(std::move(allocator_array)); } // deserialize lazily diff --git a/src/execution/operator/aggregate/aggregate_object.cpp b/src/execution/operator/aggregate/aggregate_object.cpp index 05af824d676f..79a524ea76b8 100644 --- a/src/execution/operator/aggregate/aggregate_object.cpp +++ b/src/execution/operator/aggregate/aggregate_object.cpp @@ -9,7 +9,7 @@ AggregateObject::AggregateObject(AggregateFunction function, FunctionData *bind_ idx_t payload_size, AggregateType aggr_type, PhysicalType return_type, Expression *filter) : function(std::move(function)), - bind_data_wrapper(bind_data ? make_shared(bind_data->Copy()) : nullptr), + bind_data_wrapper(bind_data ? make_refcounted(bind_data->Copy()) : nullptr), child_count(child_count), payload_size(payload_size), aggr_type(aggr_type), return_type(return_type), filter(filter) { } diff --git a/src/execution/operator/aggregate/physical_hash_aggregate.cpp b/src/execution/operator/aggregate/physical_hash_aggregate.cpp index 5217c110bd26..420a9aecb918 100644 --- a/src/execution/operator/aggregate/physical_hash_aggregate.cpp +++ b/src/execution/operator/aggregate/physical_hash_aggregate.cpp @@ -608,7 +608,7 @@ idx_t HashAggregateDistinctFinalizeEvent::CreateGlobalSources() { void HashAggregateDistinctFinalizeEvent::FinishEvent() { // Now that everything is added to the main ht, we can actually finalize - auto new_event = make_shared(context, pipeline.get(), op, gstate); + auto new_event = make_refcounted(context, pipeline.get(), op, gstate); this->InsertEvent(std::move(new_event)); } @@ -755,7 +755,7 @@ SinkFinalizeType PhysicalHashAggregate::FinalizeDistinct(Pipeline &pipeline, Eve radix_table->Finalize(context, radix_state); } } - auto new_event = make_shared(context, pipeline, *this, gstate); + auto new_event = make_refcounted(context, pipeline, *this, gstate); event.InsertEvent(std::move(new_event)); return SinkFinalizeType::READY; } diff --git a/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp b/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp index f1ceeab4bc7c..97008097513f 100644 --- a/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp +++ b/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp @@ -586,7 +586,7 @@ SinkFinalizeType PhysicalUngroupedAggregate::FinalizeDistinct(Pipeline &pipeline auto &radix_state = *distinct_state.radix_states[table_idx]; radix_table_p->Finalize(context, radix_state); } - auto new_event = make_shared(context, *this, gstate, pipeline); + auto new_event = make_refcounted(context, *this, gstate, pipeline); event.InsertEvent(std::move(new_event)); return SinkFinalizeType::READY; } diff --git a/src/execution/operator/aggregate/physical_window.cpp b/src/execution/operator/aggregate/physical_window.cpp index bcfe0a56bd3b..b945615bab16 100644 --- a/src/execution/operator/aggregate/physical_window.cpp +++ b/src/execution/operator/aggregate/physical_window.cpp @@ -171,7 +171,7 @@ SinkFinalizeType PhysicalWindow::Finalize(Pipeline &pipeline, Event &event, Clie } // Schedule all the sorts for maximum thread utilisation - auto new_event = make_shared(*state.global_partition, pipeline); + auto new_event = make_refcounted(*state.global_partition, pipeline); event.InsertEvent(std::move(new_event)); return SinkFinalizeType::READY; diff --git a/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp b/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp index 8c29ae79fb43..e0f23f4eb8f6 100644 --- a/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp +++ b/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp @@ -39,8 +39,8 @@ shared_ptr CSVBuffer::Next(CSVFileHandle &file_handle, idx_t buffer_s file_handle.Seek(global_csv_start + actual_buffer_size); has_seaked = false; } - auto next_csv_buffer = make_shared(file_handle, context, buffer_size, - global_csv_start + actual_buffer_size, file_number_p, buffer_idx + 1); + auto next_csv_buffer = make_refcounted( + file_handle, context, buffer_size, global_csv_start + actual_buffer_size, file_number_p, buffer_idx + 1); if (next_csv_buffer->GetBufferSize() == 0) { // We are done reading return nullptr; @@ -73,8 +73,8 @@ shared_ptr CSVBuffer::Pin(CSVFileHandle &file_handle, bool &has Reload(file_handle); has_seeked = true; } - return make_shared(buffer_manager.Pin(block), actual_buffer_size, last_buffer, file_number, - buffer_idx); + return make_refcounted(buffer_manager.Pin(block), actual_buffer_size, last_buffer, file_number, + buffer_idx); } void CSVBuffer::Unpin() { diff --git a/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp b/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp index 2a13158b6081..f3bcd5c403b6 100644 --- a/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp +++ b/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp @@ -28,7 +28,7 @@ void CSVBufferManager::UnpinBuffer(const idx_t cache_idx) { void CSVBufferManager::Initialize() { if (cached_buffers.empty()) { cached_buffers.emplace_back( - make_shared(context, buffer_size, *file_handle, global_csv_pos, file_idx)); + make_refcounted(context, buffer_size, *file_handle, global_csv_pos, file_idx)); last_buffer = cached_buffers.front(); } } diff --git a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp index 9582e1c1af2f..ed52797b01ac 100644 --- a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp +++ b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp @@ -525,14 +525,14 @@ StringValueScanner::StringValueScanner(const shared_ptr &buffe } unique_ptr StringValueScanner::GetCSVScanner(ClientContext &context, CSVReaderOptions &options) { - auto state_machine = make_shared(options, options.dialect_options.state_machine_options, - CSVStateMachineCache::Get(context)); + auto state_machine = make_refcounted(options, options.dialect_options.state_machine_options, + CSVStateMachineCache::Get(context)); state_machine->dialect_options.num_cols = options.dialect_options.num_cols; state_machine->dialect_options.header = options.dialect_options.header; - auto buffer_manager = make_shared(context, options, options.file_path, 0); - auto scanner = make_uniq(buffer_manager, state_machine, make_shared()); - scanner->csv_file_scan = make_shared(context, options.file_path, options); + auto buffer_manager = make_refcounted(context, options, options.file_path, 0); + auto scanner = make_uniq(buffer_manager, state_machine, make_refcounted()); + scanner->csv_file_scan = make_refcounted(context, options.file_path, options); scanner->csv_file_scan->InitializeProjection(); return scanner; } @@ -1074,8 +1074,9 @@ void StringValueScanner::SetStart() { return; } - scan_finder = make_uniq( - 0, buffer_manager, state_machine, make_shared(true), csv_file_scan, false, iterator, 1); + scan_finder = + make_uniq(0, buffer_manager, state_machine, make_refcounted(true), + csv_file_scan, false, iterator, 1); auto &tuples = scan_finder->ParseChunk(); line_found = true; if (tuples.number_of_rows != 1) { diff --git a/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp b/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp index 3b60f247aa27..af8788a1ad9c 100644 --- a/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp +++ b/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp @@ -13,8 +13,8 @@ CSVSniffer::CSVSniffer(CSVReaderOptions &options_p, shared_ptr } // Initialize max columns found to either 0 or however many were set max_columns_found = set_columns.Size(); - error_handler = make_shared(options.ignore_errors); - detection_error_handler = make_shared(true); + error_handler = make_refcounted(options.ignore_errors); + detection_error_handler = make_refcounted(true); } bool SetColumns::IsSet() { diff --git a/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp b/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp index 0532fc678a41..4f5732b934f9 100644 --- a/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp +++ b/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp @@ -10,7 +10,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, shared_ptr bu vector &file_schema) : file_path(options_p.file_path), file_idx(0), buffer_manager(std::move(buffer_manager_p)), state_machine(std::move(state_machine_p)), file_size(buffer_manager->file_handle->FileSize()), - error_handler(make_shared(options_p.ignore_errors)), + error_handler(make_refcounted(options_p.ignore_errors)), on_disk_file(buffer_manager->file_handle->OnDiskFile()), options(options_p) { if (bind_data.initial_reader.get()) { auto &union_reader = *bind_data.initial_reader; @@ -43,7 +43,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons const idx_t file_idx_p, const ReadCSVData &bind_data, const vector &column_ids, const vector &file_schema) : file_path(file_path_p), file_idx(file_idx_p), - error_handler(make_shared(options_p.ignore_errors)), options(options_p) { + error_handler(make_refcounted(options_p.ignore_errors)), options(options_p) { if (file_idx < bind_data.union_readers.size()) { // we are doing UNION BY NAME - fetch the options from the union reader for this file optional_ptr union_reader_ptr; @@ -73,7 +73,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons } // Initialize Buffer Manager - buffer_manager = make_shared(context, options, file_path, file_idx); + buffer_manager = make_refcounted(context, options, file_path, file_idx); // Initialize On Disk and Size of file on_disk_file = buffer_manager->file_handle->OnDiskFile(); file_size = buffer_manager->file_handle->FileSize(); @@ -89,7 +89,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons CSVSniffer sniffer(options, buffer_manager, state_machine_cache); sniffer.SniffCSV(); } - state_machine = make_shared( + state_machine = make_refcounted( state_machine_cache.Get(options.dialect_options.state_machine_options), options); MultiFileReader::InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, @@ -120,8 +120,8 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons names = bind_data.csv_names; types = bind_data.csv_types; - state_machine = - make_shared(state_machine_cache.Get(options.dialect_options.state_machine_options), options); + state_machine = make_refcounted( + state_machine_cache.Get(options.dialect_options.state_machine_options), options); MultiFileReader::InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, bind_data.return_names, column_ids, nullptr, file_path, context); @@ -129,9 +129,9 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons } CSVFileScan::CSVFileScan(ClientContext &context, const string &file_name, CSVReaderOptions &options_p) - : file_path(file_name), file_idx(0), error_handler(make_shared(options_p.ignore_errors)), + : file_path(file_name), file_idx(0), error_handler(make_refcounted(options_p.ignore_errors)), options(options_p) { - buffer_manager = make_shared(context, options, file_path, file_idx); + buffer_manager = make_refcounted(context, options, file_path, file_idx); // Initialize On Disk and Size of file on_disk_file = buffer_manager->file_handle->OnDiskFile(); file_size = buffer_manager->file_handle->FileSize(); @@ -151,8 +151,8 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_name, CSVRea options.dialect_options.num_cols = options.sql_type_list.size(); } // Initialize State Machine - state_machine = - make_shared(state_machine_cache.Get(options.dialect_options.state_machine_options), options); + state_machine = make_refcounted( + state_machine_cache.Get(options.dialect_options.state_machine_options), options); } void CSVFileScan::InitializeFileNamesTypes() { diff --git a/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp b/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp index b5943a9f5b8b..bc65c6d13172 100644 --- a/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp +++ b/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp @@ -14,7 +14,7 @@ CSVGlobalState::CSVGlobalState(ClientContext &context_p, const shared_ptrGetFilePath() == files[0]) { - auto state_machine = make_shared( + auto state_machine = make_refcounted( CSVStateMachineCache::Get(context).Get(options.dialect_options.state_machine_options), options); // If we already have a buffer manager, we don't need to reconstruct it to the first file file_scans.emplace_back(make_uniq(context, buffer_manager, state_machine, options, bind_data, @@ -36,7 +36,7 @@ CSVGlobalState::CSVGlobalState(ClientContext &context_p, const shared_ptrbuffer_manager->GetBuffer(0)->actual_size; current_boundary = CSVIterator(0, 0, 0, 0, buffer_size); } - current_buffer_in_use = make_shared(*file_scans.back()->buffer_manager, 0); + current_buffer_in_use = make_refcounted(*file_scans.back()->buffer_manager, 0); } double CSVGlobalState::GetProgress(const ReadCSVData &bind_data_p) const { @@ -66,8 +66,8 @@ unique_ptr CSVGlobalState::Next() { if (cur_idx == 0) { current_file = file_scans.back(); } else { - current_file = make_shared(context, bind_data.files[cur_idx], bind_data.options, cur_idx, - bind_data, column_ids, file_schema); + current_file = make_refcounted(context, bind_data.files[cur_idx], bind_data.options, cur_idx, + bind_data, column_ids, file_schema); } auto csv_scanner = make_uniq(scanner_idx++, current_file->buffer_manager, current_file->state_machine, @@ -80,7 +80,7 @@ unique_ptr CSVGlobalState::Next() { } if (current_buffer_in_use->buffer_idx != current_boundary.GetBufferIdx()) { current_buffer_in_use = - make_shared(*file_scans.back()->buffer_manager, current_boundary.GetBufferIdx()); + make_refcounted(*file_scans.back()->buffer_manager, current_boundary.GetBufferIdx()); } // We first create the scanner for the current boundary auto ¤t_file = *file_scans.back(); @@ -96,13 +96,13 @@ unique_ptr CSVGlobalState::Next() { auto current_file_idx = current_file.file_idx + 1; if (current_file_idx < bind_data.files.size()) { // If we have a next file we have to construct the file scan for that - file_scans.emplace_back(make_shared(context, bind_data.files[current_file_idx], - bind_data.options, current_file_idx, bind_data, column_ids, - file_schema)); + file_scans.emplace_back(make_refcounted(context, bind_data.files[current_file_idx], + bind_data.options, current_file_idx, bind_data, + column_ids, file_schema)); // And re-start the boundary-iterator auto buffer_size = file_scans.back()->buffer_manager->GetBuffer(0)->actual_size; current_boundary = CSVIterator(current_file_idx, 0, 0, 0, buffer_size); - current_buffer_in_use = make_shared(*file_scans.back()->buffer_manager, 0); + current_buffer_in_use = make_refcounted(*file_scans.back()->buffer_manager, 0); } else { // If not we are done with this CSV Scanning finished = true; diff --git a/src/execution/operator/helper/physical_buffered_collector.cpp b/src/execution/operator/helper/physical_buffered_collector.cpp index fcf75496e6f7..90708953a643 100644 --- a/src/execution/operator/helper/physical_buffered_collector.cpp +++ b/src/execution/operator/helper/physical_buffered_collector.cpp @@ -55,7 +55,7 @@ SinkCombineResultType PhysicalBufferedCollector::Combine(ExecutionContext &conte unique_ptr PhysicalBufferedCollector::GetGlobalSinkState(ClientContext &context) const { auto state = make_uniq(); state->context = context.shared_from_this(); - state->buffered_data = make_shared(state->context); + state->buffered_data = make_refcounted(state->context); return std::move(state); } diff --git a/src/execution/operator/join/physical_asof_join.cpp b/src/execution/operator/join/physical_asof_join.cpp index 3996063315c5..05e45d7a455f 100644 --- a/src/execution/operator/join/physical_asof_join.cpp +++ b/src/execution/operator/join/physical_asof_join.cpp @@ -169,7 +169,7 @@ SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, Cl } // Schedule all the sorts for maximum thread utilisation - auto new_event = make_shared(gstate.rhs_sink, pipeline); + auto new_event = make_refcounted(gstate.rhs_sink, pipeline); event.InsertEvent(std::move(new_event)); return SinkFinalizeType::READY; diff --git a/src/execution/operator/join/physical_hash_join.cpp b/src/execution/operator/join/physical_hash_join.cpp index 1e1e9fd12384..09c336458300 100644 --- a/src/execution/operator/join/physical_hash_join.cpp +++ b/src/execution/operator/join/physical_hash_join.cpp @@ -359,7 +359,7 @@ void HashJoinGlobalSinkState::ScheduleFinalize(Pipeline &pipeline, Event &event) return; } hash_table->InitializePointerTable(); - auto new_event = make_shared(pipeline, *this); + auto new_event = make_refcounted(pipeline, *this); event.InsertEvent(std::move(new_event)); } @@ -474,7 +474,7 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl // We have to repartition ht.SetRepartitionRadixBits(sink.local_hash_tables, sink.temporary_memory_state->GetReservation(), max_partition_size, max_partition_count); - auto new_event = make_shared(pipeline, sink, sink.local_hash_tables); + auto new_event = make_refcounted(pipeline, sink, sink.local_hash_tables); event.InsertEvent(std::move(new_event)); } else { // No repartitioning! diff --git a/src/execution/operator/join/physical_range_join.cpp b/src/execution/operator/join/physical_range_join.cpp index 5d6d44f47bff..d89360236a3b 100644 --- a/src/execution/operator/join/physical_range_join.cpp +++ b/src/execution/operator/join/physical_range_join.cpp @@ -149,7 +149,7 @@ class RangeJoinMergeEvent : public BasePipelineEvent { void PhysicalRangeJoin::GlobalSortedTable::ScheduleMergeTasks(Pipeline &pipeline, Event &event) { // Initialize global sort state for a round of merging global_sort_state.InitializeMergeRound(); - auto new_event = make_shared(*this, pipeline); + auto new_event = make_refcounted(*this, pipeline); event.InsertEvent(std::move(new_event)); } diff --git a/src/execution/operator/order/physical_order.cpp b/src/execution/operator/order/physical_order.cpp index ac933f3fdc84..7aa77a7a30c0 100644 --- a/src/execution/operator/order/physical_order.cpp +++ b/src/execution/operator/order/physical_order.cpp @@ -186,7 +186,7 @@ SinkFinalizeType PhysicalOrder::Finalize(Pipeline &pipeline, Event &event, Clien void PhysicalOrder::ScheduleMergeTasks(Pipeline &pipeline, Event &event, OrderGlobalSinkState &state) { // Initialize global sort state for a round of merging state.global_sort_state.InitializeMergeRound(); - auto new_event = make_shared(state, pipeline); + auto new_event = make_refcounted(state, pipeline); event.InsertEvent(std::move(new_event)); } diff --git a/src/execution/operator/persistent/physical_batch_copy_to_file.cpp b/src/execution/operator/persistent/physical_batch_copy_to_file.cpp index 5447f433cb69..831ea35b6e49 100644 --- a/src/execution/operator/persistent/physical_batch_copy_to_file.cpp +++ b/src/execution/operator/persistent/physical_batch_copy_to_file.cpp @@ -308,7 +308,7 @@ SinkFinalizeType PhysicalBatchCopyToFile::Finalize(Pipeline &pipeline, Event &ev FinalFlush(context, input.global_state); } else { // we have multiple tasks remaining - launch an event to execute the tasks in parallel - auto new_event = make_shared(*this, gstate, pipeline, context); + auto new_event = make_refcounted(*this, gstate, pipeline, context); event.InsertEvent(std::move(new_event)); } return SinkFinalizeType::READY; diff --git a/src/execution/operator/persistent/physical_copy_to_file.cpp b/src/execution/operator/persistent/physical_copy_to_file.cpp index 5925e9f012eb..be9ca4dbbd61 100644 --- a/src/execution/operator/persistent/physical_copy_to_file.cpp +++ b/src/execution/operator/persistent/physical_copy_to_file.cpp @@ -284,7 +284,7 @@ unique_ptr PhysicalCopyToFile::GetGlobalSinkState(ClientContext } if (partition_output) { - state->partition_state = make_shared(); + state->partition_state = make_refcounted(); } return std::move(state); diff --git a/src/execution/operator/schema/physical_create_art_index.cpp b/src/execution/operator/schema/physical_create_art_index.cpp index e7325405fafe..fe88d7c3bfa4 100644 --- a/src/execution/operator/schema/physical_create_art_index.cpp +++ b/src/execution/operator/schema/physical_create_art_index.cpp @@ -178,7 +178,7 @@ SinkFinalizeType PhysicalCreateARTIndex::Finalize(Pipeline &pipeline, Event &eve auto &index = index_entry->Cast(); index.initial_index_size = state.global_index->GetInMemorySize(); - index.info = make_shared(storage.info, index.name); + index.info = make_refcounted(storage.info, index.name); for (auto &parsed_expr : info->parsed_expressions) { index.parsed_expressions.push_back(parsed_expr->Copy()); } diff --git a/src/execution/operator/set/physical_recursive_cte.cpp b/src/execution/operator/set/physical_recursive_cte.cpp index 57a847dd0775..5210987325ab 100644 --- a/src/execution/operator/set/physical_recursive_cte.cpp +++ b/src/execution/operator/set/physical_recursive_cte.cpp @@ -200,7 +200,7 @@ void PhysicalRecursiveCTE::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_ initial_state_pipeline.Build(*children[0]); // the RHS is the recursive pipeline - recursive_meta_pipeline = make_shared(executor, state, this); + recursive_meta_pipeline = make_refcounted(executor, state, this); recursive_meta_pipeline->SetRecursiveCTE(); recursive_meta_pipeline->Build(*children[1]); diff --git a/src/execution/physical_plan/plan_cte.cpp b/src/execution/physical_plan/plan_cte.cpp index 7a306c3a54ca..c11286a8cf3f 100644 --- a/src/execution/physical_plan/plan_cte.cpp +++ b/src/execution/physical_plan/plan_cte.cpp @@ -12,7 +12,7 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalMaterializ D_ASSERT(op.children.size() == 2); // Create the working_table that the PhysicalCTE will use for evaluation. - auto working_table = make_shared(context, op.children[0]->types); + auto working_table = make_refcounted(context, op.children[0]->types); // Add the ColumnDataCollection to the context of this PhysicalPlanGenerator recursive_cte_tables[op.table_index] = working_table; diff --git a/src/execution/physical_plan/plan_recursive_cte.cpp b/src/execution/physical_plan/plan_recursive_cte.cpp index 89da0cd0706e..5ddb767612b6 100644 --- a/src/execution/physical_plan/plan_recursive_cte.cpp +++ b/src/execution/physical_plan/plan_recursive_cte.cpp @@ -12,7 +12,7 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalRecursiveC D_ASSERT(op.children.size() == 2); // Create the working_table that the PhysicalRecursiveCTE will use for evaluation. - auto working_table = make_shared(context, op.types); + auto working_table = make_refcounted(context, op.types); // Add the ColumnDataCollection to the context of this PhysicalPlanGenerator recursive_cte_tables[op.table_index] = working_table; diff --git a/src/function/table/copy_csv.cpp b/src/function/table/copy_csv.cpp index 1fa2e46e7694..e45a6a643120 100644 --- a/src/function/table/copy_csv.cpp +++ b/src/function/table/copy_csv.cpp @@ -156,7 +156,7 @@ static unique_ptr ReadCSVBind(ClientContext &context, CopyInfo &in } if (options.auto_detect) { - auto buffer_manager = make_shared(context, options, bind_data->files[0], 0); + auto buffer_manager = make_refcounted(context, options, bind_data->files[0], 0); CSVSniffer sniffer(options, buffer_manager, CSVStateMachineCache::Get(context), {&expected_types, &expected_names}); sniffer.SniffCSV(); diff --git a/src/function/table/read_csv.cpp b/src/function/table/read_csv.cpp index 8d2e1be0d780..5f462c8059f1 100644 --- a/src/function/table/read_csv.cpp +++ b/src/function/table/read_csv.cpp @@ -98,7 +98,7 @@ static unique_ptr ReadCSVBind(ClientContext &context, TableFunctio } if (options.auto_detect && !options.file_options.union_by_name) { options.file_path = result->files[0]; - result->buffer_manager = make_shared(context, options, result->files[0], 0); + result->buffer_manager = make_refcounted(context, options, result->files[0], 0); CSVSniffer sniffer(options, result->buffer_manager, CSVStateMachineCache::Get(context), {&return_types, &names}); auto sniffer_result = sniffer.SniffCSV(); diff --git a/src/function/table/sniff_csv.cpp b/src/function/table/sniff_csv.cpp index 3e859a65afe3..d29a82f1c583 100644 --- a/src/function/table/sniff_csv.cpp +++ b/src/function/table/sniff_csv.cpp @@ -120,7 +120,7 @@ static void CSVSniffFunction(ClientContext &context, TableFunctionInput &data_p, auto sniffer_options = data.options; sniffer_options.file_path = data.path; - auto buffer_manager = make_shared(context, sniffer_options, sniffer_options.file_path, 0); + auto buffer_manager = make_refcounted(context, sniffer_options, sniffer_options.file_path, 0); if (sniffer_options.name_list.empty()) { sniffer_options.name_list = data.names_csv; } diff --git a/src/include/duckdb/common/enable_shared_from_this.ipp b/src/include/duckdb/common/enable_shared_from_this.ipp new file mode 100644 index 000000000000..6472db9c2b12 --- /dev/null +++ b/src/include/duckdb/common/enable_shared_from_this.ipp @@ -0,0 +1,40 @@ +namespace duckdb { + +template +class enable_shared_from_this { + mutable weak_ptr<_Tp> __weak_this_; + +protected: + constexpr enable_shared_from_this() noexcept { + } + enable_shared_from_this(enable_shared_from_this const &) noexcept { + } + enable_shared_from_this &operator=(enable_shared_from_this const &) noexcept { + return *this; + } + ~enable_shared_from_this() { + } + +public: + shared_ptr<_Tp> shared_from_this() { + return shared_ptr<_Tp>(__weak_this_); + } + shared_ptr<_Tp const> shared_from_this() const { + return shared_ptr(__weak_this_); + } + +#if _LIBCPP_STD_VER >= 17 + weak_ptr<_Tp> weak_from_this() noexcept { + return __weak_this_; + } + + weak_ptr weak_from_this() const noexcept { + return __weak_this_; + } +#endif // _LIBCPP_STD_VER >= 17 + + template + friend class shared_ptr; +}; + +} // namespace duckdb diff --git a/src/include/duckdb/common/exception.hpp b/src/include/duckdb/common/exception.hpp index 3765c6ba58c8..4f8dc1e16bbe 100644 --- a/src/include/duckdb/common/exception.hpp +++ b/src/include/duckdb/common/exception.hpp @@ -10,7 +10,6 @@ #include "duckdb/common/assert.hpp" #include "duckdb/common/exception_format_value.hpp" -#include "duckdb/common/shared_ptr.hpp" #include "duckdb/common/unordered_map.hpp" #include "duckdb/common/typedefs.hpp" diff --git a/src/include/duckdb/common/helper.hpp b/src/include/duckdb/common/helper.hpp index 4c57e51a272b..81da4bd79513 100644 --- a/src/include/duckdb/common/helper.hpp +++ b/src/include/duckdb/common/helper.hpp @@ -68,7 +68,7 @@ make_uniq(ARGS&&... args) // NOLINT: mimic std style template inline shared_ptr -make_shared(ARGS&&... args) // NOLINT: mimic std style +make_refcounted(ARGS&&... args) // NOLINT: mimic std style { return shared_ptr(new DATA_TYPE(std::forward(args)...)); } @@ -117,10 +117,15 @@ unique_ptr unique_ptr_cast(unique_ptr src) { // NOLINT: mimic std style return unique_ptr(static_cast(src.release())); } +template +shared_ptr shared_ptr_cast(shared_ptr src) { + return shared_ptr(std::static_pointer_cast(src.internal)); +} + struct SharedConstructor { template static shared_ptr Create(ARGS &&...args) { - return make_shared(std::forward(args)...); + return make_refcounted(std::forward(args)...); } }; diff --git a/src/include/duckdb/common/http_state.hpp b/src/include/duckdb/common/http_state.hpp index 1341b921147c..6fc1fa92e291 100644 --- a/src/include/duckdb/common/http_state.hpp +++ b/src/include/duckdb/common/http_state.hpp @@ -20,7 +20,7 @@ namespace duckdb { class CachedFileHandle; //! Represents a file that is intended to be fully downloaded, then used in parallel by multiple threads -class CachedFile : public std::enable_shared_from_this { +class CachedFile : public enable_shared_from_this { friend class CachedFileHandle; public: diff --git a/src/include/duckdb/common/multi_file_reader.hpp b/src/include/duckdb/common/multi_file_reader.hpp index ca52810e8dfd..ec318dea76e0 100644 --- a/src/include/duckdb/common/multi_file_reader.hpp +++ b/src/include/duckdb/common/multi_file_reader.hpp @@ -151,7 +151,7 @@ struct MultiFileReader { return BindUnionReader(context, return_types, names, result, options); } else { shared_ptr reader; - reader = make_shared(context, result.files[0], options); + reader = make_refcounted(context, result.files[0], options); return_types = reader->return_types; names = reader->names; result.Initialize(std::move(reader)); diff --git a/src/include/duckdb/common/re2_regex.hpp b/src/include/duckdb/common/re2_regex.hpp index ae5c48fd51f7..77b5b261b60d 100644 --- a/src/include/duckdb/common/re2_regex.hpp +++ b/src/include/duckdb/common/re2_regex.hpp @@ -4,7 +4,8 @@ #include "duckdb/common/winapi.hpp" #include "duckdb/common/vector.hpp" -#include +#include "duckdb/common/shared_ptr.hpp" +#include "duckdb/common/string.hpp" #include namespace duckdb_re2 { @@ -22,7 +23,7 @@ class Regex { } private: - shared_ptr regex; + duckdb::shared_ptr regex; }; struct GroupMatch { diff --git a/src/include/duckdb/common/shared_ptr.hpp b/src/include/duckdb/common/shared_ptr.hpp index 615273d7c27e..fe9d31ee40c9 100644 --- a/src/include/duckdb/common/shared_ptr.hpp +++ b/src/include/duckdb/common/shared_ptr.hpp @@ -10,113 +10,22 @@ #include #include +#include "duckdb/common/unique_ptr.hpp" -template -class weak_ptr; +#if _LIBCPP_STD_VER >= 17 +template +struct __bounded_convertible_to_unbounded : false_type {}; -namespace duckdb { +template +struct __bounded_convertible_to_unbounded<_Up[_Np], _Tp> : is_same, _Up[]> {}; -template -class shared_ptr { -private: - template - friend class weak_ptr; - std::shared_ptr internal; +template +struct __compatible_with : _Or, __bounded_convertible_to_unbounded<_Yp, _Tp>> {}; +#else +template +struct __compatible_with : std::is_convertible<_Yp *, _Tp *> {}; +#endif // _LIBCPP_STD_VER >= 17 -public: - // Constructors - shared_ptr() : internal() { - } - shared_ptr(std::nullptr_t) : internal(nullptr) { - } // Implicit conversion - template - explicit shared_ptr(U *ptr) : internal(ptr) { - } - shared_ptr(const shared_ptr &other) : internal(other.internal) { - } - shared_ptr(std::shared_ptr other) : internal(std::move(other)) { - } - - // Destructor - ~shared_ptr() = default; - - // Assignment operators - shared_ptr &operator=(const shared_ptr &other) { - internal = other.internal; - return *this; - } - - // Modifiers - void reset() { - internal.reset(); - } - - template - void reset(U *ptr) { - internal.reset(ptr); - } - - template - void reset(U *ptr, Deleter deleter) { - internal.reset(ptr, deleter); - } - - // Observers - T *get() const { - return internal.get(); - } - - long use_count() const { - return internal.use_count(); - } - - explicit operator bool() const noexcept { - return internal.operator bool(); - } - - // Element access - std::__add_lvalue_reference_t operator*() const { - return *internal; - } - - T *operator->() const { - return internal.operator->(); - } - - // Relational operators - template - bool operator==(const shared_ptr &other) const noexcept { - return internal == other.internal; - } - - bool operator==(std::nullptr_t) const noexcept { - return internal == nullptr; - } - - template - bool operator!=(const shared_ptr &other) const noexcept { - return internal != other.internal; - } - - template - bool operator<(const shared_ptr &other) const noexcept { - return internal < other.internal; - } - - template - bool operator<=(const shared_ptr &other) const noexcept { - return internal <= other.internal; - } - - template - bool operator>(const shared_ptr &other) const noexcept { - return internal > other.internal; - } - - template - bool operator>=(const shared_ptr &other) const noexcept { - return internal >= other.internal; - } -}; - -} // namespace duckdb +#include "duckdb/common/shared_ptr.ipp" +#include "duckdb/common/weak_ptr.ipp" +#include "duckdb/common/enable_shared_from_this.ipp" diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp new file mode 100644 index 000000000000..f95901521664 --- /dev/null +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -0,0 +1,150 @@ + +namespace duckdb { + +template +class weak_ptr; + +template +class shared_ptr { +private: + template + friend class weak_ptr; + std::shared_ptr internal; + +public: + // Constructors + shared_ptr() : internal() { + } + shared_ptr(std::nullptr_t) : internal(nullptr) { + } // Implicit conversion + template + explicit shared_ptr(U *ptr) : internal(ptr) { + } + // Constructor with custom deleter + template + shared_ptr(T *ptr, Deleter deleter) : internal(ptr, deleter) { + } + + shared_ptr(const shared_ptr &other) : internal(other.internal) { + } + + shared_ptr(std::shared_ptr other) : internal(std::move(other)) { + } + shared_ptr(shared_ptr &&other) : internal(std::move(other.internal)) { + } + + template + explicit shared_ptr(weak_ptr other) : internal(other.internal) { + } + + template ::value && __compatible_with::value && + std::is_convertible::pointer, T *>::value, + int> = 0> + shared_ptr(unique_ptr other) : internal(other.release()) { + } + + template ::value && __compatible_with::value && + std::is_convertible::pointer, T *>::value, + int> = 0> + shared_ptr(unique_ptr &&other) : internal(other.release()) { + } + + // Destructor + ~shared_ptr() = default; + + // Assignment operators + shared_ptr &operator=(const shared_ptr &other) { + internal = other.internal; + return *this; + } + + template + shared_ptr &operator=(unique_ptr &&__r) { + shared_ptr(std::move(__r)).swap(*this); + return *this; + } + + // Modifiers + void reset() { + internal.reset(); + } + + template + void reset(U *ptr) { + internal.reset(ptr); + } + + template + void reset(U *ptr, Deleter deleter) { + internal.reset(ptr, deleter); + } + + // Observers + T *get() const { + return internal.get(); + } + + long use_count() const { + return internal.use_count(); + } + + explicit operator bool() const noexcept { + return internal.operator bool(); + } + + template + operator shared_ptr() const noexcept { + return shared_ptr(internal); + } + + // Element access + std::__add_lvalue_reference_t operator*() const { + return *internal; + } + + T *operator->() const { + return internal.operator->(); + } + + // Relational operators + template + bool operator==(const shared_ptr &other) const noexcept { + return internal == other.internal; + } + + bool operator==(std::nullptr_t) const noexcept { + return internal == nullptr; + } + + template + bool operator!=(const shared_ptr &other) const noexcept { + return internal != other.internal; + } + + template + bool operator<(const shared_ptr &other) const noexcept { + return internal < other.internal; + } + + template + bool operator<=(const shared_ptr &other) const noexcept { + return internal <= other.internal; + } + + template + bool operator>(const shared_ptr &other) const noexcept { + return internal > other.internal; + } + + template + bool operator>=(const shared_ptr &other) const noexcept { + return internal >= other.internal; + } + + template + friend shared_ptr shared_ptr_cast(shared_ptr src); +}; + +} // namespace duckdb diff --git a/src/include/duckdb/common/types.hpp b/src/include/duckdb/common/types.hpp index 0151bc31a85f..e9e31b488ea5 100644 --- a/src/include/duckdb/common/types.hpp +++ b/src/include/duckdb/common/types.hpp @@ -35,7 +35,7 @@ using buffer_ptr = shared_ptr; template buffer_ptr make_buffer(ARGS &&...args) { // NOLINT: mimic std casing - return make_shared(std::forward(args)...); + return make_refcounted(std::forward(args)...); } struct list_entry_t { // NOLINT: mimic std casing diff --git a/src/include/duckdb/common/types/selection_vector.hpp b/src/include/duckdb/common/types/selection_vector.hpp index db6d6e9bdece..a0f0b185beae 100644 --- a/src/include/duckdb/common/types/selection_vector.hpp +++ b/src/include/duckdb/common/types/selection_vector.hpp @@ -71,7 +71,7 @@ struct SelectionVector { sel_vector = sel; } void Initialize(idx_t count = STANDARD_VECTOR_SIZE) { - selection_data = make_shared(count); + selection_data = make_refcounted(count); sel_vector = selection_data->owned_data.get(); } void Initialize(buffer_ptr data) { diff --git a/src/include/duckdb/common/unique_ptr.hpp b/src/include/duckdb/common/unique_ptr.hpp index d9f0b835832c..b98f8da00030 100644 --- a/src/include/duckdb/common/unique_ptr.hpp +++ b/src/include/duckdb/common/unique_ptr.hpp @@ -9,10 +9,10 @@ namespace duckdb { -template , bool SAFE = true> -class unique_ptr : public std::unique_ptr { // NOLINT: naming +template , bool SAFE = true> +class unique_ptr : public std::unique_ptr { // NOLINT: naming public: - using original = std::unique_ptr; + using original = std::unique_ptr; using original::original; // NOLINT private: @@ -53,9 +53,9 @@ class unique_ptr : public std::unique_ptr { // NOLINT } }; -template -class unique_ptr - : public std::unique_ptr> { +// FIXME: DELETER is defined, but we use std::default_delete??? +template +class unique_ptr : public std::unique_ptr> { public: using original = std::unique_ptr>; using original::original; diff --git a/src/include/duckdb/common/weak_ptr.hpp b/src/include/duckdb/common/weak_ptr.ipp similarity index 87% rename from src/include/duckdb/common/weak_ptr.hpp rename to src/include/duckdb/common/weak_ptr.ipp index bf442e02ad6a..5fbe213c92bc 100644 --- a/src/include/duckdb/common/weak_ptr.hpp +++ b/src/include/duckdb/common/weak_ptr.ipp @@ -1,20 +1,18 @@ -#pragma once - -#include "duckdb/common/shared_ptr.hpp" -#include - namespace duckdb { template class weak_ptr { private: + template + friend class shared_ptr; std::weak_ptr internal; public: // Constructors weak_ptr() : internal() { } - template + // template ::value, int> = 0> + template weak_ptr(const shared_ptr &ptr) : internal(ptr.internal) { } weak_ptr(const weak_ptr &other) : internal(other.internal) { @@ -29,7 +27,7 @@ class weak_ptr { return *this; } - template + template ::value, int> = 0> weak_ptr &operator=(const shared_ptr &ptr) { internal = ptr; return *this; diff --git a/src/include/duckdb/main/buffered_data/buffered_data.hpp b/src/include/duckdb/main/buffered_data/buffered_data.hpp index 8065fbee2c73..a863d551be6f 100644 --- a/src/include/duckdb/main/buffered_data/buffered_data.hpp +++ b/src/include/duckdb/main/buffered_data/buffered_data.hpp @@ -15,7 +15,7 @@ #include "duckdb/common/optional_idx.hpp" #include "duckdb/execution/physical_operator_states.hpp" #include "duckdb/common/enums/pending_execution_result.hpp" -#include "duckdb/common/weak_ptr.hpp" +#include "duckdb/common/shared_ptr.hpp" namespace duckdb { diff --git a/src/include/duckdb/main/client_context.hpp b/src/include/duckdb/main/client_context.hpp index 9f608273b668..1af4e9007deb 100644 --- a/src/include/duckdb/main/client_context.hpp +++ b/src/include/duckdb/main/client_context.hpp @@ -59,7 +59,7 @@ struct PendingQueryParameters { //! The ClientContext holds information relevant to the current client session //! during execution -class ClientContext : public std::enable_shared_from_this { +class ClientContext : public enable_shared_from_this { friend class PendingQueryResult; // LockContext friend class SimpleBufferedData; // ExecuteTaskInternal friend class StreamQueryResult; // LockContext diff --git a/src/include/duckdb/main/database.hpp b/src/include/duckdb/main/database.hpp index 0b87cf8f0ea2..5ec2b68ab01f 100644 --- a/src/include/duckdb/main/database.hpp +++ b/src/include/duckdb/main/database.hpp @@ -27,7 +27,7 @@ class ObjectCache; struct AttachInfo; class DatabaseFileSystem; -class DatabaseInstance : public std::enable_shared_from_this { +class DatabaseInstance : public enable_shared_from_this { friend class DuckDB; public: diff --git a/src/include/duckdb/main/relation.hpp b/src/include/duckdb/main/relation.hpp index c494366208cb..6d0ffa1d1cab 100644 --- a/src/include/duckdb/main/relation.hpp +++ b/src/include/duckdb/main/relation.hpp @@ -34,7 +34,7 @@ class LogicalOperator; class QueryNode; class TableRef; -class Relation : public std::enable_shared_from_this { +class Relation : public enable_shared_from_this { public: Relation(const shared_ptr &context, RelationType type) : context(context), type(type) { } diff --git a/src/include/duckdb/parallel/event.hpp b/src/include/duckdb/parallel/event.hpp index 89a108d98a98..b65dd0443c68 100644 --- a/src/include/duckdb/parallel/event.hpp +++ b/src/include/duckdb/parallel/event.hpp @@ -16,7 +16,7 @@ namespace duckdb { class Executor; class Task; -class Event : public std::enable_shared_from_this { +class Event : public enable_shared_from_this { public: explicit Event(Executor &executor); virtual ~Event() = default; diff --git a/src/include/duckdb/parallel/interrupt.hpp b/src/include/duckdb/parallel/interrupt.hpp index fe5348bc9395..f3c54aa29cdf 100644 --- a/src/include/duckdb/parallel/interrupt.hpp +++ b/src/include/duckdb/parallel/interrupt.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/atomic.hpp" #include "duckdb/common/mutex.hpp" #include "duckdb/parallel/task.hpp" +#include "duckdb/common/shared_ptr.hpp" #include #include diff --git a/src/include/duckdb/parallel/meta_pipeline.hpp b/src/include/duckdb/parallel/meta_pipeline.hpp index 5bf58ef80cfd..f8f954fb4c62 100644 --- a/src/include/duckdb/parallel/meta_pipeline.hpp +++ b/src/include/duckdb/parallel/meta_pipeline.hpp @@ -14,7 +14,7 @@ namespace duckdb { //! MetaPipeline represents a set of pipelines that all have the same sink -class MetaPipeline : public std::enable_shared_from_this { +class MetaPipeline : public enable_shared_from_this { //! We follow these rules when building: //! 1. For joins, build out the blocking side before going down the probe side //! - The current streaming pipeline will have a dependency on it (dependency across MetaPipelines) diff --git a/src/include/duckdb/parallel/pipeline.hpp b/src/include/duckdb/parallel/pipeline.hpp index 28781200abb8..cb53777a2d61 100644 --- a/src/include/duckdb/parallel/pipeline.hpp +++ b/src/include/duckdb/parallel/pipeline.hpp @@ -66,7 +66,7 @@ class PipelineBuildState { }; //! The Pipeline class represents an execution pipeline starting at a -class Pipeline : public std::enable_shared_from_this { +class Pipeline : public enable_shared_from_this { friend class Executor; friend class PipelineExecutor; friend class PipelineEvent; diff --git a/src/include/duckdb/parallel/task.hpp b/src/include/duckdb/parallel/task.hpp index 2deadcbe3bad..2bcafeeaad5b 100644 --- a/src/include/duckdb/parallel/task.hpp +++ b/src/include/duckdb/parallel/task.hpp @@ -22,7 +22,7 @@ enum class TaskExecutionMode : uint8_t { PROCESS_ALL, PROCESS_PARTIAL }; enum class TaskExecutionResult : uint8_t { TASK_FINISHED, TASK_NOT_FINISHED, TASK_ERROR, TASK_BLOCKED }; //! Generic parallel task -class Task : public std::enable_shared_from_this { +class Task : public enable_shared_from_this { public: virtual ~Task() { } diff --git a/src/include/duckdb/planner/binder.hpp b/src/include/duckdb/planner/binder.hpp index 2c73e761f940..8d096266439a 100644 --- a/src/include/duckdb/planner/binder.hpp +++ b/src/include/duckdb/planner/binder.hpp @@ -80,7 +80,7 @@ struct CorrelatedColumnInfo { tables and columns in the catalog. In the process, it also resolves types of all expressions. */ -class Binder : public std::enable_shared_from_this { +class Binder : public enable_shared_from_this { friend class ExpressionBinder; friend class RecursiveDependentJoinPlanner; @@ -376,7 +376,7 @@ class Binder : public std::enable_shared_from_this { unique_ptr BindSummarize(ShowRef &ref); public: - // This should really be a private constructor, but make_shared does not allow it... + // This should really be a private constructor, but make_refcounted does not allow it... // If you are thinking about calling this, you should probably call Binder::CreateBinder Binder(bool i_know_what_i_am_doing, ClientContext &context, shared_ptr parent, bool inherit_ctes); }; diff --git a/src/include/duckdb/storage/object_cache.hpp b/src/include/duckdb/storage/object_cache.hpp index 25b6ab69d2ea..06a5c2d3a767 100644 --- a/src/include/duckdb/storage/object_cache.hpp +++ b/src/include/duckdb/storage/object_cache.hpp @@ -43,7 +43,7 @@ class ObjectCache { if (!object || object->GetObjectType() != T::ObjectType()) { return nullptr; } - return std::static_pointer_cast(object); + return shared_ptr_cast(object); } template @@ -52,7 +52,7 @@ class ObjectCache { auto entry = cache.find(key); if (entry == cache.end()) { - auto value = make_shared(args...); + auto value = make_refcounted(args...); cache[key] = value; return value; } @@ -60,7 +60,7 @@ class ObjectCache { if (!object || object->GetObjectType() != T::ObjectType()) { return nullptr; } - return std::static_pointer_cast(object); + return shared_ptr_cast(object); } void Put(string key, shared_ptr value) { diff --git a/src/include/duckdb/storage/serialization/types.json b/src/include/duckdb/storage/serialization/types.json index dd4cf2b7f147..5433f50a9d15 100644 --- a/src/include/duckdb/storage/serialization/types.json +++ b/src/include/duckdb/storage/serialization/types.json @@ -155,7 +155,7 @@ "class": "GenericTypeInfo", "base": "ExtraTypeInfo", "enum": "GENERIC_TYPE_INFO", - "custom_switch_code": "result = make_shared(type);\nbreak;" + "custom_switch_code": "result = make_refcounted(type);\nbreak;" }, { "class": "AnyTypeInfo", diff --git a/src/include/duckdb/transaction/local_storage.hpp b/src/include/duckdb/transaction/local_storage.hpp index 7481abd70e8c..a2140444dea1 100644 --- a/src/include/duckdb/transaction/local_storage.hpp +++ b/src/include/duckdb/transaction/local_storage.hpp @@ -21,7 +21,7 @@ class WriteAheadLog; struct LocalAppendState; struct TableAppendState; -class LocalTableStorage : public std::enable_shared_from_this { +class LocalTableStorage : public enable_shared_from_this { public: // Create a new LocalTableStorage explicit LocalTableStorage(DataTable &table); diff --git a/src/include/duckdb/transaction/transaction.hpp b/src/include/duckdb/transaction/transaction.hpp index dff31db5701c..b1bc3952c330 100644 --- a/src/include/duckdb/transaction/transaction.hpp +++ b/src/include/duckdb/transaction/transaction.hpp @@ -13,7 +13,7 @@ #include "duckdb/transaction/undo_buffer.hpp" #include "duckdb/common/atomic.hpp" #include "duckdb/transaction/transaction_data.hpp" -#include "duckdb/common/weak_ptr.hpp" +#include "duckdb/common/shared_ptr.hpp" namespace duckdb { class SequenceCatalogEntry; diff --git a/src/main/capi/table_function-c.cpp b/src/main/capi/table_function-c.cpp index e6eb5e3549fd..57be51d01f4b 100644 --- a/src/main/capi/table_function-c.cpp +++ b/src/main/capi/table_function-c.cpp @@ -179,7 +179,7 @@ void CTableFunction(ClientContext &context, TableFunctionInput &data_p, DataChun duckdb_table_function duckdb_create_table_function() { auto function = new duckdb::TableFunction("", {}, duckdb::CTableFunction, duckdb::CTableFunctionBind, duckdb::CTableFunctionInit, duckdb::CTableFunctionLocalInit); - function->function_info = duckdb::make_shared(); + function->function_info = duckdb::make_refcounted(); function->cardinality = duckdb::CTableFunctionCardinality; return function; } diff --git a/src/main/client_context.cpp b/src/main/client_context.cpp index 909593aa66f8..835210c76d1e 100644 --- a/src/main/client_context.cpp +++ b/src/main/client_context.cpp @@ -312,7 +312,7 @@ ClientContext::CreatePreparedStatementInternal(ClientContextLock &lock, const st unique_ptr statement, optional_ptr> values) { StatementType statement_type = statement->type; - auto result = make_shared(statement_type); + auto result = make_refcounted(statement_type); auto &profiler = QueryProfiler::Get(*this); profiler.StartQuery(query, IsExplainAnalyze(statement.get()), true); diff --git a/src/main/client_data.cpp b/src/main/client_data.cpp index 1298df3ea84d..f00b237837c4 100644 --- a/src/main/client_data.cpp +++ b/src/main/client_data.cpp @@ -35,8 +35,8 @@ class ClientFileSystem : public OpenerFileSystem { ClientData::ClientData(ClientContext &context) : catalog_search_path(make_uniq(context)) { auto &db = DatabaseInstance::GetDatabase(context); - profiler = make_shared(context); - temporary_objects = make_shared(db, AttachedDatabaseType::TEMP_DATABASE); + profiler = make_refcounted(context); + temporary_objects = make_refcounted(db, AttachedDatabaseType::TEMP_DATABASE); temporary_objects->oid = DatabaseManager::Get(db).ModifyCatalog(); random_engine = make_uniq(); file_opener = make_uniq(context); diff --git a/src/main/connection.cpp b/src/main/connection.cpp index 432ca9c21727..b76d440647f7 100644 --- a/src/main/connection.cpp +++ b/src/main/connection.cpp @@ -18,7 +18,8 @@ namespace duckdb { -Connection::Connection(DatabaseInstance &database) : context(make_shared(database.shared_from_this())) { +Connection::Connection(DatabaseInstance &database) + : context(make_refcounted(database.shared_from_this())) { ConnectionManager::Get(database).AddConnection(*context); #ifdef DEBUG EnableProfiling(); @@ -186,7 +187,7 @@ shared_ptr Connection::Table(const string &schema_name, const string & if (!table_info) { throw CatalogException("Table '%s' does not exist!", table_name); } - return make_shared(context, std::move(table_info)); + return make_refcounted(context, std::move(table_info)); } shared_ptr Connection::View(const string &tname) { @@ -194,7 +195,7 @@ shared_ptr Connection::View(const string &tname) { } shared_ptr Connection::View(const string &schema_name, const string &table_name) { - return make_shared(context, schema_name, table_name); + return make_refcounted(context, schema_name, table_name); } shared_ptr Connection::TableFunction(const string &fname) { @@ -205,11 +206,11 @@ shared_ptr Connection::TableFunction(const string &fname) { shared_ptr Connection::TableFunction(const string &fname, const vector &values, const named_parameter_map_t &named_parameters) { - return make_shared(context, fname, values, named_parameters); + return make_refcounted(context, fname, values, named_parameters); } shared_ptr Connection::TableFunction(const string &fname, const vector &values) { - return make_shared(context, fname, values); + return make_refcounted(context, fname, values); } shared_ptr Connection::Values(const vector> &values) { @@ -219,7 +220,7 @@ shared_ptr Connection::Values(const vector> &values) { shared_ptr Connection::Values(const vector> &values, const vector &column_names, const string &alias) { - return make_shared(context, values, column_names, alias); + return make_refcounted(context, values, column_names, alias); } shared_ptr Connection::Values(const string &values) { @@ -228,7 +229,7 @@ shared_ptr Connection::Values(const string &values) { } shared_ptr Connection::Values(const string &values, const vector &column_names, const string &alias) { - return make_shared(context, values, column_names, alias); + return make_refcounted(context, values, column_names, alias); } shared_ptr Connection::ReadCSV(const string &csv_file) { @@ -237,7 +238,7 @@ shared_ptr Connection::ReadCSV(const string &csv_file) { } shared_ptr Connection::ReadCSV(const vector &csv_input, named_parameter_map_t &&options) { - return make_shared(context, csv_input, std::move(options)); + return make_refcounted(context, csv_input, std::move(options)); } shared_ptr Connection::ReadCSV(const string &csv_input, named_parameter_map_t &&options) { @@ -258,7 +259,7 @@ shared_ptr Connection::ReadCSV(const string &csv_file, const vector files {csv_file}; - return make_shared(context, files, std::move(options)); + return make_refcounted(context, files, std::move(options)); } shared_ptr Connection::ReadParquet(const string &parquet_file, bool binary_as_string) { @@ -277,7 +278,7 @@ shared_ptr Connection::RelationFromQuery(const string &query, const st } shared_ptr Connection::RelationFromQuery(unique_ptr select_stmt, const string &alias) { - return make_shared(context, std::move(select_stmt), alias); + return make_refcounted(context, std::move(select_stmt), alias); } void Connection::BeginTransaction() { diff --git a/src/main/database.cpp b/src/main/database.cpp index 2e3fd0204e3a..a350f04ad610 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -264,7 +264,7 @@ void DatabaseInstance::Initialize(const char *database_path, DBConfig *user_conf scheduler->RelaunchThreads(); } -DuckDB::DuckDB(const char *path, DBConfig *new_config) : instance(make_shared()) { +DuckDB::DuckDB(const char *path, DBConfig *new_config) : instance(make_refcounted()) { instance->Initialize(path, new_config); if (instance->config.options.load_extensions) { ExtensionHelper::LoadAllExtensions(*this); @@ -368,7 +368,7 @@ void DatabaseInstance::Configure(DBConfig &new_config) { if (new_config.buffer_pool) { config.buffer_pool = std::move(new_config.buffer_pool); } else { - config.buffer_pool = make_shared(config.options.maximum_memory); + config.buffer_pool = make_refcounted(config.options.maximum_memory); } } diff --git a/src/main/db_instance_cache.cpp b/src/main/db_instance_cache.cpp index 342af27f3c75..6105066c6c88 100644 --- a/src/main/db_instance_cache.cpp +++ b/src/main/db_instance_cache.cpp @@ -66,7 +66,7 @@ shared_ptr DBInstanceCache::CreateInstanceInternal(const string &databas if (abs_database_path.rfind(IN_MEMORY_PATH, 0) == 0) { instance_path = IN_MEMORY_PATH; } - auto db_instance = make_shared(instance_path, &config); + auto db_instance = make_refcounted(instance_path, &config); if (cache_instance) { db_instances[abs_database_path] = db_instance; } diff --git a/src/main/relation.cpp b/src/main/relation.cpp index 87be11809ff4..f315906bd200 100644 --- a/src/main/relation.cpp +++ b/src/main/relation.cpp @@ -39,7 +39,7 @@ shared_ptr Relation::Project(const string &expression, const string &a shared_ptr Relation::Project(const string &select_list, const vector &aliases) { auto expressions = Parser::ParseExpressionList(select_list, context.GetContext()->GetParserOptions()); - return make_shared(shared_from_this(), std::move(expressions), aliases); + return make_refcounted(shared_from_this(), std::move(expressions), aliases); } shared_ptr Relation::Project(const vector &expressions) { @@ -49,7 +49,7 @@ shared_ptr Relation::Project(const vector &expressions) { shared_ptr Relation::Project(vector> expressions, const vector &aliases) { - return make_shared(shared_from_this(), std::move(expressions), aliases); + return make_refcounted(shared_from_this(), std::move(expressions), aliases); } static vector> StringListToExpressionList(ClientContext &context, @@ -70,7 +70,7 @@ static vector> StringListToExpressionList(ClientCon shared_ptr Relation::Project(const vector &expressions, const vector &aliases) { auto result_list = StringListToExpressionList(*context.GetContext(), expressions); - return make_shared(shared_from_this(), std::move(result_list), aliases); + return make_refcounted(shared_from_this(), std::move(result_list), aliases); } shared_ptr Relation::Filter(const string &expression) { @@ -82,7 +82,7 @@ shared_ptr Relation::Filter(const string &expression) { } shared_ptr Relation::Filter(unique_ptr expression) { - return make_shared(shared_from_this(), std::move(expression)); + return make_refcounted(shared_from_this(), std::move(expression)); } shared_ptr Relation::Filter(const vector &expressions) { @@ -95,11 +95,11 @@ shared_ptr Relation::Filter(const vector &expressions) { expr = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(expr), std::move(expression_list[i])); } - return make_shared(shared_from_this(), std::move(expr)); + return make_refcounted(shared_from_this(), std::move(expr)); } shared_ptr Relation::Limit(int64_t limit, int64_t offset) { - return make_shared(shared_from_this(), limit, offset); + return make_refcounted(shared_from_this(), limit, offset); } shared_ptr Relation::Order(const string &expression) { @@ -108,7 +108,7 @@ shared_ptr Relation::Order(const string &expression) { } shared_ptr Relation::Order(vector order_list) { - return make_shared(shared_from_this(), std::move(order_list)); + return make_refcounted(shared_from_this(), std::move(order_list)); } shared_ptr Relation::Order(const vector &expressions) { @@ -149,51 +149,51 @@ shared_ptr Relation::Join(const shared_ptr &other, } using_columns.push_back(colref.column_names[0]); } - return make_shared(shared_from_this(), other, std::move(using_columns), type, ref_type); + return make_refcounted(shared_from_this(), other, std::move(using_columns), type, ref_type); } else { // single expression that is not a column reference: use the expression as a join condition - return make_shared(shared_from_this(), other, std::move(expression_list[0]), type, ref_type); + return make_refcounted(shared_from_this(), other, std::move(expression_list[0]), type, ref_type); } } shared_ptr Relation::CrossProduct(const shared_ptr &other, JoinRefType join_ref_type) { - return make_shared(shared_from_this(), other, join_ref_type); + return make_refcounted(shared_from_this(), other, join_ref_type); } shared_ptr Relation::Union(const shared_ptr &other) { - return make_shared(shared_from_this(), other, SetOperationType::UNION, true); + return make_refcounted(shared_from_this(), other, SetOperationType::UNION, true); } shared_ptr Relation::Except(const shared_ptr &other) { - return make_shared(shared_from_this(), other, SetOperationType::EXCEPT, true); + return make_refcounted(shared_from_this(), other, SetOperationType::EXCEPT, true); } shared_ptr Relation::Intersect(const shared_ptr &other) { - return make_shared(shared_from_this(), other, SetOperationType::INTERSECT, true); + return make_refcounted(shared_from_this(), other, SetOperationType::INTERSECT, true); } shared_ptr Relation::Distinct() { - return make_shared(shared_from_this()); + return make_refcounted(shared_from_this()); } shared_ptr Relation::Alias(const string &alias) { - return make_shared(shared_from_this(), alias); + return make_refcounted(shared_from_this(), alias); } shared_ptr Relation::Aggregate(const string &aggregate_list) { auto expression_list = Parser::ParseExpressionList(aggregate_list, context.GetContext()->GetParserOptions()); - return make_shared(shared_from_this(), std::move(expression_list)); + return make_refcounted(shared_from_this(), std::move(expression_list)); } shared_ptr Relation::Aggregate(const string &aggregate_list, const string &group_list) { auto expression_list = Parser::ParseExpressionList(aggregate_list, context.GetContext()->GetParserOptions()); auto groups = Parser::ParseGroupByList(group_list, context.GetContext()->GetParserOptions()); - return make_shared(shared_from_this(), std::move(expression_list), std::move(groups)); + return make_refcounted(shared_from_this(), std::move(expression_list), std::move(groups)); } shared_ptr Relation::Aggregate(const vector &aggregates) { auto aggregate_list = StringListToExpressionList(*context.GetContext(), aggregates); - return make_shared(shared_from_this(), std::move(aggregate_list)); + return make_refcounted(shared_from_this(), std::move(aggregate_list)); } shared_ptr Relation::Aggregate(const vector &aggregates, const vector &groups) { @@ -204,7 +204,7 @@ shared_ptr Relation::Aggregate(const vector &aggregates, const shared_ptr Relation::Aggregate(vector> expressions, const string &group_list) { auto groups = Parser::ParseGroupByList(group_list, context.GetContext()->GetParserOptions()); - return make_shared(shared_from_this(), std::move(expressions), std::move(groups)); + return make_refcounted(shared_from_this(), std::move(expressions), std::move(groups)); } string Relation::GetAlias() { @@ -237,7 +237,7 @@ BoundStatement Relation::Bind(Binder &binder) { } shared_ptr Relation::InsertRel(const string &schema_name, const string &table_name) { - return make_shared(shared_from_this(), schema_name, table_name); + return make_refcounted(shared_from_this(), schema_name, table_name); } void Relation::Insert(const string &table_name) { @@ -255,12 +255,12 @@ void Relation::Insert(const string &schema_name, const string &table_name) { void Relation::Insert(const vector> &values) { vector column_names; - auto rel = make_shared(context.GetContext(), values, std::move(column_names), "values"); + auto rel = make_refcounted(context.GetContext(), values, std::move(column_names), "values"); rel->Insert(GetAlias()); } shared_ptr Relation::CreateRel(const string &schema_name, const string &table_name) { - return make_shared(shared_from_this(), schema_name, table_name); + return make_refcounted(shared_from_this(), schema_name, table_name); } void Relation::Create(const string &table_name) { @@ -277,7 +277,7 @@ void Relation::Create(const string &schema_name, const string &table_name) { } shared_ptr Relation::WriteCSVRel(const string &csv_file, case_insensitive_map_t> options) { - return make_shared(shared_from_this(), csv_file, std::move(options)); + return make_refcounted(shared_from_this(), csv_file, std::move(options)); } void Relation::WriteCSV(const string &csv_file, case_insensitive_map_t> options) { @@ -292,7 +292,7 @@ void Relation::WriteCSV(const string &csv_file, case_insensitive_map_t Relation::WriteParquetRel(const string &parquet_file, case_insensitive_map_t> options) { auto write_parquet = - make_shared(shared_from_this(), parquet_file, std::move(options)); + make_refcounted(shared_from_this(), parquet_file, std::move(options)); return std::move(write_parquet); } @@ -310,7 +310,7 @@ shared_ptr Relation::CreateView(const string &name, bool replace, bool } shared_ptr Relation::CreateView(const string &schema_name, const string &name, bool replace, bool temporary) { - auto view = make_shared(shared_from_this(), schema_name, name, replace, temporary); + auto view = make_refcounted(shared_from_this(), schema_name, name, replace, temporary); auto res = view->Execute(); if (res->HasError()) { const string prepended_message = "Failed to create view '" + name + "': "; @@ -329,7 +329,7 @@ unique_ptr Relation::Query(const string &name, const string &sql) { } unique_ptr Relation::Explain(ExplainType type) { - auto explain = make_shared(shared_from_this(), type); + auto explain = make_refcounted(shared_from_this(), type); return explain->Execute(); } @@ -343,12 +343,12 @@ void Relation::Delete(const string &condition) { shared_ptr Relation::TableFunction(const std::string &fname, const vector &values, const named_parameter_map_t &named_parameters) { - return make_shared(context.GetContext(), fname, values, named_parameters, - shared_from_this()); + return make_refcounted(context.GetContext(), fname, values, named_parameters, + shared_from_this()); } shared_ptr Relation::TableFunction(const std::string &fname, const vector &values) { - return make_shared(context.GetContext(), fname, values, shared_from_this()); + return make_refcounted(context.GetContext(), fname, values, shared_from_this()); } string Relation::ToString() { diff --git a/src/main/relation/read_csv_relation.cpp b/src/main/relation/read_csv_relation.cpp index 1529de9a7637..f63d535c13ab 100644 --- a/src/main/relation/read_csv_relation.cpp +++ b/src/main/relation/read_csv_relation.cpp @@ -56,7 +56,7 @@ ReadCSVRelation::ReadCSVRelation(const shared_ptr &context, const shared_ptr buffer_manager; context->RunFunctionInTransaction([&]() { - buffer_manager = make_shared(*context, csv_options, files[0], 0); + buffer_manager = make_refcounted(*context, csv_options, files[0], 0); CSVSniffer sniffer(csv_options, buffer_manager, CSVStateMachineCache::Get(*context)); auto sniffer_result = sniffer.SniffCSV(); auto &types = sniffer_result.return_types; diff --git a/src/main/relation/table_relation.cpp b/src/main/relation/table_relation.cpp index 2cdc0d9d945a..c37c88507849 100644 --- a/src/main/relation/table_relation.cpp +++ b/src/main/relation/table_relation.cpp @@ -56,14 +56,14 @@ void TableRelation::Update(const string &update_list, const string &condition) { vector> expressions; auto cond = ParseCondition(*context.GetContext(), condition); Parser::ParseUpdateList(update_list, update_columns, expressions, context.GetContext()->GetParserOptions()); - auto update = make_shared(context, std::move(cond), description->schema, description->table, - std::move(update_columns), std::move(expressions)); + auto update = make_refcounted(context, std::move(cond), description->schema, description->table, + std::move(update_columns), std::move(expressions)); update->Execute(); } void TableRelation::Delete(const string &condition) { auto cond = ParseCondition(*context.GetContext(), condition); - auto del = make_shared(context, std::move(cond), description->schema, description->table); + auto del = make_refcounted(context, std::move(cond), description->schema, description->table); del->Execute(); } diff --git a/src/parallel/executor.cpp b/src/parallel/executor.cpp index 41e710284c0e..a3be9b315843 100644 --- a/src/parallel/executor.cpp +++ b/src/parallel/executor.cpp @@ -73,10 +73,11 @@ void Executor::SchedulePipeline(const shared_ptr &meta_pipeline, S // create events/stack for the base pipeline auto base_pipeline = meta_pipeline->GetBasePipeline(); - auto base_initialize_event = make_shared(base_pipeline); - auto base_event = make_shared(base_pipeline); - auto base_finish_event = make_shared(base_pipeline); - auto base_complete_event = make_shared(base_pipeline->executor, event_data.initial_schedule); + auto base_initialize_event = make_refcounted(base_pipeline); + auto base_event = make_refcounted(base_pipeline); + auto base_finish_event = make_refcounted(base_pipeline); + auto base_complete_event = + make_refcounted(base_pipeline->executor, event_data.initial_schedule); PipelineEventStack base_stack(*base_initialize_event, *base_event, *base_finish_event, *base_complete_event); events.push_back(std::move(base_initialize_event)); events.push_back(std::move(base_event)); @@ -96,7 +97,7 @@ void Executor::SchedulePipeline(const shared_ptr &meta_pipeline, S D_ASSERT(pipeline); // create events/stack for this pipeline - auto pipeline_event = make_shared(pipeline); + auto pipeline_event = make_refcounted(pipeline); auto finish_group = meta_pipeline->GetFinishGroup(*pipeline); if (finish_group) { @@ -115,7 +116,7 @@ void Executor::SchedulePipeline(const shared_ptr &meta_pipeline, S event_map.insert(make_pair(reference(*pipeline), pipeline_stack)); } else if (meta_pipeline->HasFinishEvent(*pipeline)) { // this pipeline has its own finish event (despite going into the same sink - Finalize twice!) - auto pipeline_finish_event = make_shared(pipeline); + auto pipeline_finish_event = make_refcounted(pipeline); PipelineEventStack pipeline_stack(base_stack.pipeline_initialize_event, *pipeline_event, *pipeline_finish_event, base_stack.pipeline_complete_event); events.push_back(std::move(pipeline_finish_event)); @@ -359,7 +360,7 @@ void Executor::InitializeInternal(PhysicalOperator &plan) { // build and ready the pipelines PipelineBuildState state; - auto root_pipeline = make_shared(*this, state, nullptr); + auto root_pipeline = make_refcounted(*this, state, nullptr); root_pipeline->Build(*physical_plan); root_pipeline->Ready(); @@ -570,7 +571,7 @@ shared_ptr Executor::CreateChildPipeline(Pipeline ¤t, PhysicalOp D_ASSERT(op.IsSource()); // found another operator that is a source, schedule a child pipeline // 'op' is the source, and the sink is the same - auto child_pipeline = make_shared(*this); + auto child_pipeline = make_refcounted(*this); child_pipeline->sink = current.sink; child_pipeline->source = &op; diff --git a/src/parallel/meta_pipeline.cpp b/src/parallel/meta_pipeline.cpp index ded1cb246112..b73515d8e694 100644 --- a/src/parallel/meta_pipeline.cpp +++ b/src/parallel/meta_pipeline.cpp @@ -82,7 +82,7 @@ void MetaPipeline::Ready() { } MetaPipeline &MetaPipeline::CreateChildMetaPipeline(Pipeline ¤t, PhysicalOperator &op) { - children.push_back(make_shared(executor, state, &op)); + children.push_back(make_refcounted(executor, state, &op)); auto child_meta_pipeline = children.back().get(); // child MetaPipeline must finish completely before this MetaPipeline can start current.AddDependency(child_meta_pipeline->GetBasePipeline()); @@ -92,7 +92,7 @@ MetaPipeline &MetaPipeline::CreateChildMetaPipeline(Pipeline ¤t, PhysicalO } Pipeline &MetaPipeline::CreatePipeline() { - pipelines.emplace_back(make_shared(executor)); + pipelines.emplace_back(make_refcounted(executor)); state.SetPipelineSink(*pipelines.back(), sink, next_batch_index++); return *pipelines.back(); } diff --git a/src/planner/bind_context.cpp b/src/planner/bind_context.cpp index eac1d69c2bd6..611c7b34414b 100644 --- a/src/planner/bind_context.cpp +++ b/src/planner/bind_context.cpp @@ -514,13 +514,13 @@ void BindContext::AddGenericBinding(idx_t index, const string &alias, const vect void BindContext::AddCTEBinding(idx_t index, const string &alias, const vector &names, const vector &types) { - auto binding = make_shared(BindingType::BASE, alias, types, names, index); + auto binding = make_refcounted(BindingType::BASE, alias, types, names, index); if (cte_bindings.find(alias) != cte_bindings.end()) { throw BinderException("Duplicate alias \"%s\" in query!", alias); } cte_bindings[alias] = std::move(binding); - cte_references[alias] = make_shared(0); + cte_references[alias] = make_refcounted(0); } void BindContext::AddContext(BindContext other) { diff --git a/src/planner/binder.cpp b/src/planner/binder.cpp index 75e2a0482364..9316c6f85f63 100644 --- a/src/planner/binder.cpp +++ b/src/planner/binder.cpp @@ -47,7 +47,7 @@ shared_ptr Binder::CreateBinder(ClientContext &context, optional_ptr(true, context, parent ? parent->shared_from_this() : nullptr, inherit_ctes); + return make_refcounted(true, context, parent ? parent->shared_from_this() : nullptr, inherit_ctes); } Binder::Binder(bool, ClientContext &context, shared_ptr parent_p, bool inherit_ctes_p) diff --git a/src/planner/bound_parameter_map.cpp b/src/planner/bound_parameter_map.cpp index 420fa3931fbb..61571cd8543f 100644 --- a/src/planner/bound_parameter_map.cpp +++ b/src/planner/bound_parameter_map.cpp @@ -33,7 +33,7 @@ shared_ptr BoundParameterMap::CreateOrGetData(const string & auto entry = parameters.find(identifier); if (entry == parameters.end()) { // no entry yet: create a new one - auto data = make_shared(); + auto data = make_refcounted(); data->return_type = GetReturnType(identifier); CreateNewParameter(identifier, data); diff --git a/src/planner/planner.cpp b/src/planner/planner.cpp index 2df3a698e185..37381c2b6ec9 100644 --- a/src/planner/planner.cpp +++ b/src/planner/planner.cpp @@ -101,7 +101,7 @@ shared_ptr Planner::PrepareSQLStatement(unique_ptr(copied_statement->type); + auto prepared_data = make_refcounted(copied_statement->type); prepared_data->unbound_statement = std::move(copied_statement); prepared_data->names = names; prepared_data->types = types; diff --git a/src/storage/buffer/block_manager.cpp b/src/storage/buffer/block_manager.cpp index 36c2b559f354..fdead1bc1e6b 100644 --- a/src/storage/buffer/block_manager.cpp +++ b/src/storage/buffer/block_manager.cpp @@ -23,7 +23,7 @@ shared_ptr BlockManager::RegisterBlock(block_id_t block_id) { } } // create a new block pointer for this block - auto result = make_shared(*this, block_id, MemoryTag::BASE_TABLE); + auto result = make_refcounted(*this, block_id, MemoryTag::BASE_TABLE); // register the block pointer in the set of blocks as a weak pointer blocks[block_id] = weak_ptr(result); return result; diff --git a/src/storage/checkpoint_manager.cpp b/src/storage/checkpoint_manager.cpp index 574fc0660d1b..fe91431d986b 100644 --- a/src/storage/checkpoint_manager.cpp +++ b/src/storage/checkpoint_manager.cpp @@ -425,7 +425,7 @@ void CheckpointReader::ReadIndex(ClientContext &context, Deserializer &deseriali // now we can look for the index in the catalog and assign the table info auto &index = catalog.CreateIndex(context, info)->Cast(); - index.info = make_shared(table.GetStorage().info, info.index_name); + index.info = make_refcounted(table.GetStorage().info, info.index_name); // insert the parsed expressions into the index so that we can (de)serialize them during consecutive checkpoints for (auto &parsed_expr : info.parsed_expressions) { diff --git a/src/storage/data_table.cpp b/src/storage/data_table.cpp index ac55613c9295..97b8bc3680ec 100644 --- a/src/storage/data_table.cpp +++ b/src/storage/data_table.cpp @@ -45,12 +45,12 @@ bool DataTableInfo::IsTemporary() const { DataTable::DataTable(AttachedDatabase &db, shared_ptr table_io_manager_p, const string &schema, const string &table, vector column_definitions_p, unique_ptr data) - : info(make_shared(db, std::move(table_io_manager_p), schema, table)), + : info(make_refcounted(db, std::move(table_io_manager_p), schema, table)), column_definitions(std::move(column_definitions_p)), db(db), is_root(true) { // initialize the table with the existing data from disk, if any auto types = GetTypes(); this->row_groups = - make_shared(info, TableIOManager::Get(*this).GetBlockManagerForRowData(), types, 0); + make_refcounted(info, TableIOManager::Get(*this).GetBlockManagerForRowData(), types, 0); if (data && data->row_group_count > 0) { this->row_groups->Initialize(*data); } else { diff --git a/src/storage/local_storage.cpp b/src/storage/local_storage.cpp index 2c7fb0fe1d79..235535607e51 100644 --- a/src/storage/local_storage.cpp +++ b/src/storage/local_storage.cpp @@ -18,8 +18,8 @@ LocalTableStorage::LocalTableStorage(DataTable &table) : table_ref(table), allocator(Allocator::Get(table.db)), deleted_rows(0), optimistic_writer(table), merged_storage(false) { auto types = table.GetTypes(); - row_groups = make_shared(table.info, TableIOManager::Get(table).GetBlockManagerForRowData(), - types, MAX_ROW_ID, 0); + row_groups = make_refcounted(table.info, TableIOManager::Get(table).GetBlockManagerForRowData(), + types, MAX_ROW_ID, 0); row_groups->InitializeEmpty(); table.info->indexes.Scan([&](Index &index) { @@ -250,7 +250,7 @@ LocalTableStorage &LocalTableManager::GetOrCreateStorage(DataTable &table) { lock_guard l(table_storage_lock); auto entry = table_storage.find(table); if (entry == table_storage.end()) { - auto new_storage = make_shared(table); + auto new_storage = make_refcounted(table); auto storage = new_storage.get(); table_storage.insert(make_pair(reference(table), std::move(new_storage))); return *storage; @@ -534,7 +534,7 @@ void LocalStorage::AddColumn(DataTable &old_dt, DataTable &new_dt, ColumnDefinit if (!storage) { return; } - auto new_storage = make_shared(context, new_dt, *storage, new_column, default_value); + auto new_storage = make_refcounted(context, new_dt, *storage, new_column, default_value); table_manager.InsertEntry(new_dt, std::move(new_storage)); } @@ -544,7 +544,7 @@ void LocalStorage::DropColumn(DataTable &old_dt, DataTable &new_dt, idx_t remove if (!storage) { return; } - auto new_storage = make_shared(new_dt, *storage, removed_column); + auto new_storage = make_refcounted(new_dt, *storage, removed_column); table_manager.InsertEntry(new_dt, std::move(new_storage)); } @@ -555,8 +555,8 @@ void LocalStorage::ChangeType(DataTable &old_dt, DataTable &new_dt, idx_t change if (!storage) { return; } - auto new_storage = - make_shared(context, new_dt, *storage, changed_idx, target_type, bound_columns, cast_expr); + auto new_storage = make_refcounted(context, new_dt, *storage, changed_idx, target_type, + bound_columns, cast_expr); table_manager.InsertEntry(new_dt, std::move(new_storage)); } diff --git a/src/storage/serialization/serialize_types.cpp b/src/storage/serialization/serialize_types.cpp index 0b75f518b8eb..1b3b37d87cd9 100644 --- a/src/storage/serialization/serialize_types.cpp +++ b/src/storage/serialization/serialize_types.cpp @@ -35,7 +35,7 @@ shared_ptr ExtraTypeInfo::Deserialize(Deserializer &deserializer) result = EnumTypeInfo::Deserialize(deserializer); break; case ExtraTypeInfoType::GENERIC_TYPE_INFO: - result = make_shared(type); + result = make_refcounted(type); break; case ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO: result = IntegerLiteralTypeInfo::Deserialize(deserializer); diff --git a/src/storage/standard_buffer_manager.cpp b/src/storage/standard_buffer_manager.cpp index 5b801134f08e..f0055e8cb3ef 100644 --- a/src/storage/standard_buffer_manager.cpp +++ b/src/storage/standard_buffer_manager.cpp @@ -98,8 +98,8 @@ shared_ptr StandardBufferManager::RegisterSmallMemory(idx_t block_s auto buffer = ConstructManagedBuffer(block_size, nullptr, FileBufferType::TINY_BUFFER); // create a new block pointer for this block - auto result = make_shared(*temp_block_manager, ++temporary_id, MemoryTag::BASE_TABLE, - std::move(buffer), false, block_size, std::move(reservation)); + auto result = make_refcounted(*temp_block_manager, ++temporary_id, MemoryTag::BASE_TABLE, + std::move(buffer), false, block_size, std::move(reservation)); #ifdef DUCKDB_DEBUG_DESTROY_BLOCKS // Initialize the memory with garbage data WriteGarbageIntoBuffer(*result->buffer); @@ -118,8 +118,8 @@ shared_ptr StandardBufferManager::RegisterMemory(MemoryTag tag, idx auto buffer = ConstructManagedBuffer(block_size, std::move(reusable_buffer)); // create a new block pointer for this block - return make_shared(*temp_block_manager, ++temporary_id, tag, std::move(buffer), can_destroy, - alloc_size, std::move(res)); + return make_refcounted(*temp_block_manager, ++temporary_id, tag, std::move(buffer), can_destroy, + alloc_size, std::move(res)); } BufferHandle StandardBufferManager::Allocate(MemoryTag tag, idx_t block_size, bool can_destroy, diff --git a/src/storage/statistics/column_statistics.cpp b/src/storage/statistics/column_statistics.cpp index e2c2b45b97f9..67d67417b671 100644 --- a/src/storage/statistics/column_statistics.cpp +++ b/src/storage/statistics/column_statistics.cpp @@ -14,7 +14,7 @@ ColumnStatistics::ColumnStatistics(BaseStatistics stats_p, unique_ptr ColumnStatistics::CreateEmptyStats(const LogicalType &type) { - return make_shared(BaseStatistics::CreateEmpty(type)); + return make_refcounted(BaseStatistics::CreateEmpty(type)); } void ColumnStatistics::Merge(ColumnStatistics &other) { @@ -53,7 +53,7 @@ void ColumnStatistics::UpdateDistinctStatistics(Vector &v, idx_t count) { } shared_ptr ColumnStatistics::Copy() const { - return make_shared(stats.Copy(), distinct_stats ? distinct_stats->Copy() : nullptr); + return make_refcounted(stats.Copy(), distinct_stats ? distinct_stats->Copy() : nullptr); } void ColumnStatistics::Serialize(Serializer &serializer) const { @@ -65,7 +65,7 @@ shared_ptr ColumnStatistics::Deserialize(Deserializer &deseria auto stats = deserializer.ReadProperty(100, "statistics"); auto distinct_stats = deserializer.ReadPropertyWithDefault>( 101, "distinct", unique_ptr()); - return make_shared(std::move(stats), std::move(distinct_stats)); + return make_refcounted(std::move(stats), std::move(distinct_stats)); } } // namespace duckdb diff --git a/src/storage/table/row_group.cpp b/src/storage/table/row_group.cpp index 824827d76a97..4d232e3f53c9 100644 --- a/src/storage/table/row_group.cpp +++ b/src/storage/table/row_group.cpp @@ -624,7 +624,7 @@ shared_ptr &RowGroup::GetOrCreateVersionInfoPtr() { if (!vinfo) { lock_guard lock(row_group_lock); if (!version_info) { - version_info = make_shared(start); + version_info = make_refcounted(start); } } return version_info; diff --git a/src/storage/table/row_group_collection.cpp b/src/storage/table/row_group_collection.cpp index 00e42d5b7b7f..c333ea6b1c10 100644 --- a/src/storage/table/row_group_collection.cpp +++ b/src/storage/table/row_group_collection.cpp @@ -55,7 +55,7 @@ RowGroupCollection::RowGroupCollection(shared_ptr info_p, BlockMa vector types_p, idx_t row_start_p, idx_t total_rows_p) : block_manager(block_manager), total_rows(total_rows_p), info(std::move(info_p)), types(std::move(types_p)), row_start(row_start_p), allocation_size(0) { - row_groups = make_shared(*this); + row_groups = make_refcounted(*this); } idx_t RowGroupCollection::GetTotalRows() const { @@ -1031,7 +1031,7 @@ shared_ptr RowGroupCollection::AddColumn(ClientContext &cont auto new_types = types; new_types.push_back(new_column.GetType()); auto result = - make_shared(info, block_manager, std::move(new_types), row_start, total_rows.load()); + make_refcounted(info, block_manager, std::move(new_types), row_start, total_rows.load()); ExpressionExecutor executor(context); DataChunk dummy_chunk; @@ -1059,7 +1059,7 @@ shared_ptr RowGroupCollection::RemoveColumn(idx_t col_idx) { new_types.erase(new_types.begin() + col_idx); auto result = - make_shared(info, block_manager, std::move(new_types), row_start, total_rows.load()); + make_refcounted(info, block_manager, std::move(new_types), row_start, total_rows.load()); result->stats.InitializeRemoveColumn(stats, col_idx); for (auto ¤t_row_group : row_groups->Segments()) { @@ -1077,7 +1077,7 @@ shared_ptr RowGroupCollection::AlterType(ClientContext &cont new_types[changed_idx] = target_type; auto result = - make_shared(info, block_manager, std::move(new_types), row_start, total_rows.load()); + make_refcounted(info, block_manager, std::move(new_types), row_start, total_rows.load()); result->stats.InitializeAlterType(stats, changed_idx, target_type); vector scan_types; diff --git a/src/storage/table/row_version_manager.cpp b/src/storage/table/row_version_manager.cpp index ead21f89234f..711daa7d1881 100644 --- a/src/storage/table/row_version_manager.cpp +++ b/src/storage/table/row_version_manager.cpp @@ -212,7 +212,7 @@ shared_ptr RowVersionManager::Deserialize(MetaBlockPointer de if (!delete_pointer.IsValid()) { return nullptr; } - auto version_info = make_shared(start); + auto version_info = make_refcounted(start); MetadataReader source(manager, delete_pointer, &version_info->storage_pointers); auto chunk_count = source.Read(); D_ASSERT(chunk_count > 0); diff --git a/src/storage/wal_replay.cpp b/src/storage/wal_replay.cpp index 9699bbac3e9d..b3fea7dae501 100644 --- a/src/storage/wal_replay.cpp +++ b/src/storage/wal_replay.cpp @@ -580,7 +580,7 @@ void WriteAheadLogDeserializer::ReplayCreateIndex() { // create the index in the catalog auto &table = catalog.GetEntry(context, create_info->schema, info.table).Cast(); auto &index = catalog.CreateIndex(context, info)->Cast(); - index.info = make_shared(table.GetStorage().info, index.name); + index.info = make_refcounted(table.GetStorage().info, index.name); // insert the parsed expressions into the index so that we can (de)serialize them during consecutive checkpoints for (auto &parsed_expr : info.parsed_expressions) { From 1e0fe2a96910637abc0037bc762393093df3a3c9 Mon Sep 17 00:00:00 2001 From: Tishj Date: Sat, 6 Apr 2024 21:08:11 +0200 Subject: [PATCH 054/201] almost compiling --- extension/httpfs/httpfs.cpp | 4 +- extension/httpfs/s3fs.cpp | 2 +- extension/json/json_functions/copy_json.cpp | 2 +- extension/json/json_functions/read_json.cpp | 16 ++--- .../json/json_functions/read_json_objects.cpp | 10 +-- extension/parquet/column_reader.cpp | 10 +-- .../include/templated_column_reader.hpp | 2 +- extension/parquet/parquet_crypto.cpp | 13 ++-- extension/parquet/parquet_extension.cpp | 4 +- extension/parquet/parquet_reader.cpp | 4 +- extension/parquet/parquet_writer.cpp | 2 +- extension/sqlsmith/statement_generator.cpp | 2 +- .../sqlsmith/third_party/sqlsmith/expr.cc | 46 ++++++------- .../sqlsmith/third_party/sqlsmith/grammar.cc | 66 +++++++++---------- .../sqlsmith/third_party/sqlsmith/sqlsmith.cc | 12 ++-- .../csv_scanner/sniffer/dialect_detection.cpp | 1 + .../operator/order/physical_order.cpp | 1 + src/include/duckdb/common/shared_ptr.ipp | 32 ++++++--- src/include/duckdb/common/unique_ptr.hpp | 1 + src/include/duckdb/common/weak_ptr.ipp | 18 +++-- .../csv_scanner/column_count_scanner.hpp | 1 + test/api/test_object_cache.cpp | 2 +- test/api/test_relation_api.cpp | 16 ++--- test/sql/storage/test_buffer_manager.cpp | 6 +- 24 files changed, 151 insertions(+), 122 deletions(-) diff --git a/extension/httpfs/httpfs.cpp b/extension/httpfs/httpfs.cpp index 9240df3cb0ac..1127d18de541 100644 --- a/extension/httpfs/httpfs.cpp +++ b/extension/httpfs/httpfs.cpp @@ -556,7 +556,7 @@ static optional_ptr TryGetMetadataCache(optional_ptrregistered_state.find("http_cache"); if (lookup == client_context->registered_state.end()) { - auto cache = make_shared(true, true); + auto cache = make_refcounted(true, true); client_context->registered_state["http_cache"] = cache; return cache.get(); } else { @@ -571,7 +571,7 @@ void HTTPFileHandle::Initialize(optional_ptr opener) { auto &hfs = file_system.Cast(); state = HTTPState::TryGetState(opener); if (!state) { - state = make_shared(); + state = make_refcounted(); } auto current_cache = TryGetMetadataCache(opener, hfs); diff --git a/extension/httpfs/s3fs.cpp b/extension/httpfs/s3fs.cpp index bdc03ec10bb8..5ba265c52c80 100644 --- a/extension/httpfs/s3fs.cpp +++ b/extension/httpfs/s3fs.cpp @@ -567,7 +567,7 @@ shared_ptr S3FileHandle::GetBuffer(uint16_t write_buffer_idx) { auto buffer_handle = s3fs.Allocate(part_size, config_params.max_upload_threads); auto new_write_buffer = - make_shared(write_buffer_idx * part_size, part_size, std::move(buffer_handle)); + make_refcounted(write_buffer_idx * part_size, part_size, std::move(buffer_handle)); { unique_lock lck(write_buffers_lock); auto lookup_result = write_buffers.find(write_buffer_idx); diff --git a/extension/json/json_functions/copy_json.cpp b/extension/json/json_functions/copy_json.cpp index dd26e6d0c0b1..a42c75aede5c 100644 --- a/extension/json/json_functions/copy_json.cpp +++ b/extension/json/json_functions/copy_json.cpp @@ -184,7 +184,7 @@ CopyFunction JSONFunctions::GetJSONCopyFunction() { function.plan = CopyToJSONPlan; function.copy_from_bind = CopyFromJSONBind; - function.copy_from_function = JSONFunctions::GetReadJSONTableFunction(make_shared( + function.copy_from_function = JSONFunctions::GetReadJSONTableFunction(make_refcounted( JSONScanType::READ_JSON, JSONFormat::NEWLINE_DELIMITED, JSONRecordType::RECORDS, false)); return function; diff --git a/extension/json/json_functions/read_json.cpp b/extension/json/json_functions/read_json.cpp index 3640e0a7e42d..db4a43e751be 100644 --- a/extension/json/json_functions/read_json.cpp +++ b/extension/json/json_functions/read_json.cpp @@ -381,26 +381,26 @@ TableFunctionSet CreateJSONFunctionInfo(string name, shared_ptr in } TableFunctionSet JSONFunctions::GetReadJSONFunction() { - auto info = - make_shared(JSONScanType::READ_JSON, JSONFormat::AUTO_DETECT, JSONRecordType::AUTO_DETECT, true); + auto info = make_refcounted(JSONScanType::READ_JSON, JSONFormat::AUTO_DETECT, + JSONRecordType::AUTO_DETECT, true); return CreateJSONFunctionInfo("read_json", std::move(info)); } TableFunctionSet JSONFunctions::GetReadNDJSONFunction() { - auto info = make_shared(JSONScanType::READ_JSON, JSONFormat::NEWLINE_DELIMITED, - JSONRecordType::AUTO_DETECT, true); + auto info = make_refcounted(JSONScanType::READ_JSON, JSONFormat::NEWLINE_DELIMITED, + JSONRecordType::AUTO_DETECT, true); return CreateJSONFunctionInfo("read_ndjson", std::move(info)); } TableFunctionSet JSONFunctions::GetReadJSONAutoFunction() { - auto info = - make_shared(JSONScanType::READ_JSON, JSONFormat::AUTO_DETECT, JSONRecordType::AUTO_DETECT, true); + auto info = make_refcounted(JSONScanType::READ_JSON, JSONFormat::AUTO_DETECT, + JSONRecordType::AUTO_DETECT, true); return CreateJSONFunctionInfo("read_json_auto", std::move(info)); } TableFunctionSet JSONFunctions::GetReadNDJSONAutoFunction() { - auto info = make_shared(JSONScanType::READ_JSON, JSONFormat::NEWLINE_DELIMITED, - JSONRecordType::AUTO_DETECT, true); + auto info = make_refcounted(JSONScanType::READ_JSON, JSONFormat::NEWLINE_DELIMITED, + JSONRecordType::AUTO_DETECT, true); return CreateJSONFunctionInfo("read_ndjson_auto", std::move(info)); } diff --git a/extension/json/json_functions/read_json_objects.cpp b/extension/json/json_functions/read_json_objects.cpp index 197a6a34843b..891d142dfaf2 100644 --- a/extension/json/json_functions/read_json_objects.cpp +++ b/extension/json/json_functions/read_json_objects.cpp @@ -61,7 +61,7 @@ TableFunction GetReadJSONObjectsTableFunction(bool list_parameter, shared_ptr(JSONScanType::READ_JSON_OBJECTS, JSONFormat::ARRAY, JSONRecordType::RECORDS); + make_refcounted(JSONScanType::READ_JSON_OBJECTS, JSONFormat::ARRAY, JSONRecordType::RECORDS); function_set.AddFunction(GetReadJSONObjectsTableFunction(false, function_info)); function_set.AddFunction(GetReadJSONObjectsTableFunction(true, function_info)); return function_set; @@ -69,8 +69,8 @@ TableFunctionSet JSONFunctions::GetReadJSONObjectsFunction() { TableFunctionSet JSONFunctions::GetReadNDJSONObjectsFunction() { TableFunctionSet function_set("read_ndjson_objects"); - auto function_info = make_shared(JSONScanType::READ_JSON_OBJECTS, JSONFormat::NEWLINE_DELIMITED, - JSONRecordType::RECORDS); + auto function_info = make_refcounted(JSONScanType::READ_JSON_OBJECTS, JSONFormat::NEWLINE_DELIMITED, + JSONRecordType::RECORDS); function_set.AddFunction(GetReadJSONObjectsTableFunction(false, function_info)); function_set.AddFunction(GetReadJSONObjectsTableFunction(true, function_info)); return function_set; @@ -78,8 +78,8 @@ TableFunctionSet JSONFunctions::GetReadNDJSONObjectsFunction() { TableFunctionSet JSONFunctions::GetReadJSONObjectsAutoFunction() { TableFunctionSet function_set("read_json_objects_auto"); - auto function_info = - make_shared(JSONScanType::READ_JSON_OBJECTS, JSONFormat::AUTO_DETECT, JSONRecordType::RECORDS); + auto function_info = make_refcounted(JSONScanType::READ_JSON_OBJECTS, JSONFormat::AUTO_DETECT, + JSONRecordType::RECORDS); function_set.AddFunction(GetReadJSONObjectsTableFunction(false, function_info)); function_set.AddFunction(GetReadJSONObjectsTableFunction(true, function_info)); return function_set; diff --git a/extension/parquet/column_reader.cpp b/extension/parquet/column_reader.cpp index b21d85fde215..285e9572e7ad 100644 --- a/extension/parquet/column_reader.cpp +++ b/extension/parquet/column_reader.cpp @@ -303,7 +303,7 @@ void ColumnReader::PreparePageV2(PageHeader &page_hdr) { void ColumnReader::AllocateBlock(idx_t size) { if (!block) { - block = make_shared(GetAllocator(), size); + block = make_refcounted(GetAllocator(), size); } else { block->resize(GetAllocator(), size); } @@ -515,7 +515,7 @@ idx_t ColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data_ptr result); } else if (dbp_decoder) { // TODO keep this in the state - auto read_buf = make_shared(); + auto read_buf = make_refcounted(); switch (schema.type) { case duckdb_parquet::format::Type::INT32: @@ -536,7 +536,7 @@ idx_t ColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data_ptr } else if (rle_decoder) { // RLE encoding for boolean D_ASSERT(type.id() == LogicalTypeId::BOOLEAN); - auto read_buf = make_shared(); + auto read_buf = make_refcounted(); read_buf->resize(reader.allocator, sizeof(bool) * (read_now - null_count)); rle_decoder->GetBatch(read_buf->ptr, read_now - null_count); PlainTemplated>(read_buf, define_out, read_now, filter, @@ -545,7 +545,7 @@ idx_t ColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data_ptr // DELTA_BYTE_ARRAY or DELTA_LENGTH_BYTE_ARRAY DeltaByteArray(define_out, read_now, filter, result_offset, result); } else if (bss_decoder) { - auto read_buf = make_shared(); + auto read_buf = make_refcounted(); switch (schema.type) { case duckdb_parquet::format::Type::FLOAT: @@ -661,7 +661,7 @@ void StringColumnReader::Dictionary(shared_ptr data, idx_t num static shared_ptr ReadDbpData(Allocator &allocator, ResizeableBuffer &buffer, idx_t &value_count) { auto decoder = make_uniq(buffer.ptr, buffer.len); value_count = decoder->TotalValues(); - auto result = make_shared(); + auto result = make_refcounted(); result->resize(allocator, sizeof(uint32_t) * value_count); decoder->GetBatch(result->ptr, value_count); decoder->Finalize(); diff --git a/extension/parquet/include/templated_column_reader.hpp b/extension/parquet/include/templated_column_reader.hpp index 59a1c13c4781..c2c7740c0902 100644 --- a/extension/parquet/include/templated_column_reader.hpp +++ b/extension/parquet/include/templated_column_reader.hpp @@ -43,7 +43,7 @@ class TemplatedColumnReader : public ColumnReader { public: void AllocateDict(idx_t size) { if (!dict) { - dict = make_shared(GetAllocator(), size); + dict = make_refcounted(GetAllocator(), size); } else { dict->resize(GetAllocator(), size); } diff --git a/extension/parquet/parquet_crypto.cpp b/extension/parquet/parquet_crypto.cpp index 6982d366da4c..d629899a672f 100644 --- a/extension/parquet/parquet_crypto.cpp +++ b/extension/parquet/parquet_crypto.cpp @@ -13,7 +13,7 @@ namespace duckdb { ParquetKeys &ParquetKeys::Get(ClientContext &context) { auto &cache = ObjectCache::GetObjectCache(context); if (!cache.Get(ParquetKeys::ObjectType())) { - cache.Put(ParquetKeys::ObjectType(), make_shared()); + cache.Put(ParquetKeys::ObjectType(), make_refcounted()); } return *cache.Get(ParquetKeys::ObjectType()); } @@ -300,13 +300,14 @@ class SimpleReadTransport : public TTransport { uint32_t ParquetCrypto::Read(TBase &object, TProtocol &iprot, const string &key) { // Create decryption protocol TCompactProtocolFactoryT tproto_factory; - auto dprot = tproto_factory.getProtocol(make_shared(iprot, key)); + auto dprot = tproto_factory.getProtocol(std::make_shared(iprot, key)); auto &dtrans = reinterpret_cast(*dprot->getTransport()); // We have to read the whole thing otherwise thrift throws an error before we realize we're decryption is wrong auto all = dtrans.ReadAll(); TCompactProtocolFactoryT tsimple_proto_factory; - auto simple_prot = tsimple_proto_factory.getProtocol(make_shared(all.get(), all.GetSize())); + auto simple_prot = + tsimple_proto_factory.getProtocol(std::make_shared(all.get(), all.GetSize())); // Read the object object.read(simple_prot.get()); @@ -317,7 +318,7 @@ uint32_t ParquetCrypto::Read(TBase &object, TProtocol &iprot, const string &key) uint32_t ParquetCrypto::Write(const TBase &object, TProtocol &oprot, const string &key) { // Create encryption protocol TCompactProtocolFactoryT tproto_factory; - auto eprot = tproto_factory.getProtocol(make_shared(oprot, key)); + auto eprot = tproto_factory.getProtocol(std::make_shared(oprot, key)); auto &etrans = reinterpret_cast(*eprot->getTransport()); // Write the object in memory @@ -331,7 +332,7 @@ uint32_t ParquetCrypto::ReadData(TProtocol &iprot, const data_ptr_t buffer, cons const string &key) { // Create decryption protocol TCompactProtocolFactoryT tproto_factory; - auto dprot = tproto_factory.getProtocol(make_shared(iprot, key)); + auto dprot = tproto_factory.getProtocol(std::make_shared(iprot, key)); auto &dtrans = reinterpret_cast(*dprot->getTransport()); // Read buffer @@ -346,7 +347,7 @@ uint32_t ParquetCrypto::WriteData(TProtocol &oprot, const const_data_ptr_t buffe // FIXME: we know the size upfront so we could do a streaming write instead of this // Create encryption protocol TCompactProtocolFactoryT tproto_factory; - auto eprot = tproto_factory.getProtocol(make_shared(oprot, key)); + auto eprot = tproto_factory.getProtocol(std::make_shared(oprot, key)); auto &etrans = reinterpret_cast(*eprot->getTransport()); // Write the data in memory diff --git a/extension/parquet/parquet_extension.cpp b/extension/parquet/parquet_extension.cpp index 36919143c85f..fb900c266ae9 100644 --- a/extension/parquet/parquet_extension.cpp +++ b/extension/parquet/parquet_extension.cpp @@ -540,7 +540,7 @@ class ParquetScanFunction { result->initial_reader = result->readers[0]; } else { result->initial_reader = - make_shared(context, bind_data.files[0], bind_data.parquet_options); + make_refcounted(context, bind_data.files[0], bind_data.parquet_options); result->readers[0] = result->initial_reader; } result->file_states[0] = ParquetFileState::OPEN; @@ -746,7 +746,7 @@ class ParquetScanFunction { shared_ptr reader; try { - reader = make_shared(context, file, pq_options); + reader = make_refcounted(context, file, pq_options); InitializeParquetReader(*reader, bind_data, parallel_state.column_ids, parallel_state.filters, context); } catch (...) { diff --git a/extension/parquet/parquet_reader.cpp b/extension/parquet/parquet_reader.cpp index efc16ff0ede2..ecf73d1ec542 100644 --- a/extension/parquet/parquet_reader.cpp +++ b/extension/parquet/parquet_reader.cpp @@ -49,7 +49,7 @@ using duckdb_parquet::format::Type; static unique_ptr CreateThriftFileProtocol(Allocator &allocator, FileHandle &file_handle, bool prefetch_mode) { - auto transport = make_shared(allocator, file_handle, prefetch_mode); + auto transport = make_refcounted(allocator, file_handle, prefetch_mode); return make_uniq>(std::move(transport)); } @@ -112,7 +112,7 @@ LoadMetadata(Allocator &allocator, FileHandle &file_handle, metadata->read(file_proto.get()); } - return make_shared(std::move(metadata), current_time); + return make_refcounted(std::move(metadata), current_time); } LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, bool binary_as_string) { diff --git a/extension/parquet/parquet_writer.cpp b/extension/parquet/parquet_writer.cpp index 7c299fa793a1..1e130964780a 100644 --- a/extension/parquet/parquet_writer.cpp +++ b/extension/parquet/parquet_writer.cpp @@ -366,7 +366,7 @@ ParquetWriter::ParquetWriter(FileSystem &fs, string file_name_p, vectorWriteData(const_data_ptr_cast("PAR1"), 4); } TCompactProtocolFactoryT tproto_factory; - protocol = tproto_factory.getProtocol(make_shared(*writer)); + protocol = tproto_factory.getProtocol(std::make_shared(*writer)); file_meta_data.num_rows = 0; file_meta_data.version = 1; diff --git a/extension/sqlsmith/statement_generator.cpp b/extension/sqlsmith/statement_generator.cpp index 973287df5dab..01e8609043e9 100644 --- a/extension/sqlsmith/statement_generator.cpp +++ b/extension/sqlsmith/statement_generator.cpp @@ -46,7 +46,7 @@ StatementGenerator::~StatementGenerator() { } shared_ptr StatementGenerator::GetDatabaseState(ClientContext &context) { - auto result = make_shared(); + auto result = make_refcounted(); result->test_types = TestAllTypesFun::GetTestTypes(); auto schemas = Catalog::GetAllSchemas(context); diff --git a/extension/sqlsmith/third_party/sqlsmith/expr.cc b/extension/sqlsmith/third_party/sqlsmith/expr.cc index 12e34a2ddfca..95d5fbfc813d 100644 --- a/extension/sqlsmith/third_party/sqlsmith/expr.cc +++ b/extension/sqlsmith/third_party/sqlsmith/expr.cc @@ -17,21 +17,21 @@ using impedance::matched; shared_ptr value_expr::factory(prod *p, sqltype *type_constraint) { try { if (1 == d20() && p->level < d6() && window_function::allowed(p)) - return make_shared(p, type_constraint); + return make_refcounted(p, type_constraint); else if (1 == d42() && p->level < d6()) - return make_shared(p, type_constraint); + return make_refcounted(p, type_constraint); else if (1 == d42() && p->level < d6()) - return make_shared(p, type_constraint); + return make_refcounted(p, type_constraint); else if (p->level < d6() && d6() == 1) - return make_shared(p, type_constraint); + return make_refcounted(p, type_constraint); else if (d12() == 1) - return make_shared(p, type_constraint); + return make_refcounted(p, type_constraint); else if (p->level < d6() && d9() == 1) - return make_shared(p, type_constraint); + return make_refcounted(p, type_constraint); else if (p->scope->refs.size() && d20() > 1) - return make_shared(p, type_constraint); + return make_refcounted(p, type_constraint); else - return make_shared(p, type_constraint); + return make_refcounted(p, type_constraint); } catch (runtime_error &e) { } p->retry(); @@ -89,18 +89,18 @@ column_reference::column_reference(prod *p, sqltype *type_constraint) : value_ex shared_ptr bool_expr::factory(prod *p) { try { if (p->level > d100()) - return make_shared(p); + return make_refcounted(p); if (d6() < 4) - return make_shared(p); + return make_refcounted(p); else if (d6() < 4) - return make_shared(p); + return make_refcounted(p); else if (d6() < 4) - return make_shared(p); + return make_refcounted(p); else if (d6() < 4) - return make_shared(p); + return make_refcounted(p); else - return make_shared(p); - // return make_shared(q); + return make_refcounted(p); + // return make_refcounted(q); } catch (runtime_error &e) { } p->retry(); @@ -108,7 +108,7 @@ shared_ptr bool_expr::factory(prod *p) { } exists_predicate::exists_predicate(prod *p) : bool_expr(p) { - subquery = make_shared(this, scope); + subquery = make_refcounted(this, scope); } void exists_predicate::accept(prod_visitor *v) { @@ -123,8 +123,8 @@ void exists_predicate::out(std::ostream &out) { } distinct_pred::distinct_pred(prod *p) : bool_binop(p) { - lhs = make_shared(this); - rhs = make_shared(this, lhs->type); + lhs = make_refcounted(this); + rhs = make_refcounted(this, lhs->type); } comparison_op::comparison_op(prod *p) : bool_binop(p) { @@ -330,15 +330,15 @@ void window_function::out(std::ostream &out) { window_function::window_function(prod *p, sqltype *type_constraint) : value_expr(p) { match(); - aggregate = make_shared(this, type_constraint, true); + aggregate = make_refcounted(this, type_constraint, true); type = aggregate->type; - partition_by.push_back(make_shared(this)); + partition_by.push_back(make_refcounted(this)); while (d6() > 4) - partition_by.push_back(make_shared(this)); + partition_by.push_back(make_refcounted(this)); - order_by.push_back(make_shared(this)); + order_by.push_back(make_refcounted(this)); while (d6() > 4) - order_by.push_back(make_shared(this)); + order_by.push_back(make_refcounted(this)); } bool window_function::allowed(prod *p) { diff --git a/extension/sqlsmith/third_party/sqlsmith/grammar.cc b/extension/sqlsmith/third_party/sqlsmith/grammar.cc index 040c2d135f0d..3335371e8820 100644 --- a/extension/sqlsmith/third_party/sqlsmith/grammar.cc +++ b/extension/sqlsmith/third_party/sqlsmith/grammar.cc @@ -16,14 +16,14 @@ shared_ptr table_ref::factory(prod *p) { try { if (p->level < 3 + d6()) { if (d6() > 3 && p->level < d6()) - return make_shared(p); + return make_refcounted(p); if (d6() > 3) - return make_shared(p); + return make_refcounted(p); } if (d6() > 3) - return make_shared(p); + return make_refcounted(p); else - return make_shared(p); + return make_refcounted(p); } catch (runtime_error &e) { p->retry(); } @@ -32,7 +32,7 @@ shared_ptr table_ref::factory(prod *p) { table_or_query_name::table_or_query_name(prod *p) : table_ref(p) { t = random_pick(scope->tables); - refs.push_back(make_shared(scope->stmt_uid("ref"), t)); + refs.push_back(make_refcounted(scope->stmt_uid("ref"), t)); } void table_or_query_name::out(std::ostream &out) { @@ -46,7 +46,7 @@ target_table::target_table(prod *p, table *victim) : table_ref(p) { retry(); } victim_ = victim; - refs.push_back(make_shared(scope->stmt_uid("target"), victim)); + refs.push_back(make_refcounted(scope->stmt_uid("target"), victim)); } void target_table::out(std::ostream &out) { @@ -62,7 +62,7 @@ table_sample::table_sample(prod *p) : table_ref(p) { retry(); } while (!t || !t->is_base_table); - refs.push_back(make_shared(scope->stmt_uid("sample"), t)); + refs.push_back(make_refcounted(scope->stmt_uid("sample"), t)); percent = 0.1 * d100(); method = (d6() > 2) ? "system" : "bernoulli"; } @@ -72,10 +72,10 @@ void table_sample::out(std::ostream &out) { } table_subquery::table_subquery(prod *p, bool lateral) : table_ref(p), is_lateral(lateral) { - query = make_shared(this, scope, lateral); + query = make_refcounted(this, scope, lateral); string alias = scope->stmt_uid("subq"); relation *aliased_rel = &query->select_list->derived_table; - refs.push_back(make_shared(alias, aliased_rel)); + refs.push_back(make_refcounted(alias, aliased_rel)); } table_subquery::~table_subquery() { @@ -89,9 +89,9 @@ void table_subquery::accept(prod_visitor *v) { shared_ptr join_cond::factory(prod *p, table_ref &lhs, table_ref &rhs) { try { if (d6() < 6) - return make_shared(p, lhs, rhs); + return make_refcounted(p, lhs, rhs); else - return make_shared(p, lhs, rhs); + return make_refcounted(p, lhs, rhs); } catch (runtime_error &e) { p->retry(); } @@ -196,7 +196,7 @@ from_clause::from_clause(prod *p) : prod(p) { // add a lateral subquery if (!impedance::matched(typeid(lateral_subquery))) break; - reflist.push_back(make_shared(this)); + reflist.push_back(make_refcounted(this)); for (auto r : reflist.back()->refs) scope->refs.push_back(&*r); } @@ -302,8 +302,8 @@ query_spec::query_spec(prod *p, struct scope *s, bool lateral) : prod(p), myscop if (lateral) scope->refs = s->refs; - from_clause = make_shared(this); - select_list = make_shared(this); + from_clause = make_refcounted(this); + select_list = make_refcounted(this); set_quantifier = (d100() == 1) ? "distinct" : ""; @@ -341,7 +341,7 @@ delete_stmt::delete_stmt(prod *p, struct scope *s, table *v) : modifying_stmt(p, delete_returning::delete_returning(prod *p, struct scope *s, table *victim) : delete_stmt(p, s, victim) { match(); - select_list = make_shared(this); + select_list = make_refcounted(this); } insert_stmt::insert_stmt(prod *p, struct scope *s, table *v) : modifying_stmt(p, s, v) { @@ -399,7 +399,7 @@ void set_list::out(std::ostream &out) { update_stmt::update_stmt(prod *p, struct scope *s, table *v) : modifying_stmt(p, s, v) { scope->refs.push_back(victim); search = bool_expr::factory(this); - set_list = make_shared(this, victim); + set_list = make_refcounted(this, victim); } void update_stmt::out(std::ostream &out) { @@ -409,7 +409,7 @@ void update_stmt::out(std::ostream &out) { update_returning::update_returning(prod *p, struct scope *s, table *v) : update_stmt(p, s, v) { match(); - select_list = make_shared(this); + select_list = make_refcounted(this); } upsert_stmt::upsert_stmt(prod *p, struct scope *s, table *v) : insert_stmt(p, s, v) { @@ -427,20 +427,20 @@ shared_ptr statement_factory(struct scope *s) { try { s->new_stmt(); if (d42() == 1) - return make_shared((struct prod *)0, s); + return make_refcounted((struct prod *)0, s); if (d42() == 1) - return make_shared((struct prod *)0, s); + return make_refcounted((struct prod *)0, s); else if (d42() == 1) - return make_shared((struct prod *)0, s); + return make_refcounted((struct prod *)0, s); else if (d42() == 1) { - return make_shared((struct prod *)0, s); + return make_refcounted((struct prod *)0, s); } else if (d42() == 1) - return make_shared((struct prod *)0, s); + return make_refcounted((struct prod *)0, s); else if (d6() > 4) - return make_shared((struct prod *)0, s); + return make_refcounted((struct prod *)0, s); else if (d6() > 5) - return make_shared((struct prod *)0, s); - return make_shared((struct prod *)0, s); + return make_refcounted((struct prod *)0, s); + return make_refcounted((struct prod *)0, s); } catch (runtime_error &e) { return statement_factory(s); } @@ -456,11 +456,11 @@ void common_table_expression::accept(prod_visitor *v) { common_table_expression::common_table_expression(prod *parent, struct scope *s) : prod(parent), myscope(s) { scope = &myscope; do { - shared_ptr query = make_shared(this, s); + shared_ptr query = make_refcounted(this, s); with_queries.push_back(query); string alias = scope->stmt_uid("jennifer"); relation *relation = &query->select_list->derived_table; - auto aliased_rel = make_shared(alias, relation); + auto aliased_rel = make_refcounted(alias, relation); refs.push_back(aliased_rel); scope->tables.push_back(&*aliased_rel); @@ -472,7 +472,7 @@ common_table_expression::common_table_expression(prod *parent, struct scope *s) scope->tables.push_back(pick); } while (d6() > 3); try { - query = make_shared(this, scope); + query = make_refcounted(this, scope); } catch (runtime_error &e) { retry(); goto retry; @@ -495,11 +495,11 @@ void common_table_expression::out(std::ostream &out) { merge_stmt::merge_stmt(prod *p, struct scope *s, table *v) : modifying_stmt(p, s, v) { match(); - target_table_ = make_shared(this, victim); + target_table_ = make_refcounted(this, victim); data_source = table_ref::factory(this); // join_condition = join_cond::factory(this, *target_table_, // *data_source); - join_condition = make_shared(this, *target_table_, *data_source); + join_condition = make_refcounted(this, *target_table_, *data_source); /* Put data_source into scope but not target_table. Visibility of the latter varies depending on kind of when clause. */ @@ -604,12 +604,12 @@ shared_ptr when_clause::factory(struct merge_stmt *p) { switch (d6()) { case 1: case 2: - return make_shared(p); + return make_refcounted(p); case 3: case 4: - return make_shared(p); + return make_refcounted(p); default: - return make_shared(p); + return make_refcounted(p); } } catch (runtime_error &e) { p->retry(); diff --git a/extension/sqlsmith/third_party/sqlsmith/sqlsmith.cc b/extension/sqlsmith/third_party/sqlsmith/sqlsmith.cc index b9eced4b3506..7673b6ac4392 100644 --- a/extension/sqlsmith/third_party/sqlsmith/sqlsmith.cc +++ b/extension/sqlsmith/third_party/sqlsmith/sqlsmith.cc @@ -83,7 +83,7 @@ int32_t run_sqlsmith(duckdb::DatabaseInstance &database, SQLSmithOptions opt) { try { shared_ptr schema; - schema = make_shared(database, opt.exclude_catalog, opt.verbose_output); + schema = make_refcounted(database, opt.exclude_catalog, opt.verbose_output); scope scope; long queries_generated = 0; @@ -97,20 +97,20 @@ int32_t run_sqlsmith(duckdb::DatabaseInstance &database, SQLSmithOptions opt) { duckdb::vector> loggers; - loggers.push_back(make_shared()); + loggers.push_back(make_refcounted()); if (opt.verbose_output) { - auto l = make_shared(); + auto l = make_refcounted(); global_cerr_logger = &*l; loggers.push_back(l); signal(SIGINT, cerr_log_handler); } if (opt.dump_all_graphs) - loggers.push_back(make_shared()); + loggers.push_back(make_refcounted()); if (opt.dump_all_queries) - loggers.push_back(make_shared()); + loggers.push_back(make_refcounted()); // if (options.count("dry-run")) { // while (1) { @@ -128,7 +128,7 @@ int32_t run_sqlsmith(duckdb::DatabaseInstance &database, SQLSmithOptions opt) { shared_ptr dut; - dut = make_shared(database); + dut = make_refcounted(database); if (opt.verbose_output) cerr << "Running queries..." << endl; diff --git a/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp b/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp index 80bddc9d158f..45f1b1803b55 100644 --- a/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp +++ b/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp @@ -1,5 +1,6 @@ #include "duckdb/execution/operator/csv_scanner/csv_sniffer.hpp" #include "duckdb/main/client_data.hpp" +#include "duckdb/common/shared_ptr.hpp" namespace duckdb { diff --git a/src/execution/operator/order/physical_order.cpp b/src/execution/operator/order/physical_order.cpp index 7aa77a7a30c0..8687acfe5ee1 100644 --- a/src/execution/operator/order/physical_order.cpp +++ b/src/execution/operator/order/physical_order.cpp @@ -6,6 +6,7 @@ #include "duckdb/parallel/base_pipeline_event.hpp" #include "duckdb/parallel/executor_task.hpp" #include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/common/shared_ptr.hpp" namespace duckdb { diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index f95901521664..5db96b1b8cb6 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -10,6 +10,8 @@ private: template friend class weak_ptr; std::shared_ptr internal; + template + friend class shared_ptr; public: // Constructors @@ -24,6 +26,9 @@ public: template shared_ptr(T *ptr, Deleter deleter) : internal(ptr, deleter) { } + template + shared_ptr(const shared_ptr &__r, T *__p) noexcept : internal(__r.internal, __p) { + } shared_ptr(const shared_ptr &other) : internal(other.internal) { } @@ -37,17 +42,16 @@ public: explicit shared_ptr(weak_ptr other) : internal(other.internal) { } - template ::value && __compatible_with::value && - std::is_convertible::pointer, T *>::value, - int> = 0> - shared_ptr(unique_ptr other) : internal(other.release()) { +#if _LIBCPP_STD_VER <= 14 || defined(_LIBCPP_ENABLE_CXX17_REMOVED_AUTO_PTR) + template ::value, int> = 0> + shared_ptr(std::auto_ptr &&__r) : internal(__r.release()) { } +#endif template ::value && __compatible_with::value && - std::is_convertible::pointer, T *>::value, - int> = 0> + typename std::enable_if<__compatible_with::value && + std::is_convertible::pointer, T *>::value, + int>::type = 0> shared_ptr(unique_ptr &&other) : internal(other.release()) { } @@ -60,7 +64,10 @@ public: return *this; } - template + template ::value && + std::is_convertible::pointer, T *>::value, + int>::type = 0> shared_ptr &operator=(unique_ptr &&__r) { shared_ptr(std::move(__r)).swap(*this); return *this; @@ -81,6 +88,10 @@ public: internal.reset(ptr, deleter); } + void swap(shared_ptr &r) noexcept { + internal.swap(r.internal); + } + // Observers T *get() const { return internal.get(); @@ -122,6 +133,9 @@ public: bool operator!=(const shared_ptr &other) const noexcept { return internal != other.internal; } + bool operator!=(std::nullptr_t) const noexcept { + return internal != nullptr; + } template bool operator<(const shared_ptr &other) const noexcept { diff --git a/src/include/duckdb/common/unique_ptr.hpp b/src/include/duckdb/common/unique_ptr.hpp index b98f8da00030..0689aeb6624b 100644 --- a/src/include/duckdb/common/unique_ptr.hpp +++ b/src/include/duckdb/common/unique_ptr.hpp @@ -14,6 +14,7 @@ class unique_ptr : public std::unique_ptr { // NOLINT: namin public: using original = std::unique_ptr; using original::original; // NOLINT + using pointer = typename original::pointer; private: static inline void AssertNotNull(const bool null) { diff --git a/src/include/duckdb/common/weak_ptr.ipp b/src/include/duckdb/common/weak_ptr.ipp index 5fbe213c92bc..2fd95699de2d 100644 --- a/src/include/duckdb/common/weak_ptr.ipp +++ b/src/include/duckdb/common/weak_ptr.ipp @@ -11,13 +11,23 @@ public: // Constructors weak_ptr() : internal() { } - // template ::value, int> = 0> + template - weak_ptr(const shared_ptr &ptr) : internal(ptr.internal) { + weak_ptr(shared_ptr const &ptr, typename std::enable_if<__compatible_with::value, int>::type = 0) noexcept + : internal(ptr.internal) { } - weak_ptr(const weak_ptr &other) : internal(other.internal) { + weak_ptr(weak_ptr const &other) noexcept : internal(other.internal) { + } + template + weak_ptr(weak_ptr const &ptr, typename std::enable_if<__compatible_with::value, int>::type = 0) noexcept + : internal(ptr.internal) { + } + weak_ptr(weak_ptr &&ptr) noexcept : internal(ptr.internal) { + } + template + weak_ptr(weak_ptr &&ptr, typename std::enable_if<__compatible_with::value, int>::type = 0) noexcept + : internal(ptr.internal) { } - // Destructor ~weak_ptr() = default; diff --git a/src/include/duckdb/execution/operator/csv_scanner/column_count_scanner.hpp b/src/include/duckdb/execution/operator/csv_scanner/column_count_scanner.hpp index ce2da9606d66..de25f08f235c 100644 --- a/src/include/duckdb/execution/operator/csv_scanner/column_count_scanner.hpp +++ b/src/include/duckdb/execution/operator/csv_scanner/column_count_scanner.hpp @@ -13,6 +13,7 @@ #include "duckdb/execution/operator/csv_scanner/scanner_boundary.hpp" #include "duckdb/execution/operator/csv_scanner/string_value_scanner.hpp" #include "duckdb/execution/operator/csv_scanner/base_scanner.hpp" +#include "duckdb/common/shared_ptr.hpp" namespace duckdb { diff --git a/test/api/test_object_cache.cpp b/test/api/test_object_cache.cpp index a2dd2b5d8ab2..04f4ac7c91de 100644 --- a/test/api/test_object_cache.cpp +++ b/test/api/test_object_cache.cpp @@ -42,7 +42,7 @@ TEST_CASE("Test ObjectCache", "[api]") { auto &cache = ObjectCache::GetObjectCache(context); REQUIRE(cache.GetObject("test") == nullptr); - cache.Put("test", make_shared(42)); + cache.Put("test", make_refcounted(42)); REQUIRE(cache.GetObject("test") != nullptr); diff --git a/test/api/test_relation_api.cpp b/test/api/test_relation_api.cpp index 6d63bb3bbc43..740adedef620 100644 --- a/test/api/test_relation_api.cpp +++ b/test/api/test_relation_api.cpp @@ -12,7 +12,7 @@ TEST_CASE("Test simple relation API", "[relation_api]") { Connection con(db); con.EnableQueryVerification(); duckdb::unique_ptr result; - shared_ptr tbl, filter, proj, proj2, v1, v2, v3; + duckdb::shared_ptr tbl, filter, proj, proj2, v1, v2, v3; // create some tables REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)")); @@ -215,7 +215,7 @@ TEST_CASE("Test combinations of set operations", "[relation_api]") { Connection con(db); con.EnableQueryVerification(); duckdb::unique_ptr result; - shared_ptr values, v1, v2, v3; + duckdb::shared_ptr values, v1, v2, v3; REQUIRE_NOTHROW(values = con.Values({{1, 10}, {2, 5}, {3, 4}}, {"i", "j"})); @@ -282,7 +282,7 @@ TEST_CASE("Test combinations of joins", "[relation_api]") { Connection con(db); con.EnableQueryVerification(); duckdb::unique_ptr result; - shared_ptr values, vjoin; + duckdb::shared_ptr values, vjoin; REQUIRE_NOTHROW(values = con.Values({{1, 10}, {2, 5}, {3, 4}}, {"i", "j"})); @@ -370,7 +370,7 @@ TEST_CASE("Test crossproduct relation", "[relation_api]") { Connection con(db); con.EnableQueryVerification(); duckdb::unique_ptr result; - shared_ptr values, vcross; + duckdb::shared_ptr values, vcross; REQUIRE_NOTHROW(values = con.Values({{1, 10}, {2, 5}, {3, 4}}, {"i", "j"}), "v1"); REQUIRE_NOTHROW(values = con.Values({{1, 10}, {2, 5}, {3, 4}}, {"i", "j"}), "v2"); @@ -401,7 +401,7 @@ TEST_CASE("Test view creation of relations", "[relation_api]") { Connection con(db); con.EnableQueryVerification(); duckdb::unique_ptr result; - shared_ptr tbl, filter, proj, proj2; + duckdb::shared_ptr tbl, filter, proj, proj2; // create some tables REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)")); @@ -478,7 +478,7 @@ TEST_CASE("Test table creations using the relation API", "[relation_api]") { Connection con(db); con.EnableQueryVerification(); duckdb::unique_ptr result; - shared_ptr values; + duckdb::shared_ptr values; // create a table from a Values statement REQUIRE_NOTHROW(values = con.Values({{1, 10}, {2, 5}, {3, 4}}, {"i", "j"})); @@ -878,7 +878,7 @@ TEST_CASE("Test query relation", "[relation_api]") { Connection con(db); con.EnableQueryVerification(); duckdb::unique_ptr result; - shared_ptr tbl; + duckdb::shared_ptr tbl; // create some tables REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)")); @@ -905,7 +905,7 @@ TEST_CASE("Test TopK relation", "[relation_api]") { Connection con(db); con.EnableQueryVerification(); duckdb::unique_ptr result; - shared_ptr tbl; + duckdb::shared_ptr tbl; REQUIRE_NO_FAIL(con.Query("CREATE TABLE test (i integer,j VARCHAR, k varchar )")); REQUIRE_NO_FAIL(con.Query("insert into test values (10,'a','a'), (20,'a','b')")); diff --git a/test/sql/storage/test_buffer_manager.cpp b/test/sql/storage/test_buffer_manager.cpp index e730fe892b92..572a6be9bb41 100644 --- a/test/sql/storage/test_buffer_manager.cpp +++ b/test/sql/storage/test_buffer_manager.cpp @@ -152,7 +152,7 @@ TEST_CASE("Test buffer reallocation", "[storage][.]") { CHECK(buffer_manager.GetUsedMemory() == 0); idx_t requested_size = Storage::BLOCK_SIZE; - shared_ptr block; + duckdb::shared_ptr block; auto handle = buffer_manager.Allocate(MemoryTag::EXTENSION, requested_size, false, &block); CHECK(buffer_manager.GetUsedMemory() == BufferManager::GetAllocSize(requested_size)); for (; requested_size < limit; requested_size *= 2) { @@ -196,7 +196,7 @@ TEST_CASE("Test buffer manager variable size allocations", "[storage][.]") { CHECK(buffer_manager.GetUsedMemory() == 0); idx_t requested_size = 424242; - shared_ptr block; + duckdb::shared_ptr block; auto pin = buffer_manager.Allocate(MemoryTag::EXTENSION, requested_size, false, &block); CHECK(buffer_manager.GetUsedMemory() >= requested_size + Storage::BLOCK_HEADER_SIZE); @@ -224,7 +224,7 @@ TEST_CASE("Test buffer manager buffer re-use", "[storage][.]") { // Create 40 blocks, but don't hold the pin // They will be added to the eviction queue and the buffers will be re-used idx_t block_count = 40; - duckdb::vector> blocks; + duckdb::vector> blocks; blocks.reserve(block_count); for (idx_t i = 0; i < block_count; i++) { blocks.emplace_back(); From da66e1f49aff53d36ce7b627e9811504070307d2 Mon Sep 17 00:00:00 2001 From: Tishj Date: Sat, 6 Apr 2024 21:56:28 +0200 Subject: [PATCH 055/201] compiling now --- extension/parquet/parquet_reader.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/parquet/parquet_reader.cpp b/extension/parquet/parquet_reader.cpp index ecf73d1ec542..51c028e569ab 100644 --- a/extension/parquet/parquet_reader.cpp +++ b/extension/parquet/parquet_reader.cpp @@ -49,7 +49,7 @@ using duckdb_parquet::format::Type; static unique_ptr CreateThriftFileProtocol(Allocator &allocator, FileHandle &file_handle, bool prefetch_mode) { - auto transport = make_refcounted(allocator, file_handle, prefetch_mode); + auto transport = std::make_shared(allocator, file_handle, prefetch_mode); return make_uniq>(std::move(transport)); } From fb62ea24dcdf7908f1623aff4d744e7518defbdc Mon Sep 17 00:00:00 2001 From: Tishj Date: Sun, 7 Apr 2024 13:27:36 +0200 Subject: [PATCH 056/201] initialize 'shared_from_this' --- src/include/duckdb/common/shared_ptr.ipp | 35 ++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index 5db96b1b8cb6..433f58adf478 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -4,15 +4,20 @@ namespace duckdb { template class weak_ptr; +template +class enable_shared_from_this; + template class shared_ptr { private: template friend class weak_ptr; - std::shared_ptr internal; template friend class shared_ptr; +private: + std::shared_ptr internal; + public: // Constructors shared_ptr() : internal() { @@ -21,10 +26,12 @@ public: } // Implicit conversion template explicit shared_ptr(U *ptr) : internal(ptr) { + __enable_weak_this(internal.get(), internal.get()); } // Constructor with custom deleter template shared_ptr(T *ptr, Deleter deleter) : internal(ptr, deleter) { + __enable_weak_this(internal.get(), internal.get()); } template shared_ptr(const shared_ptr &__r, T *__p) noexcept : internal(__r.internal, __p) { @@ -33,9 +40,12 @@ public: shared_ptr(const shared_ptr &other) : internal(other.internal) { } - shared_ptr(std::shared_ptr other) : internal(std::move(other)) { + shared_ptr(std::shared_ptr other) : internal(other) { + // FIXME: should we __enable_weak_this here? + // *our* enable_shared_from_this hasn't initialized yet, so I think so? + __enable_weak_this(internal.get(), internal.get()); } - shared_ptr(shared_ptr &&other) : internal(std::move(other.internal)) { + shared_ptr(shared_ptr &&other) : internal(other.internal) { } template @@ -45,6 +55,7 @@ public: #if _LIBCPP_STD_VER <= 14 || defined(_LIBCPP_ENABLE_CXX17_REMOVED_AUTO_PTR) template ::value, int> = 0> shared_ptr(std::auto_ptr &&__r) : internal(__r.release()) { + __enable_weak_this(internal.get(), internal.get()); } #endif @@ -53,6 +64,7 @@ public: std::is_convertible::pointer, T *>::value, int>::type = 0> shared_ptr(unique_ptr &&other) : internal(other.release()) { + __enable_weak_this(internal.get(), internal.get()); } // Destructor @@ -159,6 +171,23 @@ public: template friend shared_ptr shared_ptr_cast(shared_ptr src); + +private: + // This overload is used when the class inherits from 'enable_shared_from_this' + template *>::value, + int>::type = 0> + void __enable_weak_this(const enable_shared_from_this *__e, _OrigPtr *__ptr) noexcept { + typedef typename std::remove_cv::type NonConstU; + if (__e && __e->__weak_this_.expired()) { + // __weak_this__ is the mutable variable returned by 'shared_from_this' + // it is initialized here + __e->__weak_this_ = shared_ptr(*this, const_cast(static_cast(__ptr))); + } + } + + void __enable_weak_this(...) noexcept { + } }; } // namespace duckdb From 18e58509356ad86b059139f2e1c07253293110d4 Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 8 Apr 2024 15:14:52 +0200 Subject: [PATCH 057/201] fix some compilation issues --- src/include/duckdb/common/shared_ptr.ipp | 25 ++++++++++++++++------- src/include/duckdb/common/types/value.hpp | 1 + 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index 433f58adf478..3cec8de3b587 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -24,7 +24,7 @@ public: } shared_ptr(std::nullptr_t) : internal(nullptr) { } // Implicit conversion - template + template ::value, int>::type = 0> explicit shared_ptr(U *ptr) : internal(ptr) { __enable_weak_this(internal.get(), internal.get()); } @@ -36,6 +36,16 @@ public: template shared_ptr(const shared_ptr &__r, T *__p) noexcept : internal(__r.internal, __p) { } + template + shared_ptr(shared_ptr &&__r, T *__p) noexcept : internal(__r.internal, __p) { + } + + template ::value, int>::type = 0> + shared_ptr(const shared_ptr &__r) noexcept : internal(__r.internal) { + } + template ::value, int>::type = 0> + shared_ptr(shared_ptr &&__r) noexcept : internal(__r.internal) { + } shared_ptr(const shared_ptr &other) : internal(other.internal) { } @@ -72,7 +82,13 @@ public: // Assignment operators shared_ptr &operator=(const shared_ptr &other) { - internal = other.internal; + shared_ptr(other).swap(*this); + return *this; + } + + template ::value, int>::type = 0> + shared_ptr &operator=(const shared_ptr &other) { + shared_ptr(other).swap(*this); return *this; } @@ -117,11 +133,6 @@ public: return internal.operator bool(); } - template - operator shared_ptr() const noexcept { - return shared_ptr(internal); - } - // Element access std::__add_lvalue_reference_t operator*() const { return *internal; diff --git a/src/include/duckdb/common/types/value.hpp b/src/include/duckdb/common/types/value.hpp index 0328ee6f1c73..2e27fdc4409f 100644 --- a/src/include/duckdb/common/types/value.hpp +++ b/src/include/duckdb/common/types/value.hpp @@ -17,6 +17,7 @@ #include "duckdb/common/types/date.hpp" #include "duckdb/common/types/datetime.hpp" #include "duckdb/common/types/interval.hpp" +#include "duckdb/common/shared_ptr.hpp" namespace duckdb { From deaf43a38cc6637932c9df86f3ea9a08f42ca5f8 Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 9 Apr 2024 11:04:33 +0200 Subject: [PATCH 058/201] move constructor should actually move --- src/include/duckdb/common/helper.hpp | 2 +- src/include/duckdb/common/shared_ptr.ipp | 74 ++++++++++++++---------- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/include/duckdb/common/helper.hpp b/src/include/duckdb/common/helper.hpp index 81da4bd79513..2f12dca961f0 100644 --- a/src/include/duckdb/common/helper.hpp +++ b/src/include/duckdb/common/helper.hpp @@ -70,7 +70,7 @@ inline shared_ptr make_refcounted(ARGS&&... args) // NOLINT: mimic std style { - return shared_ptr(new DATA_TYPE(std::forward(args)...)); + return shared_ptr(std::make_shared(std::forward(args)...)); } template diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index 3cec8de3b587..f6976b91f3d9 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -12,9 +12,13 @@ class shared_ptr { private: template friend class weak_ptr; + template friend class shared_ptr; + template + friend shared_ptr shared_ptr_cast(shared_ptr src); + private: std::shared_ptr internal; @@ -23,45 +27,55 @@ public: shared_ptr() : internal() { } shared_ptr(std::nullptr_t) : internal(nullptr) { - } // Implicit conversion + } + + // From raw pointer of type U convertible to T template ::value, int>::type = 0> explicit shared_ptr(U *ptr) : internal(ptr) { __enable_weak_this(internal.get(), internal.get()); } - // Constructor with custom deleter + // From raw pointer of type T with custom Deleter template shared_ptr(T *ptr, Deleter deleter) : internal(ptr, deleter) { __enable_weak_this(internal.get(), internal.get()); } + // Aliasing constructor: shares ownership information with __r but contains __p instead + // When the created shared_ptr goes out of scope, it will call the Deleter of __r, will not delete __p template shared_ptr(const shared_ptr &__r, T *__p) noexcept : internal(__r.internal, __p) { } +#if _LIBCPP_STD_VER >= 20 template - shared_ptr(shared_ptr &&__r, T *__p) noexcept : internal(__r.internal, __p) { + shared_ptr(shared_ptr &&__r, T *__p) noexcept : internal(std::move(__r.internal), __p) { } +#endif + // Copy constructor, share ownership with __r template ::value, int>::type = 0> shared_ptr(const shared_ptr &__r) noexcept : internal(__r.internal) { } + shared_ptr(const shared_ptr &other) : internal(other.internal) { + } + // Move constructor, share ownership with __r template ::value, int>::type = 0> - shared_ptr(shared_ptr &&__r) noexcept : internal(__r.internal) { + shared_ptr(shared_ptr &&__r) noexcept : internal(std::move(__r.internal)) { } - - shared_ptr(const shared_ptr &other) : internal(other.internal) { + shared_ptr(shared_ptr &&other) : internal(std::move(other.internal)) { } + // Construct from std::shared_ptr shared_ptr(std::shared_ptr other) : internal(other) { // FIXME: should we __enable_weak_this here? // *our* enable_shared_from_this hasn't initialized yet, so I think so? __enable_weak_this(internal.get(), internal.get()); } - shared_ptr(shared_ptr &&other) : internal(other.internal) { - } + // Construct from weak_ptr template explicit shared_ptr(weak_ptr other) : internal(other.internal) { } + // Construct from auto_ptr #if _LIBCPP_STD_VER <= 14 || defined(_LIBCPP_ENABLE_CXX17_REMOVED_AUTO_PTR) template ::value, int> = 0> shared_ptr(std::auto_ptr &&__r) : internal(__r.release()) { @@ -69,29 +83,43 @@ public: } #endif + // Construct from unique_ptr, takes over ownership of the unique_ptr template ::value && std::is_convertible::pointer, T *>::value, int>::type = 0> - shared_ptr(unique_ptr &&other) : internal(other.release()) { + shared_ptr(unique_ptr &&other) : internal(std::move(other)) { __enable_weak_this(internal.get(), internal.get()); } // Destructor ~shared_ptr() = default; - // Assignment operators - shared_ptr &operator=(const shared_ptr &other) { + // Assign from shared_ptr copy + shared_ptr &operator=(const shared_ptr &other) noexcept { + // Create a new shared_ptr using the copy constructor, then swap out the ownership to *this shared_ptr(other).swap(*this); return *this; } - template ::value, int>::type = 0> - shared_ptr &operator=(const shared_ptr &other) { + shared_ptr &operator=(const shared_ptr &other) { shared_ptr(other).swap(*this); return *this; } + // Assign from moved shared_ptr + shared_ptr &operator=(shared_ptr &&other) noexcept { + // Create a new shared_ptr using the move constructor, then swap out the ownership to *this + shared_ptr(std::move(other)).swap(*this); + return *this; + } + template ::value, int>::type = 0> + shared_ptr &operator=(shared_ptr &&other) { + shared_ptr(std::move(other)).swap(*this); + return *this; + } + + // Assign from moved unique_ptr template ::value && std::is_convertible::pointer, T *>::value, @@ -101,16 +129,13 @@ public: return *this; } - // Modifiers void reset() { internal.reset(); } - template void reset(U *ptr) { internal.reset(ptr); } - template void reset(U *ptr, Deleter deleter) { internal.reset(ptr, deleter); @@ -120,7 +145,6 @@ public: internal.swap(r.internal); } - // Observers T *get() const { return internal.get(); } @@ -133,7 +157,6 @@ public: return internal.operator bool(); } - // Element access std::__add_lvalue_reference_t operator*() const { return *internal; } @@ -147,15 +170,14 @@ public: bool operator==(const shared_ptr &other) const noexcept { return internal == other.internal; } - - bool operator==(std::nullptr_t) const noexcept { - return internal == nullptr; - } - template bool operator!=(const shared_ptr &other) const noexcept { return internal != other.internal; } + + bool operator==(std::nullptr_t) const noexcept { + return internal == nullptr; + } bool operator!=(std::nullptr_t) const noexcept { return internal != nullptr; } @@ -164,25 +186,19 @@ public: bool operator<(const shared_ptr &other) const noexcept { return internal < other.internal; } - template bool operator<=(const shared_ptr &other) const noexcept { return internal <= other.internal; } - template bool operator>(const shared_ptr &other) const noexcept { return internal > other.internal; } - template bool operator>=(const shared_ptr &other) const noexcept { return internal >= other.internal; } - template - friend shared_ptr shared_ptr_cast(shared_ptr src); - private: // This overload is used when the class inherits from 'enable_shared_from_this' template Date: Tue, 9 Apr 2024 11:10:11 +0200 Subject: [PATCH 059/201] make constructor from std::shared_ptr explicit --- extension/parquet/include/parquet_reader.hpp | 2 +- extension/parquet/include/parquet_writer.hpp | 2 +- .../duckdb/common/serializer/serialization_traits.hpp | 11 +++++++++++ src/include/duckdb/common/shared_ptr.ipp | 2 +- src/include/duckdb/common/weak_ptr.ipp | 2 +- 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/extension/parquet/include/parquet_reader.hpp b/extension/parquet/include/parquet_reader.hpp index 3393749373c4..6e65d46d8375 100644 --- a/extension/parquet/include/parquet_reader.hpp +++ b/extension/parquet/include/parquet_reader.hpp @@ -53,7 +53,7 @@ struct ParquetReaderScanState { idx_t group_offset; unique_ptr file_handle; unique_ptr root_reader; - unique_ptr thrift_file_proto; + std::unique_ptr thrift_file_proto; bool finished; SelectionVector sel; diff --git a/extension/parquet/include/parquet_writer.hpp b/extension/parquet/include/parquet_writer.hpp index 6b71b8196b26..44625d5db69d 100644 --- a/extension/parquet/include/parquet_writer.hpp +++ b/extension/parquet/include/parquet_writer.hpp @@ -108,7 +108,7 @@ class ParquetWriter { shared_ptr encryption_config; unique_ptr writer; - shared_ptr protocol; + std::shared_ptr protocol; duckdb_parquet::format::FileMetaData file_meta_data; std::mutex lock; diff --git a/src/include/duckdb/common/serializer/serialization_traits.hpp b/src/include/duckdb/common/serializer/serialization_traits.hpp index 616d90f2320f..5230a75e4f29 100644 --- a/src/include/duckdb/common/serializer/serialization_traits.hpp +++ b/src/include/duckdb/common/serializer/serialization_traits.hpp @@ -50,6 +50,13 @@ struct has_deserialize< T, typename std::enable_if(Deserializer &)>::value, T>::type> : std::true_type {}; +// Accept `static shared_ptr Deserialize(Deserializer& deserializer)` +template +struct has_deserialize< + T, + typename std::enable_if(Deserializer &)>::value, T>::type> + : std::true_type {}; + // Accept `static T Deserialize(Deserializer& deserializer)` template struct has_deserialize< @@ -105,6 +112,10 @@ template struct is_shared_ptr> : std::true_type { typedef T ELEMENT_TYPE; }; +template +struct is_shared_ptr> : std::true_type { + typedef T ELEMENT_TYPE; +}; template struct is_optional_ptr : std::false_type {}; diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index f6976b91f3d9..80ffc96507c7 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -64,7 +64,7 @@ public: } // Construct from std::shared_ptr - shared_ptr(std::shared_ptr other) : internal(other) { + explicit shared_ptr(std::shared_ptr other) : internal(other) { // FIXME: should we __enable_weak_this here? // *our* enable_shared_from_this hasn't initialized yet, so I think so? __enable_weak_this(internal.get(), internal.get()); diff --git a/src/include/duckdb/common/weak_ptr.ipp b/src/include/duckdb/common/weak_ptr.ipp index 2fd95699de2d..9ab76a0553eb 100644 --- a/src/include/duckdb/common/weak_ptr.ipp +++ b/src/include/duckdb/common/weak_ptr.ipp @@ -58,7 +58,7 @@ public: } shared_ptr lock() const { - return internal.lock(); + return shared_ptr(internal.lock()); } // Relational operators From 2cbe49e96c70f370830ba73f0ae0de3a9e790765 Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 9 Apr 2024 11:26:50 +0200 Subject: [PATCH 060/201] fix move in weak_ptr constructor, add SAFE template parameter --- .../duckdb/common/enable_shared_from_this.ipp | 24 ++++++++++--------- src/include/duckdb/common/shared_ptr.ipp | 23 +++++++++++------- src/include/duckdb/common/weak_ptr.ipp | 16 +++++++++---- 3 files changed, 38 insertions(+), 25 deletions(-) diff --git a/src/include/duckdb/common/enable_shared_from_this.ipp b/src/include/duckdb/common/enable_shared_from_this.ipp index 6472db9c2b12..d68d20033aa7 100644 --- a/src/include/duckdb/common/enable_shared_from_this.ipp +++ b/src/include/duckdb/common/enable_shared_from_this.ipp @@ -1,8 +1,13 @@ namespace duckdb { -template +template class enable_shared_from_this { - mutable weak_ptr<_Tp> __weak_this_; +public: + template + friend class shared_ptr; + +private: + mutable weak_ptr __weak_this_; protected: constexpr enable_shared_from_this() noexcept { @@ -16,25 +21,22 @@ protected: } public: - shared_ptr<_Tp> shared_from_this() { - return shared_ptr<_Tp>(__weak_this_); + shared_ptr shared_from_this() { + return shared_ptr(__weak_this_); } - shared_ptr<_Tp const> shared_from_this() const { - return shared_ptr(__weak_this_); + shared_ptr shared_from_this() const { + return shared_ptr(__weak_this_); } #if _LIBCPP_STD_VER >= 17 - weak_ptr<_Tp> weak_from_this() noexcept { + weak_ptr weak_from_this() noexcept { return __weak_this_; } - weak_ptr weak_from_this() const noexcept { + weak_ptr weak_from_this() const noexcept { return __weak_this_; } #endif // _LIBCPP_STD_VER >= 17 - - template - friend class shared_ptr; }; } // namespace duckdb diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index 80ffc96507c7..78d5365a03cc 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -1,26 +1,31 @@ namespace duckdb { -template +template class weak_ptr; template class enable_shared_from_this; -template +template class shared_ptr { +public: + using original = std::shared_ptr; + using element_type = typename original::element_type; + using weak_type = weak_ptr; + private: - template + template friend class weak_ptr; - template + template friend class shared_ptr; template friend shared_ptr shared_ptr_cast(shared_ptr src); private: - std::shared_ptr internal; + original internal; public: // Constructors @@ -84,11 +89,11 @@ public: #endif // Construct from unique_ptr, takes over ownership of the unique_ptr - template ::value && std::is_convertible::pointer, T *>::value, int>::type = 0> - shared_ptr(unique_ptr &&other) : internal(std::move(other)) { + shared_ptr(unique_ptr &&other) : internal(std::move(other)) { __enable_weak_this(internal.get(), internal.get()); } @@ -120,11 +125,11 @@ public: } // Assign from moved unique_ptr - template ::value && std::is_convertible::pointer, T *>::value, int>::type = 0> - shared_ptr &operator=(unique_ptr &&__r) { + shared_ptr &operator=(unique_ptr &&__r) { shared_ptr(std::move(__r)).swap(*this); return *this; } diff --git a/src/include/duckdb/common/weak_ptr.ipp b/src/include/duckdb/common/weak_ptr.ipp index 9ab76a0553eb..4b04adaa49dc 100644 --- a/src/include/duckdb/common/weak_ptr.ipp +++ b/src/include/duckdb/common/weak_ptr.ipp @@ -1,11 +1,17 @@ namespace duckdb { -template +template class weak_ptr { +public: + using original = std::weak_ptr; + using element_type = typename original::element_type; + private: - template + template friend class shared_ptr; - std::weak_ptr internal; + +private: + original internal; public: // Constructors @@ -22,11 +28,11 @@ public: weak_ptr(weak_ptr const &ptr, typename std::enable_if<__compatible_with::value, int>::type = 0) noexcept : internal(ptr.internal) { } - weak_ptr(weak_ptr &&ptr) noexcept : internal(ptr.internal) { + weak_ptr(weak_ptr &&ptr) noexcept : internal(std::move(ptr.internal)) { } template weak_ptr(weak_ptr &&ptr, typename std::enable_if<__compatible_with::value, int>::type = 0) noexcept - : internal(ptr.internal) { + : internal(std::move(ptr.internal)) { } // Destructor ~weak_ptr() = default; From dfdf55cd829b508f60f5f429bcf4113d081d597d Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 9 Apr 2024 11:49:19 +0200 Subject: [PATCH 061/201] add memory safety for shared_ptr and weak_ptr operations --- src/include/duckdb/common/shared_ptr.hpp | 15 ++++++++- src/include/duckdb/common/shared_ptr.ipp | 43 +++++++++++++++++++++--- src/include/duckdb/common/weak_ptr.ipp | 9 ++--- 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/src/include/duckdb/common/shared_ptr.hpp b/src/include/duckdb/common/shared_ptr.hpp index fe9d31ee40c9..82cb63157313 100644 --- a/src/include/duckdb/common/shared_ptr.hpp +++ b/src/include/duckdb/common/shared_ptr.hpp @@ -8,9 +8,12 @@ #pragma once +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/common/likely.hpp" +#include "duckdb/common/memory_safety.hpp" + #include #include -#include "duckdb/common/unique_ptr.hpp" #if _LIBCPP_STD_VER >= 17 template @@ -29,3 +32,13 @@ struct __compatible_with : std::is_convertible<_Yp *, _Tp *> {}; #include "duckdb/common/shared_ptr.ipp" #include "duckdb/common/weak_ptr.ipp" #include "duckdb/common/enable_shared_from_this.ipp" + +namespace duckdb { + +template +using unsafe_shared_ptr = shared_ptr; + +template +using unsafe_weak_ptr = weak_ptr; + +} // namespace duckdb diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index 78d5365a03cc..44496781a2ec 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -1,4 +1,3 @@ - namespace duckdb { template @@ -14,6 +13,17 @@ public: using element_type = typename original::element_type; using weak_type = weak_ptr; +private: + static inline void AssertNotNull(const bool null) { +#if defined(DUCKDB_DEBUG_NO_SAFETY) || defined(DUCKDB_CLANG_TIDY) + return; +#else + if (DUCKDB_UNLIKELY(null)) { + throw duckdb::InternalException("Attempted to dereference shared_ptr that is NULL!"); + } +#endif + } + private: template friend class weak_ptr; @@ -134,13 +144,26 @@ public: return *this; } - void reset() { +#ifdef DUCKDB_CLANG_TIDY + // This is necessary to tell clang-tidy that it reinitializes the variable after a move + [[clang::reinitializes]] +#endif + void + reset() { internal.reset(); } +#ifdef DUCKDB_CLANG_TIDY + // This is necessary to tell clang-tidy that it reinitializes the variable after a move + [[clang::reinitializes]] +#endif template void reset(U *ptr) { internal.reset(ptr); } +#ifdef DUCKDB_CLANG_TIDY + // This is necessary to tell clang-tidy that it reinitializes the variable after a move + [[clang::reinitializes]] +#endif template void reset(U *ptr, Deleter deleter) { internal.reset(ptr, deleter); @@ -163,11 +186,23 @@ public: } std::__add_lvalue_reference_t operator*() const { - return *internal; + if (MemorySafety::ENABLED) { + const auto ptr = internal.get(); + AssertNotNull(!ptr); + return *ptr; + } else { + return *internal; + } } T *operator->() const { - return internal.operator->(); + if (MemorySafety::ENABLED) { + const auto ptr = internal.get(); + AssertNotNull(!ptr); + return ptr; + } else { + return internal.operator->(); + } } // Relational operators diff --git a/src/include/duckdb/common/weak_ptr.ipp b/src/include/duckdb/common/weak_ptr.ipp index 4b04adaa49dc..f602aacb4830 100644 --- a/src/include/duckdb/common/weak_ptr.ipp +++ b/src/include/duckdb/common/weak_ptr.ipp @@ -19,7 +19,8 @@ public: } template - weak_ptr(shared_ptr const &ptr, typename std::enable_if<__compatible_with::value, int>::type = 0) noexcept + weak_ptr(shared_ptr const &ptr, + typename std::enable_if<__compatible_with::value, int>::type = 0) noexcept : internal(ptr.internal) { } weak_ptr(weak_ptr const &other) noexcept : internal(other.internal) { @@ -44,7 +45,7 @@ public: } template ::value, int> = 0> - weak_ptr &operator=(const shared_ptr &ptr) { + weak_ptr &operator=(const shared_ptr &ptr) { internal = ptr; return *this; } @@ -63,8 +64,8 @@ public: return internal.expired(); } - shared_ptr lock() const { - return shared_ptr(internal.lock()); + shared_ptr lock() const { + return shared_ptr(internal.lock()); } // Relational operators From 624f8adad66a85001100add98ee2b63ea48550d3 Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 9 Apr 2024 12:06:57 +0200 Subject: [PATCH 062/201] fix compilation error --- src/include/duckdb/common/shared_ptr.ipp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index 44496781a2ec..e5c96f6af02e 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -92,7 +92,7 @@ public: // Construct from auto_ptr #if _LIBCPP_STD_VER <= 14 || defined(_LIBCPP_ENABLE_CXX17_REMOVED_AUTO_PTR) - template ::value, int> = 0> + template ::value, int>::type = 0> shared_ptr(std::auto_ptr &&__r) : internal(__r.release()) { __enable_weak_this(internal.get(), internal.get()); } From df40291af5fc7eb9b4a4fbfd27863920c9461d2e Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 9 Apr 2024 12:23:48 +0200 Subject: [PATCH 063/201] use INVALID_INDEX-1 to indicate unlimited swap space --- src/main/settings/settings.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/main/settings/settings.cpp b/src/main/settings/settings.cpp index 5d25adc855c1..6d07bc839562 100644 --- a/src/main/settings/settings.cpp +++ b/src/main/settings/settings.cpp @@ -975,20 +975,18 @@ Value MaximumMemorySetting::GetSetting(ClientContext &context) { // Maximum Temp Directory Size //===--------------------------------------------------------------------===// void MaximumTempDirectorySize::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - idx_t maximum_swap_space = DConstants::INVALID_INDEX; - if (input.ToString() != "-1") { - maximum_swap_space = DBConfig::ParseMemoryLimit(input.ToString()); + auto maximum_swap_space = DBConfig::ParseMemoryLimit(input.ToString()); + if (maximum_swap_space == DConstants::INVALID_INDEX) { + // We use INVALID_INDEX to indicate that the value is not set by the user + // use one lower to indicate 'unlimited' + maximum_swap_space--; } if (!db) { config.options.maximum_swap_space = maximum_swap_space; return; } auto &buffer_manager = BufferManager::GetBufferManager(*db); - if (maximum_swap_space == DConstants::INVALID_INDEX) { - buffer_manager.SetSwapLimit(); - } else { - buffer_manager.SetSwapLimit(maximum_swap_space); - } + buffer_manager.SetSwapLimit(maximum_swap_space); config.options.maximum_swap_space = maximum_swap_space; } From 9ce6c39d4e2c052f52993542125f019ddf1cc024 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 9 Apr 2024 12:44:50 +0200 Subject: [PATCH 064/201] hex and array slice, my old nemesis --- src/core_functions/scalar/bit/bitstring.cpp | 8 +- .../scalar/list/array_slice.cpp | 34 ++-- src/core_functions/scalar/string/hex.cpp | 26 +-- src/include/duckdb/common/bit_utils.hpp | 161 +++++++----------- 4 files changed, 99 insertions(+), 130 deletions(-) diff --git a/src/core_functions/scalar/bit/bitstring.cpp b/src/core_functions/scalar/bit/bitstring.cpp index babfadfe01e7..fc1768850f07 100644 --- a/src/core_functions/scalar/bit/bitstring.cpp +++ b/src/core_functions/scalar/bit/bitstring.cpp @@ -19,9 +19,9 @@ static void BitStringFunction(DataChunk &args, ExpressionState &state, Vector &r idx_t len; Bit::TryGetBitStringSize(input, len, nullptr); // string verification - len = Bit::ComputeBitstringLen(n); + len = Bit::ComputeBitstringLen(UnsafeNumericCast(n)); string_t target = StringVector::EmptyString(result, len); - Bit::BitString(input, n, target); + Bit::BitString(input, UnsafeNumericCast(n), target); target.Finalize(); return target; }); @@ -41,7 +41,7 @@ struct GetBitOperator { throw OutOfRangeException("bit index %s out of valid range (0..%s)", NumericHelper::ToString(n), NumericHelper::ToString(Bit::BitLength(input) - 1)); } - return UnsafeNumericCast(Bit::GetBit(input, n)); + return UnsafeNumericCast(Bit::GetBit(input, UnsafeNumericCast(n))); } }; @@ -66,7 +66,7 @@ static void SetBitOperation(DataChunk &args, ExpressionState &state, Vector &res } string_t target = StringVector::EmptyString(result, input.GetSize()); memcpy(target.GetDataWriteable(), input.GetData(), input.GetSize()); - Bit::SetBit(target, n, new_value); + Bit::SetBit(target, UnsafeNumericCast(n), UnsafeNumericCast(new_value)); return target; }); } diff --git a/src/core_functions/scalar/list/array_slice.cpp b/src/core_functions/scalar/list/array_slice.cpp index 7651f410d023..f708caf431ff 100644 --- a/src/core_functions/scalar/list/array_slice.cpp +++ b/src/core_functions/scalar/list/array_slice.cpp @@ -47,15 +47,17 @@ static idx_t CalculateSliceLength(idx_t begin, idx_t end, INDEX_TYPE step, bool if (step == 0 && svalid) { throw InvalidInputException("Slice step cannot be zero"); } + auto step_unsigned = UnsafeNumericCast(step); // we called abs() on this above. + if (step == 1) { - return NumericCast(end - begin); - } else if (static_cast(step) >= (end - begin)) { + return NumericCast(end - begin); + } else if (step_unsigned >= (end - begin)) { return 1; } - if ((end - begin) % step != 0) { - return (end - begin) / step + 1; + if ((end - begin) % step_unsigned != 0) { + return (end - begin) / step_unsigned + 1; } - return (end - begin) / step; + return (end - begin) / step_unsigned; } template @@ -65,7 +67,7 @@ INDEX_TYPE ValueLength(const INPUT_TYPE &value) { template <> int64_t ValueLength(const list_entry_t &value) { - return value.length; + return UnsafeNumericCast(value.length); } template <> @@ -119,8 +121,8 @@ INPUT_TYPE SliceValue(Vector &result, INPUT_TYPE input, INDEX_TYPE begin, INDEX_ template <> list_entry_t SliceValue(Vector &result, list_entry_t input, int64_t begin, int64_t end) { - input.offset += begin; - input.length = end - begin; + input.offset = UnsafeNumericCast(UnsafeNumericCast(input.offset) + begin); + input.length = UnsafeNumericCast(end - begin); return input; } @@ -144,15 +146,15 @@ list_entry_t SliceValueWithSteps(Vector &result, SelectionVector &sel, list_entr input.offset = sel_idx; return input; } - input.length = CalculateSliceLength(begin, end, step, true); - idx_t child_idx = input.offset + begin; + input.length = CalculateSliceLength(UnsafeNumericCast(begin), UnsafeNumericCast(end), step, true); + int64_t child_idx = UnsafeNumericCast(input.offset) + begin; if (step < 0) { - child_idx = input.offset + end - 1; + child_idx = UnsafeNumericCast(input.offset) + end - 1; } input.offset = sel_idx; for (idx_t i = 0; i < input.length; i++) { - sel.set_index(sel_idx, child_idx); - child_idx += step; + sel.set_index(sel_idx, UnsafeNumericCast(child_idx)); + child_idx = UnsafeNumericCast(child_idx) + step; sel_idx++; } return input; @@ -194,7 +196,8 @@ static void ExecuteConstantSlice(Vector &result, Vector &str_vector, Vector &beg idx_t sel_length = 0; bool sel_valid = false; if (step_vector && step_valid && str_valid && begin_valid && end_valid && step != 1 && end - begin > 0) { - sel_length = CalculateSliceLength(begin, end, step, step_valid); + sel_length = + CalculateSliceLength(UnsafeNumericCast(begin), UnsafeNumericCast(end), step, step_valid); sel.Initialize(sel_length); sel_valid = true; } @@ -268,7 +271,8 @@ static void ExecuteFlatSlice(Vector &result, Vector &list_vector, Vector &begin_ idx_t length = 0; if (end - begin > 0) { - length = CalculateSliceLength(begin, end, step, step_valid); + length = + CalculateSliceLength(UnsafeNumericCast(begin), UnsafeNumericCast(end), step, step_valid); } sel_length += length; diff --git a/src/core_functions/scalar/string/hex.cpp b/src/core_functions/scalar/string/hex.cpp index dffbae70d030..f399b65cfbc2 100644 --- a/src/core_functions/scalar/string/hex.cpp +++ b/src/core_functions/scalar/string/hex.cpp @@ -90,7 +90,7 @@ struct HexIntegralOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(input); + auto num_leading_zero = CountZeros::Leading(static_cast(input)); idx_t num_bits_to_check = 64 - num_leading_zero; D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); @@ -109,7 +109,7 @@ struct HexIntegralOperator { auto target = StringVector::EmptyString(result, buffer_size); auto output = target.GetDataWriteable(); - WriteHexBytes(input, output, buffer_size); + WriteHexBytes(static_cast(input), output, buffer_size); target.Finalize(); return target; @@ -120,7 +120,7 @@ struct HexHugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(input); + idx_t num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); // Special case: All bits are zero @@ -147,7 +147,7 @@ struct HexUhugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(input); + idx_t num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); // Special case: All bits are zero @@ -189,7 +189,7 @@ struct BinaryStrOperator { auto output = target.GetDataWriteable(); for (idx_t i = 0; i < size; ++i) { - uint8_t byte = data[i]; + auto byte = static_cast(data[i]); for (idx_t i = 8; i >= 1; --i) { *output = ((byte >> (i - 1)) & 0x01) + '0'; output++; @@ -205,7 +205,7 @@ struct BinaryIntegralOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(input); + auto num_leading_zero = CountZeros::Leading(static_cast(input)); idx_t num_bits_to_check = 64 - num_leading_zero; D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); @@ -224,7 +224,7 @@ struct BinaryIntegralOperator { auto target = StringVector::EmptyString(result, buffer_size); auto output = target.GetDataWriteable(); - WriteBinBytes(input, output, buffer_size); + WriteBinBytes(static_cast(input), output, buffer_size); target.Finalize(); return target; @@ -234,7 +234,7 @@ struct BinaryIntegralOperator { struct BinaryHugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(input); + auto num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); idx_t buffer_size = sizeof(INPUT_TYPE) * 8 - num_leading_zero; // Special case: All bits are zero @@ -259,7 +259,7 @@ struct BinaryHugeIntOperator { struct BinaryUhugeIntOperator { template static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - idx_t num_leading_zero = CountZeros::Leading(input); + auto num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); idx_t buffer_size = sizeof(INPUT_TYPE) * 8 - num_leading_zero; // Special case: All bits are zero @@ -301,7 +301,7 @@ struct FromHexOperator { // Treated as a single byte idx_t i = 0; if (size % 2 != 0) { - *output = StringUtil::GetHexValue(data[i]); + *output = static_cast(StringUtil::GetHexValue(data[i])); i++; output++; } @@ -309,7 +309,7 @@ struct FromHexOperator { for (; i < size; i += 2) { uint8_t major = StringUtil::GetHexValue(data[i]); uint8_t minor = StringUtil::GetHexValue(data[i + 1]); - *output = UnsafeNumericCast((major << 4) | minor); + *output = static_cast((major << 4) | minor); output++; } @@ -343,7 +343,7 @@ struct FromBinaryOperator { byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); i++; } - *output = byte; + *output = static_cast(byte); // binary eh output++; } @@ -353,7 +353,7 @@ struct FromBinaryOperator { byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); i++; } - *output = byte; + *output = static_cast(byte); output++; } diff --git a/src/include/duckdb/common/bit_utils.hpp b/src/include/duckdb/common/bit_utils.hpp index 3c4f4a6bfcb8..0b9c325d9c23 100644 --- a/src/include/duckdb/common/bit_utils.hpp +++ b/src/include/duckdb/common/bit_utils.hpp @@ -10,68 +10,7 @@ #include "duckdb/common/hugeint.hpp" #include "duckdb/common/uhugeint.hpp" - -#if defined(_MSC_VER) && !defined(__clang__) -#define __restrict__ -#define __BYTE_ORDER__ __ORDER_LITTLE_ENDIAN__ -#define __ORDER_LITTLE_ENDIAN__ 2 -#include -static inline int __builtin_ctzll(unsigned long long x) { -#ifdef _WIN64 - unsigned long ret; - _BitScanForward64(&ret, x); - return (int)ret; -#else - unsigned long low, high; - bool low_set = _BitScanForward(&low, (unsigned __int32)(x)) != 0; - _BitScanForward(&high, (unsigned __int32)(x >> 32)); - high += 32; - return low_set ? low : high; -#endif -} -static inline int __builtin_clzll(unsigned long long mask) { - unsigned long where; -// BitScanReverse scans from MSB to LSB for first set bit. -// Returns 0 if no set bit is found. -#if defined(_WIN64) - if (_BitScanReverse64(&where, mask)) - return static_cast(63 - where); -#elif defined(_WIN32) - // Scan the high 32 bits. - if (_BitScanReverse(&where, static_cast(mask >> 32))) - return static_cast(63 - (where + 32)); // Create a bit offset from the MSB. - // Scan the low 32 bits. - if (_BitScanReverse(&where, static_cast(mask))) - return static_cast(63 - where); -#else -#error "Implementation of __builtin_clzll required" -#endif - return 64; // Undefined Behavior. -} - -static inline int __builtin_ctz(unsigned int value) { - unsigned long trailing_zero = 0; - - if (_BitScanForward(&trailing_zero, value)) { - return trailing_zero; - } else { - // This is undefined, I better choose 32 than 0 - return 32; - } -} - -static inline int __builtin_clz(unsigned int value) { - unsigned long leading_zero = 0; - - if (_BitScanReverse(&leading_zero, value)) { - return 31 - leading_zero; - } else { - // Same remarks as above - return 32; - } -} - -#endif +#include "duckdb/common/numeric_utils.hpp" namespace duckdb { @@ -79,60 +18,86 @@ template struct CountZeros {}; template <> -struct CountZeros { - inline static int Leading(uint32_t value) { - if (!value) { - return 32; +struct CountZeros { + // see here: https://en.wikipedia.org/wiki/De_Bruijn_sequence + inline static idx_t Leading(const uint64_t value_in) { + if (!value_in) { + return 64; } - return __builtin_clz(value); + + uint64_t value = value_in; + + constexpr uint64_t index64msb[] = {0, 47, 1, 56, 48, 27, 2, 60, 57, 49, 41, 37, 28, 16, 3, 61, + 54, 58, 35, 52, 50, 42, 21, 44, 38, 32, 29, 23, 17, 11, 4, 62, + 46, 55, 26, 59, 40, 36, 15, 53, 34, 51, 20, 43, 31, 22, 10, 45, + 25, 39, 14, 33, 19, 30, 9, 24, 13, 18, 8, 12, 7, 6, 5, 63}; + + constexpr uint64_t debruijn64msb = 0X03F79D71B4CB0A89; + + value |= value >> 1; + value |= value >> 2; + value |= value >> 4; + value |= value >> 8; + value |= value >> 16; + value |= value >> 32; + auto result = 63 - index64msb[(value * debruijn64msb) >> 58]; +#ifdef __clang__ + D_ASSERT(result == static_cast(__builtin_clzl(value_in))); +#endif + return result; } - inline static int Trailing(uint32_t value) { - if (!value) { - return 32; + inline static idx_t Trailing(uint64_t value_in) { + if (!value_in) { + return 64; } - return __builtin_ctz(value); + uint64_t value = value_in; + + constexpr uint64_t index64lsb[] = {63, 0, 58, 1, 59, 47, 53, 2, 60, 39, 48, 27, 54, 33, 42, 3, + 61, 51, 37, 40, 49, 18, 28, 20, 55, 30, 34, 11, 43, 14, 22, 4, + 62, 57, 46, 52, 38, 26, 32, 41, 50, 36, 17, 19, 29, 10, 13, 21, + 56, 45, 25, 31, 35, 16, 9, 12, 44, 24, 15, 8, 23, 7, 6, 5}; + constexpr uint64_t debruijn64lsb = 0x07EDD5E59A4E28C2ULL; + auto result = index64lsb[((value & -value) * debruijn64lsb) >> 58] - 1; +#ifdef __clang__ + D_ASSERT(result == static_cast(__builtin_clzl(value_in))); +#endif + return result; } }; template <> -struct CountZeros { - inline static int Leading(uint64_t value) { - if (!value) { - return 64; - } - return __builtin_clzll(value); +struct CountZeros { + inline static idx_t Leading(uint32_t value) { + return CountZeros::Leading(static_cast(value)) - 32; } - inline static int Trailing(uint64_t value) { - if (!value) { - return 64; - } - return __builtin_ctzll(value); + inline static idx_t Trailing(uint32_t value) { + return CountZeros::Trailing(static_cast(value)); } }; template <> struct CountZeros { - inline static int Leading(hugeint_t value) { - const uint64_t upper = (uint64_t)value.upper; + inline static idx_t Leading(hugeint_t value) { + const uint64_t upper = static_cast(value.upper); const uint64_t lower = value.lower; if (upper) { - return __builtin_clzll(upper); + return CountZeros::Leading(upper); } else if (lower) { - return 64 + __builtin_clzll(lower); + return 64 + CountZeros::Leading(lower); } else { return 128; } } - inline static int Trailing(hugeint_t value) { - const uint64_t upper = (uint64_t)value.upper; + inline static idx_t Trailing(hugeint_t value) { + const uint64_t upper = static_cast(value.upper); const uint64_t lower = value.lower; if (lower) { - return __builtin_ctzll(lower); + return CountZeros::Trailing(lower); } else if (upper) { - return 64 + __builtin_ctzll(upper); + return 64 + CountZeros::Trailing(upper); } else { return 128; } @@ -141,27 +106,27 @@ struct CountZeros { template <> struct CountZeros { - inline static int Leading(uhugeint_t value) { - const uint64_t upper = (uint64_t)value.upper; + inline static idx_t Leading(uhugeint_t value) { + const uint64_t upper = static_cast(value.upper); const uint64_t lower = value.lower; if (upper) { - return __builtin_clzll(upper); + return CountZeros::Leading(upper); } else if (lower) { - return 64 + __builtin_clzll(lower); + return 64 + CountZeros::Leading(lower); } else { return 128; } } - inline static int Trailing(uhugeint_t value) { - const uint64_t upper = (uint64_t)value.upper; + inline static idx_t Trailing(uhugeint_t value) { + const uint64_t upper = static_cast(value.upper); const uint64_t lower = value.lower; if (lower) { - return __builtin_ctzll(lower); + return CountZeros::Trailing(lower); } else if (upper) { - return 64 + __builtin_ctzll(upper); + return 64 + CountZeros::Trailing(upper); } else { return 128; } From 27b6d2864d9fb7924e52b00b6f1828d9d4a31ba3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 9 Apr 2024 14:07:44 +0200 Subject: [PATCH 065/201] strftime --- src/function/scalar/strftime_format.cpp | 69 +++++++++++++------------ src/include/duckdb/common/bit_utils.hpp | 4 +- 2 files changed, 39 insertions(+), 34 deletions(-) diff --git a/src/function/scalar/strftime_format.cpp b/src/function/scalar/strftime_format.cpp index 5181b005ed45..5784ed91b895 100644 --- a/src/function/scalar/strftime_format.cpp +++ b/src/function/scalar/strftime_format.cpp @@ -80,7 +80,7 @@ idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date if (0 <= year && year <= 9999) { return 4; } else { - return NumericHelper::SignedLength(year); + return UnsafeNumericCast(NumericHelper::SignedLength(year)); } } case StrTimeSpecifier::MONTH_DECIMAL: { @@ -129,11 +129,14 @@ idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date return len; } case StrTimeSpecifier::DAY_OF_MONTH: - return NumericHelper::UnsignedLength(Date::ExtractDay(date)); + return UnsafeNumericCast( + NumericHelper::UnsignedLength(UnsafeNumericCast(Date::ExtractDay(date)))); case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: - return NumericHelper::UnsignedLength(Date::ExtractDayOfTheYear(date)); + return UnsafeNumericCast( + NumericHelper::UnsignedLength(UnsafeNumericCast(Date::ExtractDayOfTheYear(date)))); case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - return NumericHelper::UnsignedLength(AbsValue(Date::ExtractYear(date)) % 100); + return UnsafeNumericCast(NumericHelper::UnsignedLength( + UnsafeNumericCast(AbsValue(Date::ExtractYear(date)) % 100))); default: throw InternalException("Unimplemented specifier for GetSpecifierLength"); } @@ -195,13 +198,13 @@ char *StrfTimeFormat::WritePadded(char *target, uint32_t value, size_t padding) D_ASSERT(padding > 1); if (padding % 2) { int decimals = value % 1000; - WritePadded3(target + padding - 3, decimals); + WritePadded3(target + padding - 3, UnsafeNumericCast(decimals)); value /= 1000; padding -= 3; } for (size_t i = 0; i < padding / 2; i++) { int decimals = value % 100; - WritePadded2(target + padding - 2 * (i + 1), decimals); + WritePadded2(target + padding - 2 * (i + 1), UnsafeNumericCast(decimals)); value /= 100; } return target + padding; @@ -245,26 +248,26 @@ char *StrfTimeFormat::WriteDateSpecifier(StrTimeSpecifier specifier, date_t date } case StrTimeSpecifier::DAY_OF_YEAR_PADDED: { int32_t doy = Date::ExtractDayOfTheYear(date); - target = WritePadded3(target, doy); + target = WritePadded3(target, UnsafeNumericCast(doy)); break; } case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - target = WritePadded2(target, Date::ExtractWeekNumberRegular(date, true)); + target = WritePadded2(target, UnsafeNumericCast(Date::ExtractWeekNumberRegular(date, true))); break; case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - target = WritePadded2(target, Date::ExtractWeekNumberRegular(date, false)); + target = WritePadded2(target, UnsafeNumericCast(Date::ExtractWeekNumberRegular(date, false))); break; case StrTimeSpecifier::WEEK_NUMBER_ISO: - target = WritePadded2(target, Date::ExtractISOWeekNumber(date)); + target = WritePadded2(target, UnsafeNumericCast(Date::ExtractISOWeekNumber(date))); break; case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: { - uint32_t doy = Date::ExtractDayOfTheYear(date); + auto doy = UnsafeNumericCast(Date::ExtractDayOfTheYear(date)); target += NumericHelper::UnsignedLength(doy); NumericHelper::FormatUnsigned(doy, target); break; } case StrTimeSpecifier::YEAR_ISO: - target = WritePadded(target, Date::ExtractISOYearNumber(date), 4); + target = WritePadded(target, UnsafeNumericCast(Date::ExtractISOYearNumber(date)), 4); break; case StrTimeSpecifier::WEEKDAY_ISO: *target = char('0' + uint8_t(Date::ExtractISODayOfTheWeek(date))); @@ -281,7 +284,7 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t // data contains [0] year, [1] month, [2] day, [3] hour, [4] minute, [5] second, [6] msec, [7] utc switch (specifier) { case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - target = WritePadded2(target, data[2]); + target = WritePadded2(target, UnsafeNumericCast(data[2])); break; case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: { auto &month_name = Date::MONTH_NAMES_ABBREVIATED[data[1] - 1]; @@ -292,14 +295,14 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t return WriteString(target, month_name); } case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - target = WritePadded2(target, data[1]); + target = WritePadded2(target, UnsafeNumericCast(data[1])); break; case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - target = WritePadded2(target, AbsValue(data[0]) % 100); + target = WritePadded2(target, UnsafeNumericCast(AbsValue(data[0]) % 100)); break; case StrTimeSpecifier::YEAR_DECIMAL: if (data[0] >= 0 && data[0] <= 9999) { - target = WritePadded(target, data[0], 4); + target = WritePadded(target, UnsafeNumericCast(data[0]), 4); } else { int32_t year = data[0]; if (data[0] < 0) { @@ -307,13 +310,13 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t year = -year; target++; } - auto len = NumericHelper::UnsignedLength(year); + auto len = NumericHelper::UnsignedLength(UnsafeNumericCast(year)); NumericHelper::FormatUnsigned(year, target + len); target += len; } break; case StrTimeSpecifier::HOUR_24_PADDED: { - target = WritePadded2(target, data[3]); + target = WritePadded2(target, UnsafeNumericCast(data[3])); break; } case StrTimeSpecifier::HOUR_12_PADDED: { @@ -321,7 +324,7 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t if (hour == 0) { hour = 12; } - target = WritePadded2(target, hour); + target = WritePadded2(target, UnsafeNumericCast(hour)); break; } case StrTimeSpecifier::AM_PM: @@ -329,20 +332,20 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t *target++ = 'M'; break; case StrTimeSpecifier::MINUTE_PADDED: { - target = WritePadded2(target, data[4]); + target = WritePadded2(target, UnsafeNumericCast(data[4])); break; } case StrTimeSpecifier::SECOND_PADDED: - target = WritePadded2(target, data[5]); + target = WritePadded2(target, UnsafeNumericCast(data[5])); break; case StrTimeSpecifier::NANOSECOND_PADDED: - target = WritePadded(target, data[6] * Interval::NANOS_PER_MICRO, 9); + target = WritePadded(target, UnsafeNumericCast(data[6] * Interval::NANOS_PER_MICRO), 9); break; case StrTimeSpecifier::MICROSECOND_PADDED: - target = WritePadded(target, data[6], 6); + target = WritePadded(target, UnsafeNumericCast(data[6]), 6); break; case StrTimeSpecifier::MILLISECOND_PADDED: - target = WritePadded3(target, data[6] / Interval::MICROS_PER_MSEC); + target = WritePadded3(target, UnsafeNumericCast(data[6] / Interval::MICROS_PER_MSEC)); break; case StrTimeSpecifier::UTC_OFFSET: { *target++ = (data[7] < 0) ? '-' : '+'; @@ -350,10 +353,10 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t auto offset = abs(data[7]); auto offset_hours = offset / Interval::MINS_PER_HOUR; auto offset_minutes = offset % Interval::MINS_PER_HOUR; - target = WritePadded2(target, offset_hours); + target = WritePadded2(target, UnsafeNumericCast(offset_hours)); if (offset_minutes) { *target++ = ':'; - target = WritePadded2(target, offset_minutes); + target = WritePadded2(target, UnsafeNumericCast(offset_minutes)); } break; } @@ -364,7 +367,7 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t } break; case StrTimeSpecifier::DAY_OF_MONTH: { - target = Write2(target, data[2] % 100); + target = Write2(target, UnsafeNumericCast(data[2] % 100)); break; } case StrTimeSpecifier::MONTH_DECIMAL: { @@ -372,7 +375,7 @@ char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t break; } case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: { - target = Write2(target, AbsValue(data[0]) % 100); + target = Write2(target, UnsafeNumericCast(AbsValue(data[0]) % 100)); break; } case StrTimeSpecifier::HOUR_24_DECIMAL: { @@ -845,9 +848,9 @@ bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { // numeric specifier: parse a number uint64_t number = 0; size_t start_pos = pos; - size_t end_pos = start_pos + numeric_width[i]; + size_t end_pos = start_pos + UnsafeNumericCast(numeric_width[i]); while (pos < size && pos < end_pos && StringUtil::CharacterIsDigit(data[pos])) { - number = number * 10 + data[pos] - '0'; + number = number * 10 + UnsafeNumericCast(data[pos]) - '0'; pos++; } if (pos == start_pos) { @@ -1229,7 +1232,7 @@ bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { // But tz must not be empty. if (tz_end == tz_begin) { error_message = "Empty Time Zone name"; - error_position = tz_begin - data; + error_position = UnsafeNumericCast(tz_begin - data); return false; } result.tz.assign(tz_begin, tz_end); @@ -1288,7 +1291,9 @@ bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: { // Adjust weekday to be 0-based for the week type if (has_weekday) { - weekday = (weekday + 7 - int(offset_specifier == StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST)) % 7; + weekday = (weekday + 7 - + static_cast(offset_specifier == StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST)) % + 7; } // Get the start of week 1, move back 7 days and then weekno * 7 + weekday gives the date const auto jan1 = Date::FromDate(result_data[0], 1, 1); diff --git a/src/include/duckdb/common/bit_utils.hpp b/src/include/duckdb/common/bit_utils.hpp index 0b9c325d9c23..766e951e9269 100644 --- a/src/include/duckdb/common/bit_utils.hpp +++ b/src/include/duckdb/common/bit_utils.hpp @@ -57,9 +57,9 @@ struct CountZeros { 62, 57, 46, 52, 38, 26, 32, 41, 50, 36, 17, 19, 29, 10, 13, 21, 56, 45, 25, 31, 35, 16, 9, 12, 44, 24, 15, 8, 23, 7, 6, 5}; constexpr uint64_t debruijn64lsb = 0x07EDD5E59A4E28C2ULL; - auto result = index64lsb[((value & -value) * debruijn64lsb) >> 58] - 1; + auto result = index64lsb[((value & -value) * debruijn64lsb) >> 58]; #ifdef __clang__ - D_ASSERT(result == static_cast(__builtin_clzl(value_in))); + D_ASSERT(result == static_cast(__builtin_ctzl(value_in))); #endif return result; } From f72498717469088b339efcc29878b6fb2179b8c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 9 Apr 2024 14:25:58 +0200 Subject: [PATCH 066/201] list_extract --- src/function/scalar/list/list_extract.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/function/scalar/list/list_extract.cpp b/src/function/scalar/list/list_extract.cpp index 822b9079cdc9..f70f12c1629e 100644 --- a/src/function/scalar/list/list_extract.cpp +++ b/src/function/scalar/list/list_extract.cpp @@ -62,13 +62,13 @@ void ListExtractTemplate(idx_t count, UnifiedVectorFormat &list_data, UnifiedVec result_mask.SetInvalid(i); continue; } - child_offset = list_entry.offset + list_entry.length + offsets_entry; + child_offset = UnsafeNumericCast(UnsafeNumericCast(list_entry.offset + list_entry.length) + offsets_entry); } else { if ((idx_t)offsets_entry >= list_entry.length) { result_mask.SetInvalid(i); continue; } - child_offset = list_entry.offset + offsets_entry; + child_offset = UnsafeNumericCast(UnsafeNumericCast(list_entry.offset) + offsets_entry); } auto child_index = child_format.sel->get_index(child_offset); if (child_format.validity.RowIsValid(child_index)) { From a2bee34cdfcb6935662024ae2cecb4e45ceba598 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 9 Apr 2024 15:00:54 +0200 Subject: [PATCH 067/201] list_extract --- src/function/scalar/list/list_extract.cpp | 3 ++- src/function/scalar/list/list_select.cpp | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/function/scalar/list/list_extract.cpp b/src/function/scalar/list/list_extract.cpp index f70f12c1629e..aa8278ce63c6 100644 --- a/src/function/scalar/list/list_extract.cpp +++ b/src/function/scalar/list/list_extract.cpp @@ -62,7 +62,8 @@ void ListExtractTemplate(idx_t count, UnifiedVectorFormat &list_data, UnifiedVec result_mask.SetInvalid(i); continue; } - child_offset = UnsafeNumericCast(UnsafeNumericCast(list_entry.offset + list_entry.length) + offsets_entry); + child_offset = UnsafeNumericCast(UnsafeNumericCast(list_entry.offset + list_entry.length) + + offsets_entry); } else { if ((idx_t)offsets_entry >= list_entry.length) { result_mask.SetInvalid(i); diff --git a/src/function/scalar/list/list_select.cpp b/src/function/scalar/list/list_select.cpp index 70ae15194fcf..bf77d9e9dcc0 100644 --- a/src/function/scalar/list/list_select.cpp +++ b/src/function/scalar/list/list_select.cpp @@ -12,9 +12,10 @@ struct SetSelectionVectorSelect { idx_t &target_offset, idx_t selection_offset, idx_t input_offset, idx_t target_length) { auto sel_idx = selection_entry.GetValue(selection_offset + child_idx).GetValue() - 1; - if (sel_idx < target_length) { - selection_vector.set_index(target_offset, input_offset + sel_idx); - if (!input_validity.RowIsValid(input_offset + sel_idx)) { + if (sel_idx >= 0 && sel_idx < UnsafeNumericCast(target_length)) { + auto sel_idx_unsigned = UnsafeNumericCast(sel_idx); + selection_vector.set_index(target_offset, input_offset + sel_idx_unsigned); + if (!input_validity.RowIsValid(input_offset + sel_idx_unsigned)) { validity_mask.SetInvalid(target_offset); } } else { From 36e8260677a9eb7e5b4440a139454f0c75cfb638 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 9 Apr 2024 15:20:01 +0200 Subject: [PATCH 068/201] cast_helpers my other nemesis --- src/common/operator/string_cast.cpp | 16 ++++++++-------- .../duckdb/common/types/cast_helpers.hpp | 17 ++++++++++------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/common/operator/string_cast.cpp b/src/common/operator/string_cast.cpp index 0fd1f5f0a0ca..be3d604910bf 100644 --- a/src/common/operator/string_cast.cpp +++ b/src/common/operator/string_cast.cpp @@ -24,37 +24,37 @@ string_t StringCast::Operation(bool input, Vector &vector) { template <> string_t StringCast::Operation(int8_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); + return NumericHelper::FormatSigned(input, vector); } template <> string_t StringCast::Operation(int16_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); + return NumericHelper::FormatSigned(input, vector); } template <> string_t StringCast::Operation(int32_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); + return NumericHelper::FormatSigned(input, vector); } template <> string_t StringCast::Operation(int64_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); + return NumericHelper::FormatSigned(input, vector); } template <> duckdb::string_t StringCast::Operation(uint8_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); + return NumericHelper::FormatSigned(input, vector); } template <> duckdb::string_t StringCast::Operation(uint16_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); + return NumericHelper::FormatSigned(input, vector); } template <> duckdb::string_t StringCast::Operation(uint32_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); + return NumericHelper::FormatSigned(input, vector); } template <> duckdb::string_t StringCast::Operation(uint64_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); + return NumericHelper::FormatSigned(input, vector); } template <> diff --git a/src/include/duckdb/common/types/cast_helpers.hpp b/src/include/duckdb/common/types/cast_helpers.hpp index 358438da3c17..7dd91f3a5f8a 100644 --- a/src/include/duckdb/common/types/cast_helpers.hpp +++ b/src/include/duckdb/common/types/cast_helpers.hpp @@ -59,16 +59,19 @@ class NumericHelper { return ptr; } - template - static string_t FormatSigned(SIGNED value, Vector &vector) { - int sign = -(value < 0); - UNSIGNED unsigned_value = UnsafeNumericCast(UNSIGNED(value ^ sign) - sign); - int length = UnsignedLength(unsigned_value) - sign; - string_t result = StringVector::EmptyString(vector, NumericCast(length)); + template + static string_t FormatSigned(T value, Vector &vector) { + auto is_negative = (value < 0); + auto unsigned_value = static_cast::type>(AbsValue(value)); + auto length = UnsignedLength(unsigned_value); + if (is_negative) { + length++; + } + auto result = StringVector::EmptyString(vector, UnsafeNumericCast(length)); auto dataptr = result.GetDataWriteable(); auto endptr = dataptr + length; endptr = FormatUnsigned(unsigned_value, endptr); - if (sign) { + if (is_negative) { *--endptr = '-'; } result.Finalize(); From 9f6a64a953156185669285fc0a54482482164dad Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 9 Apr 2024 17:58:39 +0200 Subject: [PATCH 069/201] fix compilation issues --- src/include/duckdb/common/shared_ptr.ipp | 10 +--------- src/include/duckdb/common/weak_ptr.ipp | 4 ++-- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index e5c96f6af02e..0be8ab5bdfe0 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -90,14 +90,6 @@ public: explicit shared_ptr(weak_ptr other) : internal(other.internal) { } - // Construct from auto_ptr -#if _LIBCPP_STD_VER <= 14 || defined(_LIBCPP_ENABLE_CXX17_REMOVED_AUTO_PTR) - template ::value, int>::type = 0> - shared_ptr(std::auto_ptr &&__r) : internal(__r.release()) { - __enable_weak_this(internal.get(), internal.get()); - } -#endif - // Construct from unique_ptr, takes over ownership of the unique_ptr template ::value && @@ -185,7 +177,7 @@ public: return internal.operator bool(); } - std::__add_lvalue_reference_t operator*() const { + typename std::add_lvalue_reference::type operator*() const { if (MemorySafety::ENABLED) { const auto ptr = internal.get(); AssertNotNull(!ptr); diff --git a/src/include/duckdb/common/weak_ptr.ipp b/src/include/duckdb/common/weak_ptr.ipp index f602aacb4830..dc8115b33717 100644 --- a/src/include/duckdb/common/weak_ptr.ipp +++ b/src/include/duckdb/common/weak_ptr.ipp @@ -44,9 +44,9 @@ public: return *this; } - template ::value, int> = 0> + template ::value, int>::type = 0> weak_ptr &operator=(const shared_ptr &ptr) { - internal = ptr; + internal = ptr.internal; return *this; } From 6b8c631a6b5ac4699bf56f5366b167724facbe4b Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 9 Apr 2024 17:59:13 +0200 Subject: [PATCH 070/201] test setting the maximum swap space to unlimited --- .../max_swap_space_unlimited.test | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 test/sql/storage/temp_directory/max_swap_space_unlimited.test diff --git a/test/sql/storage/temp_directory/max_swap_space_unlimited.test b/test/sql/storage/temp_directory/max_swap_space_unlimited.test new file mode 100644 index 000000000000..86dfdd7f4570 --- /dev/null +++ b/test/sql/storage/temp_directory/max_swap_space_unlimited.test @@ -0,0 +1,18 @@ +# name: test/sql/storage/temp_directory/max_swap_space_unlimited.test +# group: [temp_directory] + +require skip_reload + +require noforcestorage + +# Set a temp_directory to offload data +statement ok +set temp_directory='__TEST_DIR__/max_swap_space_reached' + +statement ok +PRAGMA max_temp_directory_size='-1'; + +query I +select current_setting('max_temp_directory_size') +---- +16383.9 PiB From 7097f39819edc1cc40f7d55e854ce6628e2d05fb Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 9 Apr 2024 22:23:14 +0200 Subject: [PATCH 071/201] surprisingly painless conversion from std::shared_ptr to duckdb::shared_ptr for the python package --- tools/odbc/include/duckdb_odbc.hpp | 3 +- .../duckdb_python/expression/pyexpression.hpp | 6 +- .../conversions/pyconnection_default.hpp | 2 +- .../duckdb_python/pybind11/pybind_wrapper.hpp | 1 + .../pyconnection/pyconnection.hpp | 3 +- .../src/include/duckdb_python/pyrelation.hpp | 2 +- .../src/include/duckdb_python/pytype.hpp | 2 +- tools/pythonpkg/src/pyconnection.cpp | 18 +++--- .../src/pyconnection/type_creation.cpp | 16 ++--- tools/pythonpkg/src/pyexpression.cpp | 26 ++++---- tools/pythonpkg/src/pyrelation.cpp | 6 +- tools/pythonpkg/src/typing/pytype.cpp | 25 ++++---- tools/pythonpkg/src/typing/typing.cpp | 64 +++++++++---------- 13 files changed, 90 insertions(+), 84 deletions(-) diff --git a/tools/odbc/include/duckdb_odbc.hpp b/tools/odbc/include/duckdb_odbc.hpp index 580cc86e9bdf..14b92ff1df5c 100644 --- a/tools/odbc/include/duckdb_odbc.hpp +++ b/tools/odbc/include/duckdb_odbc.hpp @@ -3,6 +3,7 @@ // needs to be first because BOOL #include "duckdb.hpp" +#include "duckdb/common/shared_ptr.hpp" #include "duckdb/common/windows.hpp" #include "descriptor.hpp" @@ -41,7 +42,7 @@ struct OdbcHandleEnv : public OdbcHandle { OdbcHandleEnv() : OdbcHandle(OdbcHandleType::ENV) { duckdb::DBConfig ODBC_CONFIG; ODBC_CONFIG.SetOptionByName("duckdb_api", "odbc"); - db = make_shared(nullptr, &ODBC_CONFIG); + db = make_refcounted(nullptr, &ODBC_CONFIG); }; shared_ptr db; diff --git a/tools/pythonpkg/src/include/duckdb_python/expression/pyexpression.hpp b/tools/pythonpkg/src/include/duckdb_python/expression/pyexpression.hpp index 2083a488a239..2ac7bc1f91b5 100644 --- a/tools/pythonpkg/src/include/duckdb_python/expression/pyexpression.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/expression/pyexpression.hpp @@ -22,14 +22,14 @@ namespace duckdb { -struct DuckDBPyExpression : public std::enable_shared_from_this { +struct DuckDBPyExpression : public enable_shared_from_this { public: explicit DuckDBPyExpression(unique_ptr expr, OrderType order_type = OrderType::ORDER_DEFAULT, OrderByNullType null_order = OrderByNullType::ORDER_DEFAULT); public: - std::shared_ptr shared_from_this() { - return std::enable_shared_from_this::shared_from_this(); + shared_ptr shared_from_this() { + return enable_shared_from_this::shared_from_this(); } public: diff --git a/tools/pythonpkg/src/include/duckdb_python/pybind11/conversions/pyconnection_default.hpp b/tools/pythonpkg/src/include/duckdb_python/pybind11/conversions/pyconnection_default.hpp index 1c8908b98966..d6ad6979111d 100644 --- a/tools/pythonpkg/src/include/duckdb_python/pybind11/conversions/pyconnection_default.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/pybind11/conversions/pyconnection_default.hpp @@ -37,7 +37,7 @@ class type_caster> }; template <> -struct is_holder_type> : std::true_type {}; +struct is_holder_type> : std::true_type {}; } // namespace detail } // namespace PYBIND11_NAMESPACE diff --git a/tools/pythonpkg/src/include/duckdb_python/pybind11/pybind_wrapper.hpp b/tools/pythonpkg/src/include/duckdb_python/pybind11/pybind_wrapper.hpp index 5c3be44eb6eb..56200f0bb59e 100644 --- a/tools/pythonpkg/src/include/duckdb_python/pybind11/pybind_wrapper.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/pybind11/pybind_wrapper.hpp @@ -17,6 +17,7 @@ #include PYBIND11_DECLARE_HOLDER_TYPE(T, duckdb::unique_ptr) +PYBIND11_DECLARE_HOLDER_TYPE(T, duckdb::shared_ptr) namespace pybind11 { diff --git a/tools/pythonpkg/src/include/duckdb_python/pyconnection/pyconnection.hpp b/tools/pythonpkg/src/include/duckdb_python/pyconnection/pyconnection.hpp index a0d572dc9a91..91cf2be1a898 100644 --- a/tools/pythonpkg/src/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/pyconnection/pyconnection.hpp @@ -22,6 +22,7 @@ #include "duckdb/function/scalar_function.hpp" #include "duckdb_python/pybind11/conversions/exception_handling_enum.hpp" #include "duckdb_python/pybind11/conversions/python_udf_type_enum.hpp" +#include "duckdb/common/shared_ptr.hpp" namespace duckdb { @@ -37,7 +38,7 @@ class RegisteredArrow : public RegisteredObject { unique_ptr arrow_factory; }; -struct DuckDBPyConnection : public std::enable_shared_from_this { +struct DuckDBPyConnection : public enable_shared_from_this { public: shared_ptr database; unique_ptr connection; diff --git a/tools/pythonpkg/src/include/duckdb_python/pyrelation.hpp b/tools/pythonpkg/src/include/duckdb_python/pyrelation.hpp index e6cd967558b4..629299f6b3db 100644 --- a/tools/pythonpkg/src/include/duckdb_python/pyrelation.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/pyrelation.hpp @@ -69,7 +69,7 @@ struct DuckDBPyRelation { py::str GetAlias(); - static unique_ptr EmptyResult(const std::shared_ptr &context, + static unique_ptr EmptyResult(const shared_ptr &context, const vector &types, vector names); unique_ptr SetAlias(const string &expr); diff --git a/tools/pythonpkg/src/include/duckdb_python/pytype.hpp b/tools/pythonpkg/src/include/duckdb_python/pytype.hpp index 72758bf112f1..a6e13dfd68e9 100644 --- a/tools/pythonpkg/src/include/duckdb_python/pytype.hpp +++ b/tools/pythonpkg/src/include/duckdb_python/pytype.hpp @@ -21,7 +21,7 @@ class PyUnionType : public py::object { static bool check_(const py::handle &object); }; -class DuckDBPyType : public std::enable_shared_from_this { +class DuckDBPyType : public enable_shared_from_this { public: explicit DuckDBPyType(LogicalType type); diff --git a/tools/pythonpkg/src/pyconnection.cpp b/tools/pythonpkg/src/pyconnection.cpp index dd366334e71e..548d83a0331d 100644 --- a/tools/pythonpkg/src/pyconnection.cpp +++ b/tools/pythonpkg/src/pyconnection.cpp @@ -50,6 +50,7 @@ #include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" #include "duckdb/main/pending_query_result.hpp" #include "duckdb/parser/keyword_helper.hpp" +#include "duckdb/common/shared_ptr.hpp" #include @@ -639,7 +640,7 @@ void DuckDBPyConnection::RegisterArrowObject(const py::object &arrow_object, con } vector> dependencies; dependencies.push_back( - make_shared(make_uniq(std::move(stream_factory), arrow_object))); + make_refcounted(make_uniq(std::move(stream_factory), arrow_object))); connection->context->external_dependencies[name] = std::move(dependencies); } @@ -664,8 +665,8 @@ shared_ptr DuckDBPyConnection::RegisterPythonObject(const st // keep a reference vector> dependencies; - dependencies.push_back(make_shared(make_uniq(python_object), - make_uniq(new_df))); + dependencies.push_back(make_refcounted(make_uniq(python_object), + make_uniq(new_df))); connection->context->external_dependencies[name] = std::move(dependencies); } } else if (IsAcceptedArrowObject(python_object) || IsPolarsDataframe(python_object)) { @@ -774,7 +775,8 @@ unique_ptr DuckDBPyConnection::ReadJSON(const string &name, co auto_detect = true; } - auto read_json_relation = make_shared(connection->context, name, std::move(options), auto_detect); + auto read_json_relation = + make_refcounted(connection->context, name, std::move(options), auto_detect); if (read_json_relation == nullptr) { throw BinderException("read_json can only be used when the JSON extension is (statically) loaded"); } @@ -1317,7 +1319,7 @@ shared_ptr DuckDBPyConnection::Cursor() { if (!connection) { throw ConnectionException("Connection has already been closed"); } - auto res = make_shared(); + auto res = make_refcounted(); res->database = database; res->connection = make_uniq(*res->database); cursors.push_back(res); @@ -1596,7 +1598,7 @@ static void SetDefaultConfigArguments(ClientContext &context) { } static shared_ptr FetchOrCreateInstance(const string &database, DBConfig &config) { - auto res = make_shared(); + auto res = make_refcounted(); res->database = instance_cache.GetInstance(database, config); if (!res->database) { //! No cached database, we must create a new instance @@ -1674,7 +1676,7 @@ shared_ptr DuckDBPyConnection::DefaultConnection() { PythonImportCache *DuckDBPyConnection::ImportCache() { if (!import_cache) { - import_cache = make_shared(); + import_cache = make_refcounted(); } return import_cache.get(); } @@ -1688,7 +1690,7 @@ ModifiedMemoryFileSystem &DuckDBPyConnection::GetObjectFileSystem() { throw InvalidInputException( "This operation could not be completed because required module 'fsspec' is not installed"); } - internal_object_filesystem = make_shared(modified_memory_fs()); + internal_object_filesystem = make_refcounted(modified_memory_fs()); auto &abstract_fs = reinterpret_cast(*internal_object_filesystem); RegisterFilesystem(abstract_fs); } diff --git a/tools/pythonpkg/src/pyconnection/type_creation.cpp b/tools/pythonpkg/src/pyconnection/type_creation.cpp index 5888fd84bb24..91860e7f936e 100644 --- a/tools/pythonpkg/src/pyconnection/type_creation.cpp +++ b/tools/pythonpkg/src/pyconnection/type_creation.cpp @@ -5,17 +5,17 @@ namespace duckdb { shared_ptr DuckDBPyConnection::MapType(const shared_ptr &key_type, const shared_ptr &value_type) { auto map_type = LogicalType::MAP(key_type->Type(), value_type->Type()); - return make_shared(map_type); + return make_refcounted(map_type); } shared_ptr DuckDBPyConnection::ListType(const shared_ptr &type) { auto array_type = LogicalType::LIST(type->Type()); - return make_shared(array_type); + return make_refcounted(array_type); } shared_ptr DuckDBPyConnection::ArrayType(const shared_ptr &type, idx_t size) { auto array_type = LogicalType::ARRAY(type->Type(), size); - return make_shared(array_type); + return make_refcounted(array_type); } static child_list_t GetChildList(const py::object &container) { @@ -59,7 +59,7 @@ shared_ptr DuckDBPyConnection::StructType(const py::object &fields throw InvalidInputException("Can not create an empty struct type!"); } auto struct_type = LogicalType::STRUCT(std::move(types)); - return make_shared(struct_type); + return make_refcounted(struct_type); } shared_ptr DuckDBPyConnection::UnionType(const py::object &members) { @@ -69,7 +69,7 @@ shared_ptr DuckDBPyConnection::UnionType(const py::object &members throw InvalidInputException("Can not create an empty union type!"); } auto union_type = LogicalType::UNION(std::move(types)); - return make_shared(union_type); + return make_refcounted(union_type); } shared_ptr DuckDBPyConnection::EnumType(const string &name, const shared_ptr &type, @@ -79,7 +79,7 @@ shared_ptr DuckDBPyConnection::EnumType(const string &name, const shared_ptr DuckDBPyConnection::DecimalType(int width, int scale) { auto decimal_type = LogicalType::DECIMAL(width, scale); - return make_shared(decimal_type); + return make_refcounted(decimal_type); } shared_ptr DuckDBPyConnection::StringType(const string &collation) { @@ -89,14 +89,14 @@ shared_ptr DuckDBPyConnection::StringType(const string &collation) } else { type = LogicalType::VARCHAR_COLLATION(collation); } - return make_shared(type); + return make_refcounted(type); } shared_ptr DuckDBPyConnection::Type(const string &type_str) { if (!connection) { throw ConnectionException("Connection already closed!"); } - return make_shared(TransformStringToLogicalType(type_str, *connection->context)); + return make_refcounted(TransformStringToLogicalType(type_str, *connection->context)); } } // namespace duckdb diff --git a/tools/pythonpkg/src/pyexpression.cpp b/tools/pythonpkg/src/pyexpression.cpp index d389d1672ecb..09031706acdc 100644 --- a/tools/pythonpkg/src/pyexpression.cpp +++ b/tools/pythonpkg/src/pyexpression.cpp @@ -34,19 +34,19 @@ const ParsedExpression &DuckDBPyExpression::GetExpression() const { shared_ptr DuckDBPyExpression::Copy() const { auto expr = GetExpression().Copy(); - return make_shared(std::move(expr), order_type, null_order); + return make_refcounted(std::move(expr), order_type, null_order); } shared_ptr DuckDBPyExpression::SetAlias(const string &name) const { auto copied_expression = GetExpression().Copy(); copied_expression->alias = name; - return make_shared(std::move(copied_expression)); + return make_refcounted(std::move(copied_expression)); } shared_ptr DuckDBPyExpression::Cast(const DuckDBPyType &type) const { auto copied_expression = GetExpression().Copy(); auto case_expr = make_uniq(type.Type(), std::move(copied_expression)); - return make_shared(std::move(case_expr)); + return make_refcounted(std::move(case_expr)); } // Case Expression modifiers @@ -64,7 +64,7 @@ shared_ptr DuckDBPyExpression::InternalWhen(unique_ptrcase_checks.push_back(std::move(check)); - return make_shared(std::move(expr)); + return make_refcounted(std::move(expr)); } shared_ptr DuckDBPyExpression::When(const DuckDBPyExpression &condition, @@ -82,7 +82,7 @@ shared_ptr DuckDBPyExpression::Else(const DuckDBPyExpression auto expr = unique_ptr_cast(std::move(expr_p)); expr->else_expr = value.GetExpression().Copy(); - return make_shared(std::move(expr)); + return make_refcounted(std::move(expr)); } // Binary operators @@ -181,7 +181,7 @@ shared_ptr DuckDBPyExpression::In(const py::args &args) { expressions.push_back(std::move(expr)); } auto operator_expr = make_uniq(ExpressionType::COMPARE_IN, std::move(expressions)); - return make_shared(std::move(operator_expr)); + return make_refcounted(std::move(operator_expr)); } shared_ptr DuckDBPyExpression::NotIn(const py::args &args) { @@ -249,7 +249,7 @@ shared_ptr DuckDBPyExpression::StarExpression(const py::list case_insensitive_set_t exclude; auto star = make_uniq(); PopulateExcludeList(star->exclude_list, exclude_list); - return make_shared(std::move(star)); + return make_refcounted(std::move(star)); } shared_ptr DuckDBPyExpression::ColumnExpression(const string &column_name) { @@ -267,7 +267,7 @@ shared_ptr DuckDBPyExpression::ColumnExpression(const string } column_names.push_back(qualified_name.name); - return make_shared(make_uniq(std::move(column_names))); + return make_refcounted(make_uniq(std::move(column_names))); } shared_ptr DuckDBPyExpression::ConstantExpression(const py::object &value) { @@ -292,14 +292,14 @@ DuckDBPyExpression::InternalFunctionExpression(const string &function_name, vector> children, bool is_operator) { auto function_expression = make_uniq(function_name, std::move(children), nullptr, nullptr, false, is_operator); - return make_shared(std::move(function_expression)); + return make_refcounted(std::move(function_expression)); } shared_ptr DuckDBPyExpression::InternalUnaryOperator(ExpressionType type, const DuckDBPyExpression &arg) { auto expr = arg.GetExpression().Copy(); auto operator_expression = make_uniq(type, std::move(expr)); - return make_shared(std::move(operator_expression)); + return make_refcounted(std::move(operator_expression)); } shared_ptr DuckDBPyExpression::InternalConjunction(ExpressionType type, @@ -311,11 +311,11 @@ shared_ptr DuckDBPyExpression::InternalConjunction(Expressio children.push_back(other.GetExpression().Copy()); auto operator_expression = make_uniq(type, std::move(children)); - return make_shared(std::move(operator_expression)); + return make_refcounted(std::move(operator_expression)); } shared_ptr DuckDBPyExpression::InternalConstantExpression(Value val) { - return make_shared(make_uniq(std::move(val))); + return make_refcounted(make_uniq(std::move(val))); } shared_ptr DuckDBPyExpression::ComparisonExpression(ExpressionType type, @@ -323,7 +323,7 @@ shared_ptr DuckDBPyExpression::ComparisonExpression(Expressi const DuckDBPyExpression &right_p) { auto left = left_p.GetExpression().Copy(); auto right = right_p.GetExpression().Copy(); - return make_shared( + return make_refcounted( make_uniq(type, std::move(left), std::move(right))); } diff --git a/tools/pythonpkg/src/pyrelation.cpp b/tools/pythonpkg/src/pyrelation.cpp index 457e0f72e617..c71f09ee433f 100644 --- a/tools/pythonpkg/src/pyrelation.cpp +++ b/tools/pythonpkg/src/pyrelation.cpp @@ -146,7 +146,7 @@ unique_ptr DuckDBPyRelation::ProjectFromTypes(const py::object return ProjectFromExpression(projection); } -unique_ptr DuckDBPyRelation::EmptyResult(const std::shared_ptr &context, +unique_ptr DuckDBPyRelation::EmptyResult(const shared_ptr &context, const vector &types, vector names) { vector dummy_values; D_ASSERT(types.size() == names.size()); @@ -157,7 +157,7 @@ unique_ptr DuckDBPyRelation::EmptyResult(const std::shared_ptr } vector> single_row(1, dummy_values); auto values_relation = - make_uniq(make_shared(context, single_row, std::move(names))); + make_uniq(make_refcounted(context, single_row, std::move(names))); // Add a filter on an impossible condition return values_relation->FilterFromExpression("true = false"); } @@ -1198,7 +1198,7 @@ unique_ptr DuckDBPyRelation::Query(const string &view_name, co if (statement.type == StatementType::SELECT_STATEMENT) { auto select_statement = unique_ptr_cast(std::move(parser.statements[0])); auto query_relation = - make_shared(rel->context.GetContext(), std::move(select_statement), "query_relation"); + make_refcounted(rel->context.GetContext(), std::move(select_statement), "query_relation"); return make_uniq(std::move(query_relation)); } else if (IsDescribeStatement(statement)) { auto query = PragmaShow(view_name); diff --git a/tools/pythonpkg/src/typing/pytype.cpp b/tools/pythonpkg/src/typing/pytype.cpp index ad9828876b3d..00edd97af4f9 100644 --- a/tools/pythonpkg/src/typing/pytype.cpp +++ b/tools/pythonpkg/src/typing/pytype.cpp @@ -56,20 +56,20 @@ shared_ptr DuckDBPyType::GetAttribute(const string &name) const { for (idx_t i = 0; i < children.size(); i++) { auto &child = children[i]; if (StringUtil::CIEquals(child.first, name)) { - return make_shared(StructType::GetChildType(type, i)); + return make_refcounted(StructType::GetChildType(type, i)); } } } if (type.id() == LogicalTypeId::LIST && StringUtil::CIEquals(name, "child")) { - return make_shared(ListType::GetChildType(type)); + return make_refcounted(ListType::GetChildType(type)); } if (type.id() == LogicalTypeId::MAP) { auto is_key = StringUtil::CIEquals(name, "key"); auto is_value = StringUtil::CIEquals(name, "value"); if (is_key) { - return make_shared(MapType::KeyType(type)); + return make_refcounted(MapType::KeyType(type)); } else if (is_value) { - return make_shared(MapType::ValueType(type)); + return make_refcounted(MapType::ValueType(type)); } else { throw py::attribute_error(StringUtil::Format("Tried to get a child from a map by the name of '%s', but " "this type only has 'key' and 'value' children", @@ -313,19 +313,19 @@ void DuckDBPyType::Initialize(py::handle &m) { type_module.def_property_readonly("children", &DuckDBPyType::Children); type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { auto ltype = FromString(type_str, std::move(connection)); - return make_shared(ltype); + return make_refcounted(ltype); })); type_module.def(py::init<>([](const PyGenericAlias &obj) { auto ltype = FromGenericAlias(obj); - return make_shared(ltype); + return make_refcounted(ltype); })); type_module.def(py::init<>([](const PyUnionType &obj) { auto ltype = FromUnionType(obj); - return make_shared(ltype); + return make_refcounted(ltype); })); type_module.def(py::init<>([](const py::object &obj) { auto ltype = FromObject(obj); - return make_shared(ltype); + return make_refcounted(ltype); })); type_module.def("__getattr__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name")); type_module.def("__getitem__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name")); @@ -357,7 +357,7 @@ py::list DuckDBPyType::Children() const { py::list children; auto id = type.id(); if (id == LogicalTypeId::LIST) { - children.append(py::make_tuple("child", make_shared(ListType::GetChildType(type)))); + children.append(py::make_tuple("child", make_refcounted(ListType::GetChildType(type)))); return children; } // FIXME: where is ARRAY?? @@ -366,13 +366,14 @@ py::list DuckDBPyType::Children() const { auto &struct_children = StructType::GetChildTypes(type); for (idx_t i = 0; i < struct_children.size(); i++) { auto &child = struct_children[i]; - children.append(py::make_tuple(child.first, make_shared(StructType::GetChildType(type, i)))); + children.append( + py::make_tuple(child.first, make_refcounted(StructType::GetChildType(type, i)))); } return children; } if (id == LogicalTypeId::MAP) { - children.append(py::make_tuple("key", make_shared(MapType::KeyType(type)))); - children.append(py::make_tuple("value", make_shared(MapType::ValueType(type)))); + children.append(py::make_tuple("key", make_refcounted(MapType::KeyType(type)))); + children.append(py::make_tuple("value", make_refcounted(MapType::ValueType(type)))); return children; } if (id == LogicalTypeId::DECIMAL) { diff --git a/tools/pythonpkg/src/typing/typing.cpp b/tools/pythonpkg/src/typing/typing.cpp index 0c1793ed703e..8064acfc22e1 100644 --- a/tools/pythonpkg/src/typing/typing.cpp +++ b/tools/pythonpkg/src/typing/typing.cpp @@ -4,38 +4,38 @@ namespace duckdb { static void DefineBaseTypes(py::handle &m) { - m.attr("SQLNULL") = make_shared(LogicalType::SQLNULL); - m.attr("BOOLEAN") = make_shared(LogicalType::BOOLEAN); - m.attr("TINYINT") = make_shared(LogicalType::TINYINT); - m.attr("UTINYINT") = make_shared(LogicalType::UTINYINT); - m.attr("SMALLINT") = make_shared(LogicalType::SMALLINT); - m.attr("USMALLINT") = make_shared(LogicalType::USMALLINT); - m.attr("INTEGER") = make_shared(LogicalType::INTEGER); - m.attr("UINTEGER") = make_shared(LogicalType::UINTEGER); - m.attr("BIGINT") = make_shared(LogicalType::BIGINT); - m.attr("UBIGINT") = make_shared(LogicalType::UBIGINT); - m.attr("HUGEINT") = make_shared(LogicalType::HUGEINT); - m.attr("UHUGEINT") = make_shared(LogicalType::UHUGEINT); - m.attr("UUID") = make_shared(LogicalType::UUID); - m.attr("FLOAT") = make_shared(LogicalType::FLOAT); - m.attr("DOUBLE") = make_shared(LogicalType::DOUBLE); - m.attr("DATE") = make_shared(LogicalType::DATE); - - m.attr("TIMESTAMP") = make_shared(LogicalType::TIMESTAMP); - m.attr("TIMESTAMP_MS") = make_shared(LogicalType::TIMESTAMP_MS); - m.attr("TIMESTAMP_NS") = make_shared(LogicalType::TIMESTAMP_NS); - m.attr("TIMESTAMP_S") = make_shared(LogicalType::TIMESTAMP_S); - - m.attr("TIME") = make_shared(LogicalType::TIME); - - m.attr("TIME_TZ") = make_shared(LogicalType::TIME_TZ); - m.attr("TIMESTAMP_TZ") = make_shared(LogicalType::TIMESTAMP_TZ); - - m.attr("VARCHAR") = make_shared(LogicalType::VARCHAR); - - m.attr("BLOB") = make_shared(LogicalType::BLOB); - m.attr("BIT") = make_shared(LogicalType::BIT); - m.attr("INTERVAL") = make_shared(LogicalType::INTERVAL); + m.attr("SQLNULL") = make_refcounted(LogicalType::SQLNULL); + m.attr("BOOLEAN") = make_refcounted(LogicalType::BOOLEAN); + m.attr("TINYINT") = make_refcounted(LogicalType::TINYINT); + m.attr("UTINYINT") = make_refcounted(LogicalType::UTINYINT); + m.attr("SMALLINT") = make_refcounted(LogicalType::SMALLINT); + m.attr("USMALLINT") = make_refcounted(LogicalType::USMALLINT); + m.attr("INTEGER") = make_refcounted(LogicalType::INTEGER); + m.attr("UINTEGER") = make_refcounted(LogicalType::UINTEGER); + m.attr("BIGINT") = make_refcounted(LogicalType::BIGINT); + m.attr("UBIGINT") = make_refcounted(LogicalType::UBIGINT); + m.attr("HUGEINT") = make_refcounted(LogicalType::HUGEINT); + m.attr("UHUGEINT") = make_refcounted(LogicalType::UHUGEINT); + m.attr("UUID") = make_refcounted(LogicalType::UUID); + m.attr("FLOAT") = make_refcounted(LogicalType::FLOAT); + m.attr("DOUBLE") = make_refcounted(LogicalType::DOUBLE); + m.attr("DATE") = make_refcounted(LogicalType::DATE); + + m.attr("TIMESTAMP") = make_refcounted(LogicalType::TIMESTAMP); + m.attr("TIMESTAMP_MS") = make_refcounted(LogicalType::TIMESTAMP_MS); + m.attr("TIMESTAMP_NS") = make_refcounted(LogicalType::TIMESTAMP_NS); + m.attr("TIMESTAMP_S") = make_refcounted(LogicalType::TIMESTAMP_S); + + m.attr("TIME") = make_refcounted(LogicalType::TIME); + + m.attr("TIME_TZ") = make_refcounted(LogicalType::TIME_TZ); + m.attr("TIMESTAMP_TZ") = make_refcounted(LogicalType::TIMESTAMP_TZ); + + m.attr("VARCHAR") = make_refcounted(LogicalType::VARCHAR); + + m.attr("BLOB") = make_refcounted(LogicalType::BLOB); + m.attr("BIT") = make_refcounted(LogicalType::BIT); + m.attr("INTERVAL") = make_refcounted(LogicalType::INTERVAL); } void DuckDBPyTyping::Initialize(py::module_ &parent) { From 7990a3ea3affe30feb980f802bcd3d28ca611161 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Wed, 10 Apr 2024 13:31:10 +0200 Subject: [PATCH 072/201] bitpacking --- src/execution/perfect_aggregate_hashtable.cpp | 9 +++++---- .../duckdb/storage/buffer/buffer_pool.hpp | 2 +- .../compression/chimp/algorithm/bit_reader.hpp | 2 +- src/storage/buffer/buffer_pool.cpp | 13 ++++++++++--- src/storage/buffer/buffer_pool_reservation.cpp | 4 ++-- src/storage/compression/bitpacking.cpp | 16 ++++++++++------ 6 files changed, 29 insertions(+), 17 deletions(-) diff --git a/src/execution/perfect_aggregate_hashtable.cpp b/src/execution/perfect_aggregate_hashtable.cpp index da7c20192452..ba8e366e9414 100644 --- a/src/execution/perfect_aggregate_hashtable.cpp +++ b/src/execution/perfect_aggregate_hashtable.cpp @@ -64,7 +64,7 @@ static void ComputeGroupLocationTemplated(UnifiedVectorFormat &group_data, Value // we only need to handle non-null values here if (group_data.validity.RowIsValid(index)) { D_ASSERT(data[index] >= min_val); - uintptr_t adjusted_value = (data[index] - min_val) + 1; + auto adjusted_value = UnsafeNumericCast((data[index] - min_val) + 1); address_data[i] += adjusted_value << current_shift; } } @@ -72,7 +72,7 @@ static void ComputeGroupLocationTemplated(UnifiedVectorFormat &group_data, Value // no null values: we can directly compute the addresses for (idx_t i = 0; i < count; i++) { auto index = group_data.sel->get_index(i); - uintptr_t adjusted_value = (data[index] - min_val) + 1; + auto adjusted_value = UnsafeNumericCast((data[index] - min_val) + 1); address_data[i] += adjusted_value << current_shift; } } @@ -149,7 +149,7 @@ void PerfectAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload) } // move to the next aggregate payload_idx += input_count; - VectorOperations::AddInPlace(addresses, aggregate.payload_size, payload.size()); + VectorOperations::AddInPlace(addresses, UnsafeNumericCast(aggregate.payload_size), payload.size()); } } @@ -205,7 +205,8 @@ static void ReconstructGroupVectorTemplated(uint32_t group_values[], Value &min, validity_mask.SetInvalid(i); } else { // otherwise we add the value (minus 1) to the min value - data[i] = UnsafeNumericCast(min_data + group_index - 1); + data[i] = UnsafeNumericCast(UnsafeNumericCast(min_data) + + UnsafeNumericCast(group_index) - 1); } } } diff --git a/src/include/duckdb/storage/buffer/buffer_pool.hpp b/src/include/duckdb/storage/buffer/buffer_pool.hpp index 9b9b3fd78d5f..57762b63d557 100644 --- a/src/include/duckdb/storage/buffer/buffer_pool.hpp +++ b/src/include/duckdb/storage/buffer/buffer_pool.hpp @@ -48,7 +48,7 @@ class BufferPool { //! blocks can be evicted void SetLimit(idx_t limit, const char *exception_postscript); - void IncreaseUsedMemory(MemoryTag tag, idx_t size); + void UpdateUsedMemory(MemoryTag tag, int64_t size); idx_t GetUsedMemory() const; diff --git a/src/include/duckdb/storage/compression/chimp/algorithm/bit_reader.hpp b/src/include/duckdb/storage/compression/chimp/algorithm/bit_reader.hpp index 57512f27be5f..a2910e63e46a 100644 --- a/src/include/duckdb/storage/compression/chimp/algorithm/bit_reader.hpp +++ b/src/include/duckdb/storage/compression/chimp/algorithm/bit_reader.hpp @@ -150,7 +150,7 @@ struct BitReader { result = result << 8 | InnerReadByte(i); } result = result << remainder | InnerRead(remainder, bytes); - index += (bytes << 3) + remainder; + index += static_cast(bytes << 3) + remainder; return result; } diff --git a/src/storage/buffer/buffer_pool.cpp b/src/storage/buffer/buffer_pool.cpp index 0b427b227306..a7741b268a4d 100644 --- a/src/storage/buffer/buffer_pool.cpp +++ b/src/storage/buffer/buffer_pool.cpp @@ -66,9 +66,16 @@ bool BufferPool::AddToEvictionQueue(shared_ptr &handle) { return false; } -void BufferPool::IncreaseUsedMemory(MemoryTag tag, idx_t size) { - current_memory += size; - memory_usage_per_tag[uint8_t(tag)] += size; +void BufferPool::UpdateUsedMemory(MemoryTag tag, int64_t size) { + if (size < 0) { + current_memory -= UnsafeNumericCast(-size); + memory_usage_per_tag[uint8_t(tag)] -= UnsafeNumericCast(-size); + } else { + current_memory += UnsafeNumericCast(size); + ; + memory_usage_per_tag[uint8_t(tag)] += UnsafeNumericCast(size); + ; + } } idx_t BufferPool::GetUsedMemory() const { diff --git a/src/storage/buffer/buffer_pool_reservation.cpp b/src/storage/buffer/buffer_pool_reservation.cpp index f22a96ffa58f..783a98079fb6 100644 --- a/src/storage/buffer/buffer_pool_reservation.cpp +++ b/src/storage/buffer/buffer_pool_reservation.cpp @@ -23,8 +23,8 @@ BufferPoolReservation::~BufferPoolReservation() { } void BufferPoolReservation::Resize(idx_t new_size) { - int64_t delta = (int64_t)new_size - size; - pool.IncreaseUsedMemory(tag, delta); + auto delta = UnsafeNumericCast(new_size) - UnsafeNumericCast(size); + pool.UpdateUsedMemory(tag, delta); size = new_size; } diff --git a/src/storage/compression/bitpacking.cpp b/src/storage/compression/bitpacking.cpp index efdd03087bc8..03472f4cd5d3 100644 --- a/src/storage/compression/bitpacking.cpp +++ b/src/storage/compression/bitpacking.cpp @@ -5,6 +5,7 @@ #include "duckdb/common/operator/add.hpp" #include "duckdb/common/operator/multiply.hpp" #include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/operator/cast_operators.hpp" #include "duckdb/function/compression/compression.hpp" #include "duckdb/function/compression_function.hpp" #include "duckdb/main/config.hpp" @@ -804,8 +805,10 @@ void BitpackingScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t T *target_ptr = result_data + result_offset + scanned; for (idx_t i = 0; i < to_scan; i++) { - target_ptr[i] = (static_cast(scan_state.current_group_offset + i) * scan_state.current_constant) + - scan_state.current_frame_of_reference; + T multiplier; + auto success = TryCast::Operation(scan_state.current_group_offset + i, multiplier); + D_ASSERT(success); + target_ptr[i] = (multiplier * scan_state.current_constant) + scan_state.current_frame_of_reference; } scanned += to_scan; @@ -890,16 +893,17 @@ void BitpackingFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t r } if (scan_state.current_group.mode == BitpackingMode::CONSTANT_DELTA) { + T multiplier; + auto cast = TryCast::Operation(scan_state.current_group_offset, multiplier); + D_ASSERT(cast); #ifdef DEBUG // overflow check T result; - bool multiply = TryMultiplyOperator::Operation(static_cast(scan_state.current_group_offset), - scan_state.current_constant, result); + bool multiply = TryMultiplyOperator::Operation(multiplier, scan_state.current_constant, result); bool add = TryAddOperator::Operation(result, scan_state.current_frame_of_reference, result); D_ASSERT(multiply && add); #endif - *current_result_ptr = (static_cast(scan_state.current_group_offset) * scan_state.current_constant) + - scan_state.current_frame_of_reference; + *current_result_ptr = (multiplier * scan_state.current_constant) + scan_state.current_frame_of_reference; return; } From e9d973aeb4586feebbad3e708cec30d72b5c387d Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 16:07:36 +0200 Subject: [PATCH 073/201] wip --- .../pythonpkg/scripts/connection_methods.json | 749 ------------------ tools/pythonpkg/src/pyconnection.cpp | 5 +- 2 files changed, 4 insertions(+), 750 deletions(-) diff --git a/tools/pythonpkg/scripts/connection_methods.json b/tools/pythonpkg/scripts/connection_methods.json index dedaef4cb967..c53f2b2d47dd 100644 --- a/tools/pythonpkg/scripts/connection_methods.json +++ b/tools/pythonpkg/scripts/connection_methods.json @@ -103,754 +103,5 @@ } ], "return": "DuckDBPyConnection" - }, - - { - "name": ["sqltype", "dtype", "type"], - "function": "Type", - "docs": "Create a type object by parsing the 'type_str' string", - "args": [ - { - "name": "type_str" - } - ], - "return": - }, - { - "name": "array_type", - "function": "ArrayType", - "docs": "Create an array type object of 'type'", - "args": [ - - ] - { - "name": "type" - .none(false) - }, - { - "name": "size" - .none(false - }, - }, - { - "name": "list_type", - "function": "ListType", - "docs": "Create a list type object of 'type'", - "args": [ - - ] - { - "name": "type" - .none(false - }, - }, - { - "name": "union_type", - "function": "UnionType", - "docs": "Create a union type object from 'members'", - "args": [ - - ] - { - "name": "members" - .none(false) - }, - }, - { - "name": "string_type", - "function": "StringType", - "docs": "Create a string type with an optional collation", - "args": [ - - ] - { - "name": "collation" - = string() - }, - }, - { - "name": "enum_type", - "function": "EnumType", - "docs": "Create an enum type of underlying 'type', consisting of the list of 'values'", - "args": [ - - ] - { - "name": "name" - }, - { - "name": "type" - }, - { - "name": "values" - }, - }, - { - "name": "decimal_type", - "function": "DecimalType", - "docs": "Create a decimal type with 'width' and 'scale'", - "args": [ - - ] - { - "name": "width" - }, - { - "name": "scale" - }, - }, - { - "name": ["struct_type", "row_type"], - "function": "StructType", - "docs": "Create a struct type object from 'fields'", - { - "name": "fields" - }, - }, - { - "name": "map_type", - "function": "MapType", - "docs": "Create a map type object from 'key_type' and 'value_type'", - "args": [ - - ] - { - "name": "key" - .none(false) - }, - { - "name": "value" - .none(false) - }, - }, - { - "name": "duplicate", - "function": "Cursor", - "docs": "Create a duplicate of the current connection", - }, - { - "name": "execute", - "function": "Execute", - "docs": "Execute the given SQL query, optionally using prepared statements with parameters set", - "args": [ - - ] - { - "name": "query" - }, - { - "name": "parameters" - = py::none() - }, - { - "name": "multiple_parameter_sets" - = false - }, - }, - { - "name": "executemany", - "function": "ExecuteMany", - "docs": "Execute the given prepared statement multiple times using the list of parameter sets in parameters", - "args": [ - - ] - { - "name": "query" - }, - { - "name": "parameters" - = py::none() - }, - }, - { - "name": "close", - "function": "Close", - "docs": "Close the connection" - }, - { - "name": "interrupt", - "function": "Interrupt", - "docs": "Interrupt pending operations" - }, - { - "name": "fetchone", - "function": "FetchOne", - "docs": "Fetch a single row from a result following execute" - }, - { - "name": "fetchmany", - "function": "FetchMany", - "docs": "Fetch the next set of rows from a result following execute" - { - "name": "size" - = 1 - }, - }, - { - "name": "fetchall", - "function": "FetchAll", - "docs": "Fetch all rows from a result following execute" - }, - { - "name": "fetchnumpy", - "function": "FetchNumpy", - "docs": "Fetch a result as list of NumPy arrays following execute" - }, - { - "name": "fetchdf", - "function": "FetchDF", - "docs": "Fetch a result as DataFrame following execute()", - py::kw_only(), - { - "name": "date_as_object" - = false - }, - }, - { - "name": "fetch_df", - "function": "FetchDF", - "docs": "Fetch a result as DataFrame following execute()", - py::kw_only(), - { - "name": "date_as_object" - = false - }, - }, - { - "name": "fetch_df_chunk", - "function": "FetchDFChunk", - "docs": "Fetch a chunk of the result as Data.Frame following execute()", - "args": [ - - ] - { - "name": "vectors_per_chunk" - = 1 - }, - py::kw_only(), - { - "name": "date_as_object" - = false - }, - }, - { - "name": "df", - "function": "FetchDF", - "docs": "Fetch a result as DataFrame following execute()", - py::kw_only(), - { - "name": "date_as_object" - = false - }, - }, - { - "name": "pl", - "function": "FetchPolars", - "docs": "Fetch a result as Polars DataFrame following execute()", - "args": [ - - ] - { - "name": "rows_per_batch" - = 1000000 - }, - }, - { - "name": "fetch_arrow_table", - "function": "FetchArrow", - "docs": "Fetch a result as Arrow table following execute()", - "args": [ - - ] - { - "name": "rows_per_batch" - = 1000000 - }, - }, - { - "name": "fetch_record_batch", - "function": "FetchRecordBatchReader", - "docs": "Fetch an Arrow RecordBatchReader following execute()", - "args": [ - - ] - { - "name": "rows_per_batch" - = 1000000 - }, - }, - { - "name": "arrow", - "function": "FetchArrow", - "docs": "Fetch a result as Arrow table following execute()", - "args": [ - - ] - { - "name": "rows_per_batch" - = 1000000 - }, - }, - { - "name": "torch", - "function": "FetchPyTorch", - "docs": "Fetch a result as dict of PyTorch Tensors following execute()" - - }, - { - "name": "tf", - "function": "FetchTF", - "docs": "Fetch a result as dict of TensorFlow Tensors following execute()" - - }, - { - "name": "begin", - "function": "Begin", - "docs": "Start a new transaction" - - }, - { - "name": "commit", - "function": "Commit", - "docs": "Commit changes performed within a transaction" - }, - { - "name": "rollback", - "function": "Rollback", - "docs": "Roll back changes performed within a transaction" - }, - { - "name": "append", - "function": "Append", - "docs": "Append the passed DataFrame to the named table", - "args": [ - - ] - { - "name": "table_name" - }, - { - "name": "df" - }, - py::kw_only(), - { - "name": "by_name" - = false - }, - }, - { - "name": "register", - "function": "RegisterPythonObject", - "docs": "Register the passed Python Object value for querying with a view", - "args": [ - - ] - { - "name": "view_name" - }, - { - "name": "python_object" - }, - }, - { - "name": "unregister", - "function": "UnregisterPythonObject", - "docs": "Unregister the view name", - "args": [ - - ] - { - "name": "view_name" - }, - }, - { - "name": "table", - "function": "Table", - "docs": "Create a relation object for the name'd table", - "args": [ - - ] - { - "name": "table_name" - }, - }, - { - "name": "view", - "function": "View", - "docs": "Create a relation object for the name'd view", - "args": [ - - ] - { - "name": "view_name" - }, - }, - { - "name": "values", - "function": "Values", - "docs": "Create a relation object from the passed values", - "args": [ - - ] - { - "name": "values" - }, - }, - { - "name": "table_function", - "function": "TableFunction", - "docs": "Create a relation object from the name'd table function with given parameters", - "args": [ - - ] - { - "name": "name" - }, - { - "name": "parameters" - = py::none() - }, - }, - { - "name": "read_json", - "function": "ReadJSON", - "docs": "Create a relation object from the JSON file in 'name'", - "args": [ - - ] - { - "name": "name" - }, - py::kw_only(), - { - "name": "columns" - = py::none() - }, - { - "name": "sample_size" - = py::none() - }, - { - "name": "maximum_depth" - = py::none() - }, - { - "name": "records" - = py::none() - }, - { - "name": "format" - = py::none() - }, - }, - { - "name": "extract_statements", - "function": "ExtractStatements", - "docs": "Parse the query string and extract the Statement object(s) produced", - "args": [ - - ] - { - "name": "query" - }, - }, - - { - "name": ["sql", "query", "from_query"], - "function": "RunQuery", - "docs": "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, otherwise run the query as-is.", - "args": [ - - ] - { - "name": "query" - }, - py::kw_only(), - { - "name": "alias" - = "" - }, - { - "name": "params" - = py::none( - }, - }, - - { - "name": ["read_csv", "from_csv_auto"], - "function": "ReadCSV", - "docs": "Create a relation object from the CSV file in 'name'", - "args": [ - - ] - { - "name": "name" - }, - py::kw_only(), - { - "name": "header" - = py::none() - }, - { - "name": "compression" - = py::none() - }, - { - "name": "sep" - = py::none() - }, - { - "name": "delimiter" - = py::none() - }, - { - "name": "dtype" - = py::none() - }, - { - "name": "na_values" - = py::none() - }, - { - "name": "skiprows" - = py::none() - }, - { - "name": "quotechar" - = py::none() - }, - { - "name": "escapechar" - = py::none() - }, - { - "name": "encoding" - = py::none() - }, - { - "name": "parallel" - = py::none() - }, - { - "name": "date_format" - = py::none() - }, - { - "name": "timestamp_format" - = py::none() - }, - { - "name": "sample_size" - = py::none() - }, - { - "name": "all_varchar" - = py::none() - }, - { - "name": "normalize_names" - = py::none() - }, - { - "name": "filename" - = py::none() - }, - { - "name": "null_padding" - = py::none() - }, - { - "name": "names" - = py::none( - }, - }, - { - "name": "from_df", - "function": "FromDF", - "docs": "Create a relation object from the Data.Frame in df", - "args": [ - - ] - { - "name": "df" - = py::none() - }, - }, - { - "name": "from_arrow", - "function": "FromArrow", - "docs": "Create a relation object from an Arrow object", - "args": [ - - ] - { - "name": "arrow_object" - }, - }, - - { - "name": ["from_parquet", "read_parquet"], - "function": "FromParquet", - "docs": "Create a relation object from the Parquet files in file_glob", - "args": [ - - ] - { - "name": "file_glob" - }, - { - "name": "binary_as_string" - = false - }, - py::kw_only(), - { - "name": "file_row_number" - = false - }, - { - "name": "filename" - = false - }, - { - "name": "hive_partitioning" - = false - }, - { - "name": "union_by_name" - = false - }, - { - "name": "compression" - = py::none( - }, - }, - { - "name": ["from_parquet", "read_parquet"], - "function": "FromParquets", - "docs": "Create a relation object from the Parquet files in file_globs", - "args": [ - - ] - { - "name": "file_globs" - }, - { - "name": "binary_as_string" - = false - }, - py::kw_only(), - { - "name": "file_row_number" - = false - }, - { - "name": "filename" - = false - }, - { - "name": "hive_partitioning" - = false - }, - { - "name": "union_by_name" - = false - }, - { - "name": "compression" - = py::none( - }, - }, - { - "name": "from_substrait", - "function": "FromSubstrait", - "docs": "Create a query object from protobuf plan", - "args": [ - - ] - { - "name": "proto" - }, - }, - { - "name": "get_substrait", - "function": "GetSubstrait", - "docs": "Serialize a query to protobuf", - "args": [ - - ] - { - "name": "query" - }, - py::kw_only(), - { - "name": "enable_optimizer" - = true - }, - }, - { - "name": "get_substrait_json", - "function": "GetSubstraitJSON", - "docs": "Serialize a query to protobuf on the JSON format", - "args": [ - - ] - { - "name": "query" - }, - py::kw_only(), - { - "name": "enable_optimizer" - = true - }, - }, - { - "name": "from_substrait_json", - "function": "FromSubstraitJSON", - "docs": "Create a query object from a JSON protobuf plan", - "args": [ - - ] - { - "name": "json" - }, - }, - { - "name": "get_table_names", - "function": "GetTableNames", - "docs": "Extract the required table names from a query", - "args": [ - - ] - { - "name": "query" - }, - }, - - { - "name": "install_extension", - "function": "InstallExtension", - "docs": "Install an extension by name", - "args": [ - - ] - { - "name": "extension" - }, - py::kw_only(), - { - "name": "force_install" - = false - }, - }, - { - "name": "load_extension", - "function": "LoadExtension", - "docs": "Load an installed extension", - "args": [ - - ] - { - "name": "extension" - }, } ] diff --git a/tools/pythonpkg/src/pyconnection.cpp b/tools/pythonpkg/src/pyconnection.cpp index f4cd9100a142..2de161bc3d25 100644 --- a/tools/pythonpkg/src/pyconnection.cpp +++ b/tools/pythonpkg/src/pyconnection.cpp @@ -119,6 +119,9 @@ py::object ArrowTableFromDataframe(const py::object &df) { } } +// NOTE: this function is generated by tools/pythonpkg/scripts/generate_connection_methods.py. +// Do not edit this function manually, your changes will be overwritten! + static void InitializeConnectionMethods(py::class_> &m) { m.def("cursor", &DuckDBPyConnection::Cursor, "Create a duplicate of the current connection") .def("register_filesystem", &DuckDBPyConnection::RegisterFilesystem, "Register a fsspec compliant filesystem", @@ -266,7 +269,7 @@ static void InitializeConnectionMethods(py::class_GetFileSystem(); From 073fd1b184b8dbb7358fd1f139e5540f3d63d90d Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 16:11:17 +0200 Subject: [PATCH 074/201] start of script to generate the definitions inside the DuckDBPyConnection class --- .../scripts/generate_connection_methods.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tools/pythonpkg/scripts/generate_connection_methods.py diff --git a/tools/pythonpkg/scripts/generate_connection_methods.py b/tools/pythonpkg/scripts/generate_connection_methods.py new file mode 100644 index 000000000000..84bdc5a0ae59 --- /dev/null +++ b/tools/pythonpkg/scripts/generate_connection_methods.py @@ -0,0 +1,48 @@ +import os +import json + +os.chdir(os.path.dirname(__file__)) + +JSON_PATH = os.path.join("connection_methods.json") +PYCONNECTION_SOURCE = os.path.join("..", "src", "pyconnection.cpp") + +INITIALIZE_METHOD = ( + "static void InitializeConnectionMethods(py::class_> &m) {" +) +END_MARKER = "} // END_OF_CONNECTION_METHODS" + +# Read the PYCONNECTION_SOURCE file +with open(PYCONNECTION_SOURCE, 'r') as source_file: + source_code = source_file.readlines() + +# Locate the InitializeConnectionMethods function in it +start_index = -1 +end_index = -1 +for i, line in enumerate(source_code): + if line.startswith(INITIALIZE_METHOD): + start_index = i + elif line.startswith(END_MARKER): + end_index = i + +if start_index == -1 or end_index == -1: + raise ValueError("Couldn't find start or end marker in source file") + +start_section = source_code[: start_index + 1] +end_section = source_code[end_index:] + +# Generate the definition code from the json +# Read the JSON file +with open(JSON_PATH, 'r') as json_file: + connection_methods = json.load(json_file) + +regenerated_method = [] +regenerated_method.extend(['', '']) + +with_newlines = [x + '\n' for x in regenerated_method] +# Recreate the file content by concatenating all the pieces together + +new_content = start_section + with_newlines + end_section + +# Write out the modified PYCONNECTION_SOURCE file +with open(PYCONNECTION_SOURCE, 'w') as source_file: + source_file.write("".join(new_content)) From bbf679e0dc92729042c42c3dce9655fb8c2bea6f Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 16:34:44 +0200 Subject: [PATCH 075/201] generate the c++ (pybind) definition code --- .../scripts/generate_connection_methods.py | 53 +++++++++++++++++-- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/tools/pythonpkg/scripts/generate_connection_methods.py b/tools/pythonpkg/scripts/generate_connection_methods.py index 84bdc5a0ae59..1882e6e0afa6 100644 --- a/tools/pythonpkg/scripts/generate_connection_methods.py +++ b/tools/pythonpkg/scripts/generate_connection_methods.py @@ -29,16 +29,61 @@ start_section = source_code[: start_index + 1] end_section = source_code[end_index:] +# ---- Generate the definition code from the json ---- -# Generate the definition code from the json # Read the JSON file with open(JSON_PATH, 'r') as json_file: connection_methods = json.load(json_file) -regenerated_method = [] -regenerated_method.extend(['', '']) +body = [] +body.extend(['', '']) -with_newlines = [x + '\n' for x in regenerated_method] +DEFAULT_ARGUMENT_MAP = {'True': 'true', 'False': 'false', 'None': 'py::none()'} + + +def map_default(val): + if val in DEFAULT_ARGUMENT_MAP: + return DEFAULT_ARGUMENT_MAP[val] + return val + + +for conn in connection_methods: + definition = f"m.def(\"{conn['name']}\"" + definition += ", " + definition += f"""&DuckDBPyConnection::{conn['function']}""" + definition += ", " + definition += f"\"{conn['docs']}\"" + if 'args' in conn: + definition += ", " + arguments = [] + for arg in conn['args']: + argument = f"py::arg(\"{arg['name']}\")" + # Add the default argument if present + if 'default' in arg: + default = map_default(arg['default']) + argument += f" = {default}" + arguments.append(argument) + definition += ', '.join(arguments) + if 'kwargs' in conn: + definition += ", " + definition += "py::kw_only(), " + keyword_arguments = [] + for kwarg in conn['kwargs']: + keyword_argument = f"py::arg(\"{kwarg['name']}\")" + # Add the default argument if present + if 'default' in arg: + default = map_default(arg['default']) + keyword_argument += f" = {default}" + keyword_arguments.append(keyword_argument) + definition += ', '.join(keyword_arguments) + definition += ");" + body.append(definition) + +# ---- End of generation code ---- + +with_newlines = [x + '\n' for x in body] +print(''.join(with_newlines)) +exit() # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section From c6800bd85c5927dded2719d401c30c15cf6aebed Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 16:58:08 +0200 Subject: [PATCH 076/201] extend the list of functions --- .../pythonpkg/scripts/connection_methods.json | 737 ++++++++++++++++++ .../scripts/generate_connection_methods.py | 63 +- 2 files changed, 772 insertions(+), 28 deletions(-) diff --git a/tools/pythonpkg/scripts/connection_methods.json b/tools/pythonpkg/scripts/connection_methods.json index c53f2b2d47dd..39e11bb0ba3c 100644 --- a/tools/pythonpkg/scripts/connection_methods.json +++ b/tools/pythonpkg/scripts/connection_methods.json @@ -103,5 +103,742 @@ } ], "return": "DuckDBPyConnection" + }, + { + "name": ["sqltype", "dtype", "type"], + "function": "Type", + "docs": "Create a type object by parsing the 'type_str' string", + "args": [ + { + "name": "type_str", + "type": "str" + } + ], + "return": "DuckDBPyType" + }, + { + "name": "array_type", + "function": "ArrayType", + "docs": "Create an array type object of 'type'", + "args": [ + { + "name": "type", + "type": "DuckDBPyType", + "allow_none": false + }, + { + "name": "size", + "type": "int" + } + ] + }, + { + "name": "list_type", + "function": "ListType", + "docs": "Create a list type object of 'type'", + "args": [ + { + "name": "type", + "type": "DuckDBPyType", + "allow_none": false + } + ] + }, + { + "name": "union_type", + "function": "UnionType", + "docs": "Create a union type object from 'members'", + "args": [ + { + "name": "members", + "type": "DuckDBPyType", + "allow_none": false + } + ] + }, + { + "name": "string_type", + "function": "StringType", + "docs": "Create a string type with an optional collation", + "args": [ + { + "name": "collation", + "type": "str", + "default": "\"\"" + } + ] + }, + { + "name": "enum_type", + "function": "EnumType", + "docs": "Create an enum type of underlying 'type', consisting of the list of 'values'", + "args": [ + { + "name": "name", + "type": "str" + }, + { + "name": "type", + "type": "DuckDBPyType" + }, + { + "name": "values", + "type": "List[Any]" + } + ] + }, + { + "name": "decimal_type", + "function": "DecimalType", + "docs": "Create a decimal type with 'width' and 'scale'", + "args": [ + { + "name": "width", + "type": "int" + }, + { + "name": "scale", + "type": "int" + } + ] + }, + { + "name": ["struct_type", "row_type"], + "function": "StructType", + "docs": "Create a struct type object from 'fields'", + "args": [ + { + "name": "fields", + "type": "Union[Dict[str, DuckDBPyType], List[str]]" + } + ] + }, + { + "name": "map_type", + "function": "MapType", + "docs": "Create a map type object from 'key_type' and 'value_type'", + "args": [ + { + "name": "key", + "allow_none": false + }, + { + "name": "value", + "allow_none": false + } + ] + }, + { + "name": "duplicate", + "function": "Cursor", + "docs": "Create a duplicate of the current connection" + }, + { + "name": "execute", + "function": "Execute", + "docs": "Execute the given SQL query, optionally using prepared statements with parameters set", + "args": [ + { + "name": "query" + }, + { + "name": "parameters", + "default": "None" + }, + { + "name": "multiple_parameter_sets", + "default": "false" + } + ] + }, + { + "name": "executemany", + "function": "ExecuteMany", + "docs": "Execute the given prepared statement multiple times using the list of parameter sets in parameters", + "args": [ + { + "name": "query" + }, + { + "name": "parameters", + "default": "None" + } + ] + }, + { + "name": "close", + "function": "Close", + "docs": "Close the connection" + }, + { + "name": "interrupt", + "function": "Interrupt", + "docs": "Interrupt pending operations" + }, + { + "name": "fetchone", + "function": "FetchOne", + "docs": "Fetch a single row from a result following execute" + }, + { + "name": "fetchmany", + "function": "FetchMany", + "docs": "Fetch the next set of rows from a result following execute", + "args": [ + { + "name": "size", + "default": "1" + } + ] + }, + { + "name": "fetchall", + "function": "FetchAll", + "docs": "Fetch all rows from a result following execute" + }, + { + "name": "fetchnumpy", + "function": "FetchNumpy", + "docs": "Fetch a result as list of NumPy arrays following execute" + }, + { + "name": "fetchdf", + "function": "FetchDF", + "docs": "Fetch a result as DataFrame following execute()", + "kwargs": [ + { + "name": "date_as_object", + "default": "false" + } + ] + }, + { + "name": "fetch_df", + "function": "FetchDF", + "docs": "Fetch a result as DataFrame following execute()", + "kwargs": [ + { + "name": "date_as_object", + "default": "false" + } + ] + }, + { + "name": "fetch_df_chunk", + "function": "FetchDFChunk", + "docs": "Fetch a chunk of the result as Data.Frame following execute()", + "args": [ + { + "name": "vectors_per_chunk", + "default": "1" + } + ], + "kwargs": [ + { + "name": "date_as_object", + "default": "false" + } + ] + }, + { + "name": "df", + "function": "FetchDF", + "docs": "Fetch a result as DataFrame following execute()", + "kwargs": [ + { + "name": "date_as_object", + "default": "false" + } + ] + }, + { + "name": "pl", + "function": "FetchPolars", + "docs": "Fetch a result as Polars DataFrame following execute()", + "args": [ + { + "name": "rows_per_batch", + "default": "1000000" + } + ] + }, + { + "name": "fetch_arrow_table", + "function": "FetchArrow", + "docs": "Fetch a result as Arrow table following execute()", + "args": [ + { + "name": "rows_per_batch", + "default": "1000000" + } + ] + }, + { + "name": "fetch_record_batch", + "function": "FetchRecordBatchReader", + "docs": "Fetch an Arrow RecordBatchReader following execute()", + "args": [ + { + "name": "rows_per_batch", + "default": "1000000" + } + ] + }, + { + "name": "arrow", + "function": "FetchArrow", + "docs": "Fetch a result as Arrow table following execute()", + "args": [ + { + "name": "rows_per_batch", + "default": "1000000" + } + ] + }, + { + "name": "torch", + "function": "FetchPyTorch", + "docs": "Fetch a result as dict of PyTorch Tensors following execute()" + + }, + { + "name": "tf", + "function": "FetchTF", + "docs": "Fetch a result as dict of TensorFlow Tensors following execute()" + + }, + { + "name": "begin", + "function": "Begin", + "docs": "Start a new transaction" + + }, + { + "name": "commit", + "function": "Commit", + "docs": "Commit changes performed within a transaction" + }, + { + "name": "rollback", + "function": "Rollback", + "docs": "Roll back changes performed within a transaction" + }, + { + "name": "append", + "function": "Append", + "docs": "Append the passed DataFrame to the named table", + "args": [ + { + "name": "table_name" + }, + { + "name": "df" + } + ], + "kwargs": [ + { + "name": "by_name", + "default": "false" + } + ] + }, + { + "name": "register", + "function": "RegisterPythonObject", + "docs": "Register the passed Python Object value for querying with a view", + "args": [ + { + "name": "view_name" + }, + { + "name": "python_object" + } + ] + }, + { + "name": "unregister", + "function": "UnregisterPythonObject", + "docs": "Unregister the view name", + "args": [ + { + "name": "view_name" + } + ] + }, + { + "name": "table", + "function": "Table", + "docs": "Create a relation object for the name'd table", + "args": [ + { + "name": "table_name" + } + ] + }, + { + "name": "view", + "function": "View", + "docs": "Create a relation object for the name'd view", + "args": [ + { + "name": "view_name" + } + ] + }, + { + "name": "values", + "function": "Values", + "docs": "Create a relation object from the passed values", + "args": [ + { + "name": "values" + } + ] + }, + { + "name": "table_function", + "function": "TableFunction", + "docs": "Create a relation object from the name'd table function with given parameters", + "args": [ + { + "name": "name" + }, + { + "name": "parameters", + "default": "None" + } + ] + }, + { + "name": "read_json", + "function": "ReadJSON", + "docs": "Create a relation object from the JSON file in 'name'", + "args": [ + { + "name": "name" + } + ], + "kwargs": [ + { + "name": "columns", + "default": "None" + }, + { + "name": "sample_size", + "default": "None" + }, + { + "name": "maximum_depth", + "default": "None" + }, + { + "name": "records", + "default": "None" + }, + { + "name": "format", + "default": "None" + } + ] + }, + { + "name": "extract_statements", + "function": "ExtractStatements", + "docs": "Parse the query string and extract the Statement object(s) produced", + "args": [ + { + "name": "query" + } + ] + }, + { + "name": ["sql", "query", "from_query"], + "function": "RunQuery", + "docs": "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, otherwise run the query as-is.", + "args": [ + { + "name": "query" + } + ], + "kwargs": [ + { + "name": "alias", + "default": "" + }, + { + "name": "params", + "default": "None" + } + ] + }, + + { + "name": ["read_csv", "from_csv_auto"], + "function": "ReadCSV", + "docs": "Create a relation object from the CSV file in 'name'", + "args": [ + { + "name": "name" + } + ], + "kwargs": [ + { + "name": "header", + "default": "None" + }, + { + "name": "compression", + "default": "None" + }, + { + "name": "sep", + "default": "None" + }, + { + "name": "delimiter", + "default": "None" + }, + { + "name": "dtype", + "default": "None" + }, + { + "name": "na_values", + "default": "None" + }, + { + "name": "skiprows", + "default": "None" + }, + { + "name": "quotechar", + "default": "None" + }, + { + "name": "escapechar", + "default": "None" + }, + { + "name": "encoding", + "default": "None" + }, + { + "name": "parallel", + "default": "None" + }, + { + "name": "date_format", + "default": "None" + }, + { + "name": "timestamp_format", + "default": "None" + }, + { + "name": "sample_size", + "default": "None" + }, + { + "name": "all_varchar", + "default": "None" + }, + { + "name": "normalize_names", + "default": "None" + }, + { + "name": "filename", + "default": "None" + }, + { + "name": "null_padding", + "default": "None" + }, + { + "name": "names", + "default": "None" + } + ] + }, + { + "name": "from_df", + "function": "FromDF", + "docs": "Create a relation object from the Data.Frame in df", + "args": [ + { + "name": "df", + "default": "None" + } + ] + }, + { + "name": "from_arrow", + "function": "FromArrow", + "docs": "Create a relation object from an Arrow object", + "args": [ + { + "name": "arrow_object" + } + ] + }, + { + "name": ["from_parquet", "read_parquet"], + "function": "FromParquet", + "docs": "Create a relation object from the Parquet files in file_glob", + "args": [ + { + "name": "file_glob" + }, + { + "name": "binary_as_string", + "default": "false" + } + ], + "kwargs": [ + { + "name": "file_row_number", + "default": "false" + }, + { + "name": "filename", + "default": "false" + }, + { + "name": "hive_partitioning", + "default": "false" + }, + { + "name": "union_by_name", + "default": "false" + }, + { + "name": "compression", + "default": "None" + } + ] + }, + { + "name": ["from_parquet", "read_parquet"], + "function": "FromParquets", + "docs": "Create a relation object from the Parquet files in file_globs", + "args": [ + { + "name": "file_globs" + }, + { + "name": "binary_as_string", + "default": "false" + } + ], + "kwargs": [ + { + "name": "file_row_number", + "default": "false" + }, + { + "name": "filename", + "default": "false" + }, + { + "name": "hive_partitioning", + "default": "false" + }, + { + "name": "union_by_name", + "default": "false" + }, + { + "name": "compression", + "default": "None" + } + ] + }, + { + "name": "from_substrait", + "function": "FromSubstrait", + "docs": "Create a query object from protobuf plan", + "args": [ + { + "name": "proto" + } + ] + }, + { + "name": "get_substrait", + "function": "GetSubstrait", + "docs": "Serialize a query to protobuf", + "args": [ + { + "name": "query" + } + ], + "kwargs": [ + { + "name": "enable_optimizer", + "default": "True" + } + ] + }, + { + "name": "get_substrait_json", + "function": "GetSubstraitJSON", + "docs": "Serialize a query to protobuf on the JSON format", + "args": [ + { + "name": "query" + } + ], + "kwargs": [ + { + "name": "enable_optimizer", + "default": "True" + } + ] + }, + { + "name": "from_substrait_json", + "function": "FromSubstraitJSON", + "docs": "Create a query object from a JSON protobuf plan", + "args": [ + { + "name": "json" + } + ] + }, + { + "name": "get_table_names", + "function": "GetTableNames", + "docs": "Extract the required table names from a query", + "args": [ + { + "name": "query" + } + ] + }, + { + "name": "install_extension", + "function": "InstallExtension", + "docs": "Install an extension by name", + "args": [ + { + "name": "extension" + } + ], + "kwargs": [ + { + "name": "force_install", + "default": "false" + } + ] + }, + { + "name": "load_extension", + "function": "LoadExtension", + "docs": "Load an installed extension", + "args": [ + { + "name": "extension" + } + ] } ] diff --git a/tools/pythonpkg/scripts/generate_connection_methods.py b/tools/pythonpkg/scripts/generate_connection_methods.py index 1882e6e0afa6..a251c90d4cf9 100644 --- a/tools/pythonpkg/scripts/generate_connection_methods.py +++ b/tools/pythonpkg/scripts/generate_connection_methods.py @@ -48,36 +48,43 @@ def map_default(val): for conn in connection_methods: - definition = f"m.def(\"{conn['name']}\"" - definition += ", " - definition += f"""&DuckDBPyConnection::{conn['function']}""" - definition += ", " - definition += f"\"{conn['docs']}\"" - if 'args' in conn: + if isinstance(conn['name'], list): + names = conn['name'] + else: + names = [conn['name']] + for name in names: + definition = f"m.def(\"{name}\"" definition += ", " - arguments = [] - for arg in conn['args']: - argument = f"py::arg(\"{arg['name']}\")" - # Add the default argument if present - if 'default' in arg: - default = map_default(arg['default']) - argument += f" = {default}" - arguments.append(argument) - definition += ', '.join(arguments) - if 'kwargs' in conn: + definition += f"""&DuckDBPyConnection::{conn['function']}""" definition += ", " - definition += "py::kw_only(), " - keyword_arguments = [] - for kwarg in conn['kwargs']: - keyword_argument = f"py::arg(\"{kwarg['name']}\")" - # Add the default argument if present - if 'default' in arg: - default = map_default(arg['default']) - keyword_argument += f" = {default}" - keyword_arguments.append(keyword_argument) - definition += ', '.join(keyword_arguments) - definition += ");" - body.append(definition) + definition += f"\"{conn['docs']}\"" + if 'args' in conn: + definition += ", " + arguments = [] + for arg in conn['args']: + argument = f"py::arg(\"{arg['name']}\")" + # TODO: add '.none(false)' if required (add 'allow_none' to the JSON) + # Add the default argument if present + if 'default' in arg: + default = map_default(arg['default']) + argument += f" = {default}" + arguments.append(argument) + definition += ', '.join(arguments) + if 'kwargs' in conn: + definition += ", " + definition += "py::kw_only(), " + keyword_arguments = [] + for kwarg in conn['kwargs']: + keyword_argument = f"py::arg(\"{kwarg['name']}\")" + # TODO: add '.none(false)' if required (add 'allow_none' to the JSON) + # Add the default argument if present + if 'default' in arg: + default = map_default(arg['default']) + keyword_argument += f" = {default}" + keyword_arguments.append(keyword_argument) + definition += ', '.join(keyword_arguments) + definition += ");" + body.append(definition) # ---- End of generation code ---- From eb87a36c1e38b6f3e7dcba85d25a17ac67a12d63 Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 17:03:08 +0200 Subject: [PATCH 077/201] more default remapping --- tools/pythonpkg/scripts/generate_connection_methods.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tools/pythonpkg/scripts/generate_connection_methods.py b/tools/pythonpkg/scripts/generate_connection_methods.py index a251c90d4cf9..c4c73159337a 100644 --- a/tools/pythonpkg/scripts/generate_connection_methods.py +++ b/tools/pythonpkg/scripts/generate_connection_methods.py @@ -36,9 +36,13 @@ connection_methods = json.load(json_file) body = [] -body.extend(['', '']) -DEFAULT_ARGUMENT_MAP = {'True': 'true', 'False': 'false', 'None': 'py::none()'} +DEFAULT_ARGUMENT_MAP = { + 'True': 'true', + 'False': 'false', + 'None': 'py::none()', + 'PythonUDFType.NATIVE': 'PythonUDFType::NATIVE', +} def map_default(val): From 78254624b8f7515905736d0b422705b28504ccee Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 17:10:20 +0200 Subject: [PATCH 078/201] generate the isnone(false) --- .../pythonpkg/scripts/connection_methods.json | 6 +- .../scripts/generate_connection_methods.py | 74 +++++++++---------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/tools/pythonpkg/scripts/connection_methods.json b/tools/pythonpkg/scripts/connection_methods.json index 39e11bb0ba3c..2fc427ca039d 100644 --- a/tools/pythonpkg/scripts/connection_methods.json +++ b/tools/pythonpkg/scripts/connection_methods.json @@ -468,7 +468,7 @@ { "name": "table", "function": "Table", - "docs": "Create a relation object for the name'd table", + "docs": "Create a relation object for the named table", "args": [ { "name": "table_name" @@ -478,7 +478,7 @@ { "name": "view", "function": "View", - "docs": "Create a relation object for the name'd view", + "docs": "Create a relation object for the named view", "args": [ { "name": "view_name" @@ -498,7 +498,7 @@ { "name": "table_function", "function": "TableFunction", - "docs": "Create a relation object from the name'd table function with given parameters", + "docs": "Create a relation object from the named table function with given parameters", "args": [ { "name": "name" diff --git a/tools/pythonpkg/scripts/generate_connection_methods.py b/tools/pythonpkg/scripts/generate_connection_methods.py index c4c73159337a..f28470986d59 100644 --- a/tools/pythonpkg/scripts/generate_connection_methods.py +++ b/tools/pythonpkg/scripts/generate_connection_methods.py @@ -50,45 +50,45 @@ def map_default(val): return DEFAULT_ARGUMENT_MAP[val] return val - -for conn in connection_methods: - if isinstance(conn['name'], list): - names = conn['name'] - else: - names = [conn['name']] - for name in names: - definition = f"m.def(\"{name}\"" +def create_arguments(arguments) -> list: + result = [] + for arg in arguments: + argument = f"py::arg(\"{arg['name']}\")" + if 'allow_none' in arg: + value = str(arg['allow_none']).lower() + argument += f".none({value})" + # Add the default argument if present + if 'default' in arg: + default = map_default(arg['default']) + argument += f" = {default}" + result.append(argument) + return result + +def create_definition(name, method) -> str: + definition = f"m.def(\"{name}\"" + definition += ", " + definition += f"""&DuckDBPyConnection::{method['function']}""" + definition += ", " + definition += f"\"{method['docs']}\"" + if 'args' in method: definition += ", " - definition += f"""&DuckDBPyConnection::{conn['function']}""" + arguments = create_arguments(method['args']) + definition += ', '.join(arguments) + if 'kwargs' in method: definition += ", " - definition += f"\"{conn['docs']}\"" - if 'args' in conn: - definition += ", " - arguments = [] - for arg in conn['args']: - argument = f"py::arg(\"{arg['name']}\")" - # TODO: add '.none(false)' if required (add 'allow_none' to the JSON) - # Add the default argument if present - if 'default' in arg: - default = map_default(arg['default']) - argument += f" = {default}" - arguments.append(argument) - definition += ', '.join(arguments) - if 'kwargs' in conn: - definition += ", " - definition += "py::kw_only(), " - keyword_arguments = [] - for kwarg in conn['kwargs']: - keyword_argument = f"py::arg(\"{kwarg['name']}\")" - # TODO: add '.none(false)' if required (add 'allow_none' to the JSON) - # Add the default argument if present - if 'default' in arg: - default = map_default(arg['default']) - keyword_argument += f" = {default}" - keyword_arguments.append(keyword_argument) - definition += ', '.join(keyword_arguments) - definition += ");" - body.append(definition) + definition += "py::kw_only(), " + arguments = create_arguments(method['kwargs']) + definition += ', '.join(arguments) + definition += ");" + return definition + +for method in connection_methods: + if isinstance(method['name'], list): + names = method['name'] + else: + names = [method['name']] + for name in names: + body.append(create_definition(name, method)) # ---- End of generation code ---- From 8684102b873c326ae7714c2ece43f2ceea32078e Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 17:22:21 +0200 Subject: [PATCH 079/201] correctly generating the pybind definition code --- .../pythonpkg/scripts/connection_methods.json | 2 +- .../scripts/generate_connection_methods.py | 9 +- tools/pythonpkg/src/pyconnection.cpp | 287 ++++++++++-------- 3 files changed, 160 insertions(+), 138 deletions(-) diff --git a/tools/pythonpkg/scripts/connection_methods.json b/tools/pythonpkg/scripts/connection_methods.json index 2fc427ca039d..f944e17625f0 100644 --- a/tools/pythonpkg/scripts/connection_methods.json +++ b/tools/pythonpkg/scripts/connection_methods.json @@ -563,7 +563,7 @@ "kwargs": [ { "name": "alias", - "default": "" + "default": "\"\"" }, { "name": "params", diff --git a/tools/pythonpkg/scripts/generate_connection_methods.py b/tools/pythonpkg/scripts/generate_connection_methods.py index f28470986d59..efcf13304f0d 100644 --- a/tools/pythonpkg/scripts/generate_connection_methods.py +++ b/tools/pythonpkg/scripts/generate_connection_methods.py @@ -42,6 +42,8 @@ 'False': 'false', 'None': 'py::none()', 'PythonUDFType.NATIVE': 'PythonUDFType::NATIVE', + 'PythonExceptionHandling.DEFAULT': 'PythonExceptionHandling::FORWARD_ERROR', + 'FunctionNullHandling.DEFAULT': 'FunctionNullHandling::DEFAULT_NULL_HANDLING', } @@ -50,6 +52,7 @@ def map_default(val): return DEFAULT_ARGUMENT_MAP[val] return val + def create_arguments(arguments) -> list: result = [] for arg in arguments: @@ -64,6 +67,7 @@ def create_arguments(arguments) -> list: result.append(argument) return result + def create_definition(name, method) -> str: definition = f"m.def(\"{name}\"" definition += ", " @@ -82,6 +86,7 @@ def create_definition(name, method) -> str: definition += ");" return definition + for method in connection_methods: if isinstance(method['name'], list): names = method['name'] @@ -92,9 +97,7 @@ def create_definition(name, method) -> str: # ---- End of generation code ---- -with_newlines = [x + '\n' for x in body] -print(''.join(with_newlines)) -exit() +with_newlines = ['\t' + x + '\n' for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section diff --git a/tools/pythonpkg/src/pyconnection.cpp b/tools/pythonpkg/src/pyconnection.cpp index 2de161bc3d25..39cf30451f5d 100644 --- a/tools/pythonpkg/src/pyconnection.cpp +++ b/tools/pythonpkg/src/pyconnection.cpp @@ -123,152 +123,168 @@ py::object ArrowTableFromDataframe(const py::object &df) { // Do not edit this function manually, your changes will be overwritten! static void InitializeConnectionMethods(py::class_> &m) { - m.def("cursor", &DuckDBPyConnection::Cursor, "Create a duplicate of the current connection") - .def("register_filesystem", &DuckDBPyConnection::RegisterFilesystem, "Register a fsspec compliant filesystem", - py::arg("filesystem")) - .def("unregister_filesystem", &DuckDBPyConnection::UnregisterFilesystem, "Unregister a filesystem", - py::arg("name")) - .def("list_filesystems", &DuckDBPyConnection::ListFilesystems, - "List registered filesystems, including builtin ones") - .def("filesystem_is_registered", &DuckDBPyConnection::FileSystemIsRegistered, - "Check if a filesystem with the provided name is currently registered", py::arg("name")); - + m.def("cursor", &DuckDBPyConnection::Cursor, "Create a duplicate of the current connection"); + m.def("register_filesystem", &DuckDBPyConnection::RegisterFilesystem, "Register a fsspec compliant filesystem", + py::arg("filesystem")); + m.def("unregister_filesystem", &DuckDBPyConnection::UnregisterFilesystem, "Unregister a filesystem", + py::arg("name")); + m.def("list_filesystems", &DuckDBPyConnection::ListFilesystems, + "List registered filesystems, including builtin ones"); + m.def("filesystem_is_registered", &DuckDBPyConnection::FileSystemIsRegistered, + "Check if a filesystem with the provided name is currently registered", py::arg("name")); m.def("create_function", &DuckDBPyConnection::RegisterScalarUDF, "Create a DuckDB function out of the passing in Python function so it can be used in queries", py::arg("name"), py::arg("function"), py::arg("parameters") = py::none(), py::arg("return_type") = py::none(), - py::kw_only(), py::arg("type") = PythonUDFType::NATIVE, py::arg("null_handling") = 0, - py::arg("exception_handling") = 0, py::arg("side_effects") = false); - + py::kw_only(), py::arg("type") = PythonUDFType::NATIVE, + py::arg("null_handling") = FunctionNullHandling::DEFAULT_NULL_HANDLING, + py::arg("exception_handling") = PythonExceptionHandling::FORWARD_ERROR, py::arg("side_effects") = false); m.def("remove_function", &DuckDBPyConnection::UnregisterUDF, "Remove a previously created function", py::arg("name")); - - DefineMethod({"sqltype", "dtype", "type"}, m, &DuckDBPyConnection::Type, - "Create a type object by parsing the 'type_str' string", py::arg("type_str")); - + m.def("sqltype", &DuckDBPyConnection::Type, "Create a type object by parsing the 'type_str' string", + py::arg("type_str")); + m.def("dtype", &DuckDBPyConnection::Type, "Create a type object by parsing the 'type_str' string", + py::arg("type_str")); + m.def("type", &DuckDBPyConnection::Type, "Create a type object by parsing the 'type_str' string", + py::arg("type_str")); m.def("array_type", &DuckDBPyConnection::ArrayType, "Create an array type object of 'type'", - py::arg("type").none(false), py::arg("size").none(false)); + py::arg("type").none(false), py::arg("size")); m.def("list_type", &DuckDBPyConnection::ListType, "Create a list type object of 'type'", py::arg("type").none(false)); m.def("union_type", &DuckDBPyConnection::UnionType, "Create a union type object from 'members'", - py::arg("members").none(false)) - .def("string_type", &DuckDBPyConnection::StringType, "Create a string type with an optional collation", - py::arg("collation") = string()) - .def("enum_type", &DuckDBPyConnection::EnumType, - "Create an enum type of underlying 'type', consisting of the list of 'values'", py::arg("name"), - py::arg("type"), py::arg("values")) - .def("decimal_type", &DuckDBPyConnection::DecimalType, "Create a decimal type with 'width' and 'scale'", - py::arg("width"), py::arg("scale")); - DefineMethod({"struct_type", "row_type"}, m, &DuckDBPyConnection::StructType, - "Create a struct type object from 'fields'", py::arg("fields")); + py::arg("members").none(false)); + m.def("string_type", &DuckDBPyConnection::StringType, "Create a string type with an optional collation", + py::arg("collation") = ""); + m.def("enum_type", &DuckDBPyConnection::EnumType, + "Create an enum type of underlying 'type', consisting of the list of 'values'", py::arg("name"), + py::arg("type"), py::arg("values")); + m.def("decimal_type", &DuckDBPyConnection::DecimalType, "Create a decimal type with 'width' and 'scale'", + py::arg("width"), py::arg("scale")); + m.def("struct_type", &DuckDBPyConnection::StructType, "Create a struct type object from 'fields'", + py::arg("fields")); + m.def("row_type", &DuckDBPyConnection::StructType, "Create a struct type object from 'fields'", py::arg("fields")); m.def("map_type", &DuckDBPyConnection::MapType, "Create a map type object from 'key_type' and 'value_type'", - py::arg("key").none(false), py::arg("value").none(false)) - .def("duplicate", &DuckDBPyConnection::Cursor, "Create a duplicate of the current connection") - .def("execute", &DuckDBPyConnection::Execute, - "Execute the given SQL query, optionally using prepared statements with parameters set", py::arg("query"), - py::arg("parameters") = py::none(), py::arg("multiple_parameter_sets") = false) - .def("executemany", &DuckDBPyConnection::ExecuteMany, - "Execute the given prepared statement multiple times using the list of parameter sets in parameters", - py::arg("query"), py::arg("parameters") = py::none()) - .def("close", &DuckDBPyConnection::Close, "Close the connection") - .def("interrupt", &DuckDBPyConnection::Interrupt, "Interrupt pending operations") - .def("fetchone", &DuckDBPyConnection::FetchOne, "Fetch a single row from a result following execute") - .def("fetchmany", &DuckDBPyConnection::FetchMany, "Fetch the next set of rows from a result following execute", - py::arg("size") = 1) - .def("fetchall", &DuckDBPyConnection::FetchAll, "Fetch all rows from a result following execute") - .def("fetchnumpy", &DuckDBPyConnection::FetchNumpy, "Fetch a result as list of NumPy arrays following execute") - .def("fetchdf", &DuckDBPyConnection::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), - py::arg("date_as_object") = false) - .def("fetch_df", &DuckDBPyConnection::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), - py::arg("date_as_object") = false) - .def("fetch_df_chunk", &DuckDBPyConnection::FetchDFChunk, - "Fetch a chunk of the result as Data.Frame following execute()", py::arg("vectors_per_chunk") = 1, - py::kw_only(), py::arg("date_as_object") = false) - .def("df", &DuckDBPyConnection::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), - py::arg("date_as_object") = false) - .def("pl", &DuckDBPyConnection::FetchPolars, "Fetch a result as Polars DataFrame following execute()", - py::arg("rows_per_batch") = 1000000) - .def("fetch_arrow_table", &DuckDBPyConnection::FetchArrow, "Fetch a result as Arrow table following execute()", - py::arg("rows_per_batch") = 1000000) - .def("fetch_record_batch", &DuckDBPyConnection::FetchRecordBatchReader, - "Fetch an Arrow RecordBatchReader following execute()", py::arg("rows_per_batch") = 1000000) - .def("arrow", &DuckDBPyConnection::FetchArrow, "Fetch a result as Arrow table following execute()", - py::arg("rows_per_batch") = 1000000) - .def("torch", &DuckDBPyConnection::FetchPyTorch, - "Fetch a result as dict of PyTorch Tensors following execute()") - .def("tf", &DuckDBPyConnection::FetchTF, "Fetch a result as dict of TensorFlow Tensors following execute()") - .def("begin", &DuckDBPyConnection::Begin, "Start a new transaction") - .def("commit", &DuckDBPyConnection::Commit, "Commit changes performed within a transaction") - .def("rollback", &DuckDBPyConnection::Rollback, "Roll back changes performed within a transaction") - .def("append", &DuckDBPyConnection::Append, "Append the passed DataFrame to the named table", - py::arg("table_name"), py::arg("df"), py::kw_only(), py::arg("by_name") = false) - .def("register", &DuckDBPyConnection::RegisterPythonObject, - "Register the passed Python Object value for querying with a view", py::arg("view_name"), - py::arg("python_object")) - .def("unregister", &DuckDBPyConnection::UnregisterPythonObject, "Unregister the view name", - py::arg("view_name")) - .def("table", &DuckDBPyConnection::Table, "Create a relation object for the name'd table", - py::arg("table_name")) - .def("view", &DuckDBPyConnection::View, "Create a relation object for the name'd view", py::arg("view_name")) - .def("values", &DuckDBPyConnection::Values, "Create a relation object from the passed values", - py::arg("values")) - .def("table_function", &DuckDBPyConnection::TableFunction, - "Create a relation object from the name'd table function with given parameters", py::arg("name"), - py::arg("parameters") = py::none()) - .def("read_json", &DuckDBPyConnection::ReadJSON, "Create a relation object from the JSON file in 'name'", - py::arg("name"), py::kw_only(), py::arg("columns") = py::none(), py::arg("sample_size") = py::none(), - py::arg("maximum_depth") = py::none(), py::arg("records") = py::none(), py::arg("format") = py::none()) - .def("extract_statements", &DuckDBPyConnection::ExtractStatements, - "Parse the query string and extract the Statement object(s) produced", py::arg("query")); - - DefineMethod({"sql", "query", "from_query"}, m, &DuckDBPyConnection::RunQuery, - "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, " - "otherwise run the query as-is.", - py::arg("query"), py::kw_only(), py::arg("alias") = "", py::arg("params") = py::none()); - - DefineMethod({"read_csv", "from_csv_auto"}, m, &DuckDBPyConnection::ReadCSV, - "Create a relation object from the CSV file in 'name'", py::arg("name"), py::kw_only(), - py::arg("header") = py::none(), py::arg("compression") = py::none(), py::arg("sep") = py::none(), - py::arg("delimiter") = py::none(), py::arg("dtype") = py::none(), py::arg("na_values") = py::none(), - py::arg("skiprows") = py::none(), py::arg("quotechar") = py::none(), - py::arg("escapechar") = py::none(), py::arg("encoding") = py::none(), py::arg("parallel") = py::none(), - py::arg("date_format") = py::none(), py::arg("timestamp_format") = py::none(), - py::arg("sample_size") = py::none(), py::arg("all_varchar") = py::none(), - py::arg("normalize_names") = py::none(), py::arg("filename") = py::none(), - py::arg("null_padding") = py::none(), py::arg("names") = py::none()); - + py::arg("key").none(false), py::arg("value").none(false)); + m.def("duplicate", &DuckDBPyConnection::Cursor, "Create a duplicate of the current connection"); + m.def("execute", &DuckDBPyConnection::Execute, + "Execute the given SQL query, optionally using prepared statements with parameters set", py::arg("query"), + py::arg("parameters") = py::none(), py::arg("multiple_parameter_sets") = false); + m.def("executemany", &DuckDBPyConnection::ExecuteMany, + "Execute the given prepared statement multiple times using the list of parameter sets in parameters", + py::arg("query"), py::arg("parameters") = py::none()); + m.def("close", &DuckDBPyConnection::Close, "Close the connection"); + m.def("interrupt", &DuckDBPyConnection::Interrupt, "Interrupt pending operations"); + m.def("fetchone", &DuckDBPyConnection::FetchOne, "Fetch a single row from a result following execute"); + m.def("fetchmany", &DuckDBPyConnection::FetchMany, "Fetch the next set of rows from a result following execute", + py::arg("size") = 1); + m.def("fetchall", &DuckDBPyConnection::FetchAll, "Fetch all rows from a result following execute"); + m.def("fetchnumpy", &DuckDBPyConnection::FetchNumpy, "Fetch a result as list of NumPy arrays following execute"); + m.def("fetchdf", &DuckDBPyConnection::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), + py::arg("date_as_object") = false); + m.def("fetch_df", &DuckDBPyConnection::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), + py::arg("date_as_object") = false); + m.def("fetch_df_chunk", &DuckDBPyConnection::FetchDFChunk, + "Fetch a chunk of the result as Data.Frame following execute()", py::arg("vectors_per_chunk") = 1, + py::kw_only(), py::arg("date_as_object") = false); + m.def("df", &DuckDBPyConnection::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), + py::arg("date_as_object") = false); + m.def("pl", &DuckDBPyConnection::FetchPolars, "Fetch a result as Polars DataFrame following execute()", + py::arg("rows_per_batch") = 1000000); + m.def("fetch_arrow_table", &DuckDBPyConnection::FetchArrow, "Fetch a result as Arrow table following execute()", + py::arg("rows_per_batch") = 1000000); + m.def("fetch_record_batch", &DuckDBPyConnection::FetchRecordBatchReader, + "Fetch an Arrow RecordBatchReader following execute()", py::arg("rows_per_batch") = 1000000); + m.def("arrow", &DuckDBPyConnection::FetchArrow, "Fetch a result as Arrow table following execute()", + py::arg("rows_per_batch") = 1000000); + m.def("torch", &DuckDBPyConnection::FetchPyTorch, "Fetch a result as dict of PyTorch Tensors following execute()"); + m.def("tf", &DuckDBPyConnection::FetchTF, "Fetch a result as dict of TensorFlow Tensors following execute()"); + m.def("begin", &DuckDBPyConnection::Begin, "Start a new transaction"); + m.def("commit", &DuckDBPyConnection::Commit, "Commit changes performed within a transaction"); + m.def("rollback", &DuckDBPyConnection::Rollback, "Roll back changes performed within a transaction"); + m.def("append", &DuckDBPyConnection::Append, "Append the passed DataFrame to the named table", + py::arg("table_name"), py::arg("df"), py::kw_only(), py::arg("by_name") = false); + m.def("register", &DuckDBPyConnection::RegisterPythonObject, + "Register the passed Python Object value for querying with a view", py::arg("view_name"), + py::arg("python_object")); + m.def("unregister", &DuckDBPyConnection::UnregisterPythonObject, "Unregister the view name", py::arg("view_name")); + m.def("table", &DuckDBPyConnection::Table, "Create a relation object for the named table", py::arg("table_name")); + m.def("view", &DuckDBPyConnection::View, "Create a relation object for the named view", py::arg("view_name")); + m.def("values", &DuckDBPyConnection::Values, "Create a relation object from the passed values", py::arg("values")); + m.def("table_function", &DuckDBPyConnection::TableFunction, + "Create a relation object from the named table function with given parameters", py::arg("name"), + py::arg("parameters") = py::none()); + m.def("read_json", &DuckDBPyConnection::ReadJSON, "Create a relation object from the JSON file in 'name'", + py::arg("name"), py::kw_only(), py::arg("columns") = py::none(), py::arg("sample_size") = py::none(), + py::arg("maximum_depth") = py::none(), py::arg("records") = py::none(), py::arg("format") = py::none()); + m.def("extract_statements", &DuckDBPyConnection::ExtractStatements, + "Parse the query string and extract the Statement object(s) produced", py::arg("query")); + m.def("sql", &DuckDBPyConnection::RunQuery, + "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, otherwise " + "run the query as-is.", + py::arg("query"), py::kw_only(), py::arg("alias") = "", py::arg("params") = py::none()); + m.def("query", &DuckDBPyConnection::RunQuery, + "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, otherwise " + "run the query as-is.", + py::arg("query"), py::kw_only(), py::arg("alias") = "", py::arg("params") = py::none()); + m.def("from_query", &DuckDBPyConnection::RunQuery, + "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, otherwise " + "run the query as-is.", + py::arg("query"), py::kw_only(), py::arg("alias") = "", py::arg("params") = py::none()); + m.def("read_csv", &DuckDBPyConnection::ReadCSV, "Create a relation object from the CSV file in 'name'", + py::arg("name"), py::kw_only(), py::arg("header") = py::none(), py::arg("compression") = py::none(), + py::arg("sep") = py::none(), py::arg("delimiter") = py::none(), py::arg("dtype") = py::none(), + py::arg("na_values") = py::none(), py::arg("skiprows") = py::none(), py::arg("quotechar") = py::none(), + py::arg("escapechar") = py::none(), py::arg("encoding") = py::none(), py::arg("parallel") = py::none(), + py::arg("date_format") = py::none(), py::arg("timestamp_format") = py::none(), + py::arg("sample_size") = py::none(), py::arg("all_varchar") = py::none(), + py::arg("normalize_names") = py::none(), py::arg("filename") = py::none(), + py::arg("null_padding") = py::none(), py::arg("names") = py::none()); + m.def("from_csv_auto", &DuckDBPyConnection::ReadCSV, "Create a relation object from the CSV file in 'name'", + py::arg("name"), py::kw_only(), py::arg("header") = py::none(), py::arg("compression") = py::none(), + py::arg("sep") = py::none(), py::arg("delimiter") = py::none(), py::arg("dtype") = py::none(), + py::arg("na_values") = py::none(), py::arg("skiprows") = py::none(), py::arg("quotechar") = py::none(), + py::arg("escapechar") = py::none(), py::arg("encoding") = py::none(), py::arg("parallel") = py::none(), + py::arg("date_format") = py::none(), py::arg("timestamp_format") = py::none(), + py::arg("sample_size") = py::none(), py::arg("all_varchar") = py::none(), + py::arg("normalize_names") = py::none(), py::arg("filename") = py::none(), + py::arg("null_padding") = py::none(), py::arg("names") = py::none()); m.def("from_df", &DuckDBPyConnection::FromDF, "Create a relation object from the Data.Frame in df", - py::arg("df") = py::none()) - .def("from_arrow", &DuckDBPyConnection::FromArrow, "Create a relation object from an Arrow object", - py::arg("arrow_object")); - - DefineMethod({"from_parquet", "read_parquet"}, m, &DuckDBPyConnection::FromParquet, - "Create a relation object from the Parquet files in file_glob", py::arg("file_glob"), - py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, - py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, - py::arg("compression") = py::none()); - DefineMethod({"from_parquet", "read_parquet"}, m, &DuckDBPyConnection::FromParquets, - "Create a relation object from the Parquet files in file_globs", py::arg("file_globs"), - py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, - py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, - py::arg("compression") = py::none()); - + py::arg("df") = py::none()); + m.def("from_arrow", &DuckDBPyConnection::FromArrow, "Create a relation object from an Arrow object", + py::arg("arrow_object")); + m.def("from_parquet", &DuckDBPyConnection::FromParquet, + "Create a relation object from the Parquet files in file_glob", py::arg("file_glob"), + py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, + py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, + py::arg("compression") = py::none()); + m.def("read_parquet", &DuckDBPyConnection::FromParquet, + "Create a relation object from the Parquet files in file_glob", py::arg("file_glob"), + py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, + py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, + py::arg("compression") = py::none()); + m.def("from_parquet", &DuckDBPyConnection::FromParquets, + "Create a relation object from the Parquet files in file_globs", py::arg("file_globs"), + py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, + py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, + py::arg("compression") = py::none()); + m.def("read_parquet", &DuckDBPyConnection::FromParquets, + "Create a relation object from the Parquet files in file_globs", py::arg("file_globs"), + py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, + py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, + py::arg("compression") = py::none()); m.def("from_substrait", &DuckDBPyConnection::FromSubstrait, "Create a query object from protobuf plan", - py::arg("proto")) - .def("get_substrait", &DuckDBPyConnection::GetSubstrait, "Serialize a query to protobuf", py::arg("query"), - py::kw_only(), py::arg("enable_optimizer") = true) - .def("get_substrait_json", &DuckDBPyConnection::GetSubstraitJSON, - "Serialize a query to protobuf on the JSON format", py::arg("query"), py::kw_only(), - py::arg("enable_optimizer") = true) - .def("from_substrait_json", &DuckDBPyConnection::FromSubstraitJSON, - "Create a query object from a JSON protobuf plan", py::arg("json")) - .def("get_table_names", &DuckDBPyConnection::GetTableNames, "Extract the required table names from a query", - py::arg("query")) - .def_property_readonly("description", &DuckDBPyConnection::GetDescription, - "Get result set attributes, mainly column names") - .def_property_readonly("rowcount", &DuckDBPyConnection::GetRowcount, "Get result set row count") - .def("install_extension", &DuckDBPyConnection::InstallExtension, "Install an extension by name", - py::arg("extension"), py::kw_only(), py::arg("force_install") = false) - .def("load_extension", &DuckDBPyConnection::LoadExtension, "Load an installed extension", py::arg("extension")); + py::arg("proto")); + m.def("get_substrait", &DuckDBPyConnection::GetSubstrait, "Serialize a query to protobuf", py::arg("query"), + py::kw_only(), py::arg("enable_optimizer") = true); + m.def("get_substrait_json", &DuckDBPyConnection::GetSubstraitJSON, + "Serialize a query to protobuf on the JSON format", py::arg("query"), py::kw_only(), + py::arg("enable_optimizer") = true); + m.def("from_substrait_json", &DuckDBPyConnection::FromSubstraitJSON, + "Create a query object from a JSON protobuf plan", py::arg("json")); + m.def("get_table_names", &DuckDBPyConnection::GetTableNames, "Extract the required table names from a query", + py::arg("query")); + m.def("install_extension", &DuckDBPyConnection::InstallExtension, "Install an extension by name", + py::arg("extension"), py::kw_only(), py::arg("force_install") = false); + m.def("load_extension", &DuckDBPyConnection::LoadExtension, "Load an installed extension", py::arg("extension")); } // END_OF_CONNECTION_METHODS void DuckDBPyConnection::UnregisterFilesystem(const py::str &name) { @@ -396,6 +412,9 @@ void DuckDBPyConnection::Initialize(py::handle &m) { .def("__exit__", &DuckDBPyConnection::Exit, py::arg("exc_type"), py::arg("exc"), py::arg("traceback")); InitializeConnectionMethods(connection_module); + connection_module.def_property_readonly("description", &DuckDBPyConnection::GetDescription, + "Get result set attributes, mainly column names"); + connection_module.def_property_readonly("rowcount", &DuckDBPyConnection::GetRowcount, "Get result set row count"); PyDateTime_IMPORT; // NOLINT DuckDBPyConnection::ImportCache(); } From d698881eba0f51acff6b5751ab585c7e67be18d9 Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 17:29:14 +0200 Subject: [PATCH 080/201] add markers to the stubs --- tools/pythonpkg/duckdb-stubs/__init__.pyi | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tools/pythonpkg/duckdb-stubs/__init__.pyi b/tools/pythonpkg/duckdb-stubs/__init__.pyi index 1f1efa86078f..b575bd039660 100644 --- a/tools/pythonpkg/duckdb-stubs/__init__.pyi +++ b/tools/pythonpkg/duckdb-stubs/__init__.pyi @@ -251,6 +251,19 @@ def FunctionExpression(function: str, *cols: Expression) -> Expression: ... class DuckDBPyConnection: def __init__(self, *args, **kwargs) -> None: ... + def __enter__(self) -> DuckDBPyConnection: ... + def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: ... + @property + def description(self) -> Optional[List[Any]]: ... + @property + def rowcount(self) -> int: ... + + + + # NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_stubs.py. + # Do not edit this section manually, your changes will be overwritten! + + # START OF CONNECTION METHODS def append(self, table_name: str, df: pandas.DataFrame) -> DuckDBPyConnection: ... def arrow(self, rows_per_batch: int = ...) -> pyarrow.lib.Table: ... def begin(self) -> DuckDBPyConnection: ... @@ -386,12 +399,7 @@ class DuckDBPyConnection: def list_type(self, type: DuckDBPyType) -> DuckDBPyType: ... def array_type(self, type: DuckDBPyType, size: int) -> DuckDBPyType: ... def map_type(self, key: DuckDBPyType, value: DuckDBPyType) -> DuckDBPyType: ... - def __enter__(self) -> DuckDBPyConnection: ... - def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: ... - @property - def description(self) -> Optional[List[Any]]: ... - @property - def rowcount(self) -> int: ... + # END OF CONNECTION METHODS class DuckDBPyRelation: def close(self) -> None: ... From fbc53d0293cce5c65cec54844848382c2ee9cc96 Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 20:34:50 +0200 Subject: [PATCH 081/201] add return types and argument types --- .../pythonpkg/scripts/connection_methods.json | 443 +++++++++++------- .../scripts/generate_connection_stubs.py | 88 ++++ 2 files changed, 357 insertions(+), 174 deletions(-) create mode 100644 tools/pythonpkg/scripts/generate_connection_stubs.py diff --git a/tools/pythonpkg/scripts/connection_methods.json b/tools/pythonpkg/scripts/connection_methods.json index f944e17625f0..f0d204c3af94 100644 --- a/tools/pythonpkg/scripts/connection_methods.json +++ b/tools/pythonpkg/scripts/connection_methods.json @@ -14,7 +14,8 @@ "name": "filesystem", "type": "str" } - ] + ], + "return": "None" }, { "name": "unregister_filesystem", @@ -25,7 +26,8 @@ "name": "name", "type": "str" } - ] + ], + "return": "None" }, { "name": "list_filesystems", @@ -39,7 +41,8 @@ "docs": "Check if a filesystem with the provided name is currently registered", "args": [ { - "name": "name" + "name": "name", + "type": "str" } ], "return": "bool" @@ -130,7 +133,8 @@ "name": "size", "type": "int" } - ] + ], + "return": "DuckDBPyType" }, { "name": "list_type", @@ -142,7 +146,8 @@ "type": "DuckDBPyType", "allow_none": false } - ] + ], + "return": "DuckDBPyType" }, { "name": "union_type", @@ -154,7 +159,8 @@ "type": "DuckDBPyType", "allow_none": false } - ] + ], + "return": "DuckDBPyType" }, { "name": "string_type", @@ -166,7 +172,8 @@ "type": "str", "default": "\"\"" } - ] + ], + "return": "DuckDBPyType" }, { "name": "enum_type", @@ -185,7 +192,8 @@ "name": "values", "type": "List[Any]" } - ] + ], + "return": "DuckDBPyType" }, { "name": "decimal_type", @@ -200,7 +208,8 @@ "name": "scale", "type": "int" } - ] + ], + "return": "DuckDBPyType" }, { "name": ["struct_type", "row_type"], @@ -211,7 +220,8 @@ "name": "fields", "type": "Union[Dict[str, DuckDBPyType], List[str]]" } - ] + ], + "return": "DuckDBPyType" }, { "name": "map_type", @@ -220,18 +230,22 @@ "args": [ { "name": "key", - "allow_none": false + "allow_none": false, + "type": "DuckDBPyType" }, { "name": "value", - "allow_none": false + "allow_none": false, + "type": "DuckDBPyType" } - ] + ], + "return": "DuckDBPyType" }, { "name": "duplicate", "function": "Cursor", - "docs": "Create a duplicate of the current connection" + "docs": "Create a duplicate of the current connection", + "return": "DuckDBPyConnection" }, { "name": "execute", @@ -239,17 +253,21 @@ "docs": "Execute the given SQL query, optionally using prepared statements with parameters set", "args": [ { - "name": "query" + "name": "query", + "type": "object" }, { "name": "parameters", - "default": "None" + "default": "None", + "type": "object" }, { "name": "multiple_parameter_sets", - "default": "false" + "default": "false", + "type": "bool" } - ] + ], + "return": "DuckDBPyConnection" }, { "name": "executemany", @@ -257,28 +275,34 @@ "docs": "Execute the given prepared statement multiple times using the list of parameter sets in parameters", "args": [ { - "name": "query" + "name": "query", + "type": "object" }, { "name": "parameters", - "default": "None" + "default": "None", + "type": "object" } - ] + ], + "return": "DuckDBPyConnection" }, { "name": "close", "function": "Close", - "docs": "Close the connection" + "docs": "Close the connection", + "return": "None" }, { "name": "interrupt", "function": "Interrupt", - "docs": "Interrupt pending operations" + "docs": "Interrupt pending operations", + "return": "None" }, { "name": "fetchone", "function": "FetchOne", - "docs": "Fetch a single row from a result following execute" + "docs": "Fetch a single row from a result following execute", + "return": "Optional[tuple]" }, { "name": "fetchmany", @@ -287,69 +311,56 @@ "args": [ { "name": "size", - "default": "1" + "default": "1", + "type": "int" } - ] + ], + "return": "List[Any]" }, { "name": "fetchall", "function": "FetchAll", - "docs": "Fetch all rows from a result following execute" + "docs": "Fetch all rows from a result following execute", + "return": "List[Any]" }, { "name": "fetchnumpy", "function": "FetchNumpy", - "docs": "Fetch a result as list of NumPy arrays following execute" - }, - { - "name": "fetchdf", - "function": "FetchDF", - "docs": "Fetch a result as DataFrame following execute()", - "kwargs": [ - { - "name": "date_as_object", - "default": "false" - } - ] + "docs": "Fetch a result as list of NumPy arrays following execute", + "return": "dict" }, { - "name": "fetch_df", + "name": ["fetchdf", "fetch_df", "df"], "function": "FetchDF", "docs": "Fetch a result as DataFrame following execute()", "kwargs": [ { "name": "date_as_object", - "default": "false" + "default": "false", + "type": "bool" } - ] + ], + "return": "pandas.DataFrame" }, { "name": "fetch_df_chunk", "function": "FetchDFChunk", - "docs": "Fetch a chunk of the result as Data.Frame following execute()", + "docs": "Fetch a chunk of the result as DataFrame following execute()", "args": [ { "name": "vectors_per_chunk", - "default": "1" + "default": "1", + "type": "int" } ], "kwargs": [ { "name": "date_as_object", - "default": "false" + "default": "false", + "type": "bool" } - ] - }, - { - "name": "df", - "function": "FetchDF", - "docs": "Fetch a result as DataFrame following execute()", - "kwargs": [ - { - "name": "date_as_object", - "default": "false" - } - ] + ], + "return": "pandas.DataFrame" }, { "name": "pl", @@ -358,20 +369,24 @@ "args": [ { "name": "rows_per_batch", - "default": "1000000" + "default": "1000000", + "type": "int" } - ] + ], + "return": "polars.DataFrame" }, { - "name": "fetch_arrow_table", + "name": ["fetch_arrow_table", "arrow"], "function": "FetchArrow", "docs": "Fetch a result as Arrow table following execute()", "args": [ { "name": "rows_per_batch", - "default": "1000000" + "default": "1000000", + "type": "int" } - ] + ], + "return": "pyarrow.lib.Table" }, { "name": "fetch_record_batch", @@ -380,48 +395,41 @@ "args": [ { "name": "rows_per_batch", - "default": "1000000" - } - ] - }, - { - "name": "arrow", - "function": "FetchArrow", - "docs": "Fetch a result as Arrow table following execute()", - "args": [ - { - "name": "rows_per_batch", - "default": "1000000" + "default": "1000000", + "type": "int" } - ] + ], + "return": "pyarrow.lib.RecordBatchReader" }, { "name": "torch", "function": "FetchPyTorch", - "docs": "Fetch a result as dict of PyTorch Tensors following execute()" - + "docs": "Fetch a result as dict of PyTorch Tensors following execute()", + "return": "dict" }, { "name": "tf", "function": "FetchTF", - "docs": "Fetch a result as dict of TensorFlow Tensors following execute()" - + "docs": "Fetch a result as dict of TensorFlow Tensors following execute()", + "return": "dict" }, { "name": "begin", "function": "Begin", - "docs": "Start a new transaction" - + "docs": "Start a new transaction", + "return": "DuckDBPyConnection" }, { "name": "commit", "function": "Commit", - "docs": "Commit changes performed within a transaction" + "docs": "Commit changes performed within a transaction", + "return": "DuckDBPyConnection" }, { "name": "rollback", "function": "Rollback", - "docs": "Roll back changes performed within a transaction" + "docs": "Roll back changes performed within a transaction", + "return": "DuckDBPyConnection" }, { "name": "append", @@ -429,18 +437,22 @@ "docs": "Append the passed DataFrame to the named table", "args": [ { - "name": "table_name" + "name": "table_name", + "type": "str" }, { - "name": "df" + "name": "df", + "type": "pandas.DataFrame" } ], "kwargs": [ { "name": "by_name", - "default": "false" + "default": "false", + "type": "bool" } - ] + ], + "return": "DuckDBPyConnection" }, { "name": "register", @@ -448,12 +460,15 @@ "docs": "Register the passed Python Object value for querying with a view", "args": [ { - "name": "view_name" + "name": "view_name", + "type": "str" }, { - "name": "python_object" + "name": "python_object", + "type": "object" } - ] + ], + "return": "DuckDBPyConnection" }, { "name": "unregister", @@ -461,9 +476,11 @@ "docs": "Unregister the view name", "args": [ { - "name": "view_name" + "name": "view_name", + "type": "str" } - ] + ], + "return": "DuckDBPyConnection" }, { "name": "table", @@ -471,9 +488,11 @@ "docs": "Create a relation object for the named table", "args": [ { - "name": "table_name" + "name": "table_name", + "type": "str" } - ] + ], + "return": "DuckDBPyRelation" }, { "name": "view", @@ -481,9 +500,11 @@ "docs": "Create a relation object for the named view", "args": [ { - "name": "view_name" + "name": "view_name", + "type": "str" } - ] + ], + "return": "DuckDBPyRelation" }, { "name": "values", @@ -491,9 +512,11 @@ "docs": "Create a relation object from the passed values", "args": [ { - "name": "values" + "name": "values", + "type": "List[Any]" } - ] + ], + "return": "DuckDBPyRelation" }, { "name": "table_function", @@ -501,13 +524,16 @@ "docs": "Create a relation object from the named table function with given parameters", "args": [ { - "name": "name" + "name": "name", + "type": "str" }, { "name": "parameters", - "default": "None" + "default": "None", + "type": "object" } - ] + ], + "return": "DuckDBPyRelation" }, { "name": "read_json", @@ -515,31 +541,38 @@ "docs": "Create a relation object from the JSON file in 'name'", "args": [ { - "name": "name" + "name": "name", + "type": "str" } ], "kwargs": [ { "name": "columns", - "default": "None" + "default": "None", + "type": "Optional[Dict[str,str]]" }, { "name": "sample_size", - "default": "None" + "default": "None", + "type": "Optional[int]" }, { "name": "maximum_depth", - "default": "None" + "default": "None", + "type": "Optional[int]" }, { "name": "records", - "default": "None" + "default": "None", + "type": "Optional[str]" }, { "name": "format", - "default": "None" + "default": "None", + "type": "Optional[str]" } - ] + ], + "return": "DuckDBPyRelation" }, { "name": "extract_statements", @@ -547,9 +580,11 @@ "docs": "Parse the query string and extract the Statement object(s) produced", "args": [ { - "name": "query" + "name": "query", + "type": "str" } - ] + ], + "return": "List[Statement]" }, { "name": ["sql", "query", "from_query"], @@ -557,119 +592,144 @@ "docs": "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, otherwise run the query as-is.", "args": [ { - "name": "query" + "name": "query", + "type": "str" } ], "kwargs": [ { "name": "alias", - "default": "\"\"" + "default": "\"\"", + "type": "str" }, { "name": "params", - "default": "None" + "default": "None", + "type": "object" } - ] + ], + "return": "DuckDBPyRelation" }, - { "name": ["read_csv", "from_csv_auto"], "function": "ReadCSV", "docs": "Create a relation object from the CSV file in 'name'", "args": [ { - "name": "name" + "name": "path_or_buffer", + "type": "Union[str, StringIO, TextIOBase]" } ], "kwargs": [ { "name": "header", - "default": "None" + "default": "None", + "type": "Optional[bool | int]" }, { "name": "compression", - "default": "None" + "default": "None", + "type": "Optional[str]" }, { "name": "sep", - "default": "None" + "default": "None", + "type": "Optional[str]" }, { "name": "delimiter", - "default": "None" + "default": "None", + "type": "Optional[str]" }, { "name": "dtype", - "default": "None" + "default": "None", + "type": "Optional[Dict[str, str] | List[str]]" }, { "name": "na_values", - "default": "None" + "default": "None", + "type": "Optional[str]" }, { "name": "skiprows", - "default": "None" + "default": "None", + "type": "Optional[int]" }, { "name": "quotechar", - "default": "None" + "default": "None", + "type": "Optional[str]" }, { "name": "escapechar", - "default": "None" + "default": "None", + "type": "Optional[str]" }, { "name": "encoding", - "default": "None" + "default": "None", + "type": "Optional[str]" }, { "name": "parallel", - "default": "None" + "default": "None", + "type": "Optional[bool]" }, { "name": "date_format", - "default": "None" + "default": "None", + "type": "Optional[str]" }, { "name": "timestamp_format", - "default": "None" + "default": "None", + "type": "Optional[str]" }, { "name": "sample_size", - "default": "None" + "default": "None", + "type": "Optional[int]" }, { "name": "all_varchar", - "default": "None" + "default": "None", + "type": "Optional[bool]" }, { "name": "normalize_names", - "default": "None" + "default": "None", + "type": "Optional[bool]" }, { "name": "filename", - "default": "None" + "default": "None", + "type": "Optional[bool]" }, { "name": "null_padding", - "default": "None" + "default": "None", + "type": "Optional[bool]" }, { "name": "names", - "default": "None" + "default": "None", + "type": "Optional[List[str]]" } - ] + ], + "return": "DuckDBPyRelation" }, { "name": "from_df", "function": "FromDF", - "docs": "Create a relation object from the Data.Frame in df", + "docs": "Create a relation object from the DataFrame in df", "args": [ { "name": "df", - "default": "None" + "type": "pandas.DataFrame" } - ] + ], + "return": "DuckDBPyRelation" }, { "name": "from_arrow", @@ -677,9 +737,11 @@ "docs": "Create a relation object from an Arrow object", "args": [ { - "name": "arrow_object" + "name": "arrow_object", + "type": "object" } - ] + ], + "return": "DuckDBPyRelation" }, { "name": ["from_parquet", "read_parquet"], @@ -687,35 +749,43 @@ "docs": "Create a relation object from the Parquet files in file_glob", "args": [ { - "name": "file_glob" + "name": "file_glob", + "type": "str" }, { "name": "binary_as_string", - "default": "false" + "default": "false", + "type": "bool" } ], "kwargs": [ { "name": "file_row_number", - "default": "false" + "default": "false", + "type": "bool" }, { "name": "filename", - "default": "false" + "default": "false", + "type": "bool" }, { "name": "hive_partitioning", - "default": "false" + "default": "false", + "type": "bool" }, { "name": "union_by_name", - "default": "false" + "default": "false", + "type": "bool" }, { "name": "compression", - "default": "None" + "default": "None", + "type": "str" } - ] + ], + "return": "DuckDBPyRelation" }, { "name": ["from_parquet", "read_parquet"], @@ -723,35 +793,43 @@ "docs": "Create a relation object from the Parquet files in file_globs", "args": [ { - "name": "file_globs" + "name": "file_globs", + "type": "str" }, { "name": "binary_as_string", - "default": "false" + "default": "false", + "type": "bool" } ], "kwargs": [ { "name": "file_row_number", - "default": "false" + "default": "false", + "type": "bool" }, { "name": "filename", - "default": "false" + "default": "false", + "type": "bool" }, { "name": "hive_partitioning", - "default": "false" + "default": "false", + "type": "bool" }, { "name": "union_by_name", - "default": "false" + "default": "false", + "type": "bool" }, { "name": "compression", - "default": "None" + "default": "None", + "type": "str" } - ] + ], + "return": "DuckDBPyRelation" }, { "name": "from_substrait", @@ -759,9 +837,11 @@ "docs": "Create a query object from protobuf plan", "args": [ { - "name": "proto" + "name": "proto", + "type": "str" } - ] + ], + "return": "DuckDBPyRelation" }, { "name": "get_substrait", @@ -769,15 +849,18 @@ "docs": "Serialize a query to protobuf", "args": [ { - "name": "query" + "name": "query", + "type": "str" } ], "kwargs": [ { "name": "enable_optimizer", - "default": "True" + "default": "True", + "type": "bool" } - ] + ], + "return": "str" }, { "name": "get_substrait_json", @@ -785,15 +868,18 @@ "docs": "Serialize a query to protobuf on the JSON format", "args": [ { - "name": "query" + "name": "query", + "type": "str" } ], "kwargs": [ { "name": "enable_optimizer", - "default": "True" + "default": "True", + "type": "bool" } - ] + ], + "return": "str" }, { "name": "from_substrait_json", @@ -801,9 +887,11 @@ "docs": "Create a query object from a JSON protobuf plan", "args": [ { - "name": "json" + "name": "json", + "type": "str" } - ] + ], + "return": "DuckDBPyRelation" }, { "name": "get_table_names", @@ -811,9 +899,11 @@ "docs": "Extract the required table names from a query", "args": [ { - "name": "query" + "name": "query", + "type": "str" } - ] + ], + "return": "List[str]" }, { "name": "install_extension", @@ -821,15 +911,18 @@ "docs": "Install an extension by name", "args": [ { - "name": "extension" + "name": "extension", + "type": "str" } ], "kwargs": [ { "name": "force_install", - "default": "false" + "default": "false", + "type": "bool" } - ] + ], + "return": "None" }, { "name": "load_extension", @@ -837,8 +930,10 @@ "docs": "Load an installed extension", "args": [ { - "name": "extension" + "name": "extension", + "type": "str" } - ] + ], + "return": "None" } ] diff --git a/tools/pythonpkg/scripts/generate_connection_stubs.py b/tools/pythonpkg/scripts/generate_connection_stubs.py new file mode 100644 index 000000000000..2d6683be8970 --- /dev/null +++ b/tools/pythonpkg/scripts/generate_connection_stubs.py @@ -0,0 +1,88 @@ +import os +import json + +os.chdir(os.path.dirname(__file__)) + +JSON_PATH = os.path.join("connection_methods.json") +DUCKDB_STUBS_FILE = os.path.join("..", "duckdb-stubs", "__init__.pyi") + +START_MARKER = " # START OF CONNECTION METHODS" +END_MARKER = " # END OF CONNECTION METHODS" + +# Read the DUCKDB_STUBS_FILE file +with open(DUCKDB_STUBS_FILE, 'r') as source_file: + source_code = source_file.readlines() + +# Locate the InitializeConnectionMethods function in it +start_index = -1 +end_index = -1 +for i, line in enumerate(source_code): + if line.startswith(START_MARKER): + # TODO: handle the case where the start marker appears multiple times + start_index = i + elif line.startswith(END_MARKER): + # TODO: ditto ^ + end_index = i + +if start_index == -1 or end_index == -1: + raise ValueError("Couldn't find start or end marker in source file") + +start_section = source_code[: start_index + 1] +end_section = source_code[end_index:] +# ---- Generate the definition code from the json ---- + +# Read the JSON file +with open(JSON_PATH, 'r') as json_file: + connection_methods = json.load(json_file) + +body = [] + + +def create_arguments(arguments) -> list: + result = [] + for arg in arguments: + argument = f"{arg['name']}: {arg['type']}" + # Add the default argument if present + if 'default' in arg: + default = arg['default'] + argument += f" = {default}" + result.append(argument) + return result + + +def create_definition(name, method) -> str: + print(method) + definition = f"def {name}(self" + if 'args' in method: + definition += ", " + arguments = create_arguments(method['args']) + definition += ', '.join(arguments) + if 'kwargs' in method: + definition += ", **kwargs" + definition += ")" + definition += f" -> {method['return']}: ..." + return definition + + +for method in connection_methods: + if isinstance(method['name'], list): + names = method['name'] + else: + names = [method['name']] + for name in names: + body.append(create_definition(name, method)) + +# ---- End of generation code ---- + +with_newlines = ['\t' + x + '\n' for x in body] +# Recreate the file content by concatenating all the pieces together + +new_content = start_section + with_newlines + end_section + +print(with_newlines) + +exit() + +# Write out the modified DUCKDB_STUBS_FILE file +with open(DUCKDB_STUBS_FILE, 'w') as source_file: + source_file.write("".join(new_content)) From 1a061d48e5c532a97cdf9c858a75345aebd379df Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 20:50:15 +0200 Subject: [PATCH 082/201] generating the stubs for DuckDBPyConnection --- tools/pythonpkg/duckdb-stubs/__init__.pyi | 181 ++++++------------ .../pythonpkg/scripts/connection_methods.json | 30 +-- .../scripts/generate_connection_stubs.py | 11 +- tools/pythonpkg/src/pyconnection.cpp | 17 +- 4 files changed, 87 insertions(+), 152 deletions(-) diff --git a/tools/pythonpkg/duckdb-stubs/__init__.pyi b/tools/pythonpkg/duckdb-stubs/__init__.pyi index b575bd039660..a39a8b1451ca 100644 --- a/tools/pythonpkg/duckdb-stubs/__init__.pyi +++ b/tools/pythonpkg/duckdb-stubs/__init__.pyi @@ -264,141 +264,72 @@ class DuckDBPyConnection: # Do not edit this section manually, your changes will be overwritten! # START OF CONNECTION METHODS - def append(self, table_name: str, df: pandas.DataFrame) -> DuckDBPyConnection: ... - def arrow(self, rows_per_batch: int = ...) -> pyarrow.lib.Table: ... - def begin(self) -> DuckDBPyConnection: ... - def close(self) -> None: ... - def commit(self) -> DuckDBPyConnection: ... def cursor(self) -> DuckDBPyConnection: ... - def df(self) -> pandas.DataFrame: ... - def duplicate(self) -> DuckDBPyConnection: ... - def execute(self, query: str, parameters: object = ..., multiple_parameter_sets: bool = ...) -> DuckDBPyConnection: ... - def executemany(self, query: str, parameters: object = ...) -> DuckDBPyConnection: ... - def fetch_arrow_table(self, rows_per_batch: int = ...) -> pyarrow.lib.Table: ... - def fetch_df(self, *args, **kwargs) -> pandas.DataFrame: ... - def fetch_df_chunk(self, *args, **kwargs) -> pandas.DataFrame: ... - def fetch_record_batch(self, rows_per_batch: int = ...) -> pyarrow.lib.RecordBatchReader: ... - def fetchall(self) -> List[Any]: ... - def fetchdf(self, *args, **kwargs) -> pandas.DataFrame: ... - def fetchmany(self, size: int = ...) -> List[Any]: ... - def fetchnumpy(self) -> dict: ... - def fetchone(self) -> Optional[tuple]: ... - def from_arrow(self, arrow_object: object) -> DuckDBPyRelation: ... - def read_json( - self, - file_name: str, - columns: Optional[Dict[str,str]] = None, - sample_size: Optional[int] = None, - maximum_depth: Optional[int] = None, - records: Optional[str] = None, - format: Optional[str] = None - ) -> DuckDBPyRelation: ... - def read_csv( - self, - path_or_buffer: Union[str, StringIO, TextIOBase], - header: Optional[bool | int] = None, - compression: Optional[str] = None, - sep: Optional[str] = None, - delimiter: Optional[str] = None, - dtype: Optional[Dict[str, str] | List[str]] = None, - na_values: Optional[str] = None, - skiprows: Optional[int] = None, - quotechar: Optional[str] = None, - escapechar: Optional[str] = None, - encoding: Optional[str] = None, - parallel: Optional[bool] = None, - date_format: Optional[str] = None, - timestamp_format: Optional[str] = None, - sample_size: Optional[int] = None, - all_varchar: Optional[bool] = None, - normalize_names: Optional[bool] = None, - filename: Optional[bool] = None, - null_padding: Optional[bool] = None, - names: Optional[List[str]] = None - ) -> DuckDBPyRelation: ... - def from_csv_auto( - self, - path_or_buffer: Union[str, StringIO, TextIOBase], - header: Optional[bool | int] = None, - compression: Optional[str] = None, - sep: Optional[str] = None, - delimiter: Optional[str] = None, - dtype: Optional[Dict[str, str] | List[str]] = None, - na_values: Optional[str] = None, - skiprows: Optional[int] = None, - quotechar: Optional[str] = None, - escapechar: Optional[str] = None, - encoding: Optional[str] = None, - parallel: Optional[bool] = None, - date_format: Optional[str] = None, - timestamp_format: Optional[str] = None, - sample_size: Optional[int] = None, - all_varchar: Optional[bool] = None, - normalize_names: Optional[bool] = None, - filename: Optional[bool] = None, - null_padding: Optional[bool] = None, - names: Optional[List[str]] = None - ) -> DuckDBPyRelation: ... - def from_df(self, df: pandas.DataFrame = ...) -> DuckDBPyRelation: ... - @overload - def read_parquet(self, file_glob: str, binary_as_string: bool = ..., *, file_row_number: bool = ..., filename: bool = ..., hive_partitioning: bool = ..., union_by_name: bool = ...) -> DuckDBPyRelation: ... - @overload - def read_parquet(self, file_globs: List[str], binary_as_string: bool = ..., *, file_row_number: bool = ..., filename: bool = ..., hive_partitioning: bool = ..., union_by_name: bool = ...) -> DuckDBPyRelation: ... - @overload - def from_parquet(self, file_glob: str, binary_as_string: bool = ..., *, file_row_number: bool = ..., filename: bool = ..., hive_partitioning: bool = ..., union_by_name: bool = ...) -> DuckDBPyRelation: ... - @overload - def from_parquet(self, file_globs: List[str], binary_as_string: bool = ..., *, file_row_number: bool = ..., filename: bool = ..., hive_partitioning: bool = ..., union_by_name: bool = ...) -> DuckDBPyRelation: ... - def from_substrait(self, proto: bytes) -> DuckDBPyRelation: ... - def get_substrait(self, query: str) -> DuckDBPyRelation: ... - def get_substrait_json(self, query: str) -> DuckDBPyRelation: ... - def from_substrait_json(self, json: str) -> DuckDBPyRelation: ... - def get_table_names(self, query: str) -> Set[str]: ... - def install_extension(self, *args, **kwargs) -> None: ... - def interrupt(self) -> None: ... - def list_filesystems(self) -> List[Any]: ... + def register_filesystem(self, filesystem: str) -> None: ... + def unregister_filesystem(self, name: str) -> None: ... + def list_filesystems(self) -> list: ... def filesystem_is_registered(self, name: str) -> bool: ... - def load_extension(self, extension: str) -> None: ... - def pl(self, rows_per_batch: int = ..., connection: DuckDBPyConnection = ...) -> polars.DataFrame: ... - def torch(self, connection: DuckDBPyConnection = ...) -> dict: ... - def tf(self, connection: DuckDBPyConnection = ...) -> dict: ... - - def from_query(self, query: str, **kwargs) -> DuckDBPyRelation: ... - def extract_statements(self, query: str) -> List[Statement]: ... - def query(self, query: str, **kwargs) -> DuckDBPyRelation: ... - def sql(self, query: str, **kwargs) -> DuckDBPyRelation: ... - - def register(self, view_name: str, python_object: object) -> DuckDBPyConnection: ... + def create_function(self, name: str, function: function, parameters: Optional[List[DuckDBPyType]] = None, return_type: Optional[DuckDBPyType] = None, **kwargs) -> DuckDBPyConnection: ... def remove_function(self, name: str) -> DuckDBPyConnection: ... - def create_function( - self, - name: str, - func: Callable, - parameters: Optional[List[DuckDBPyType]] = None, - return_type: Optional[DuckDBPyType] = None, - type: Optional[PythonUDFType] = PythonUDFType.NATIVE, - null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, - exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, - side_effects: Optional[bool] = False) -> DuckDBPyConnection: ... - def register_filesystem(self, filesystem: fsspec.AbstractFileSystem) -> None: ... - def rollback(self) -> DuckDBPyConnection: ... - def table(self, table_name: str) -> DuckDBPyRelation: ... - def table_function(self, name: str, parameters: object = ...) -> DuckDBPyRelation: ... - def unregister(self, view_name: str) -> DuckDBPyConnection: ... - def unregister_filesystem(self, name: str) -> None: ... - def values(self, values: object) -> DuckDBPyRelation: ... - def view(self, view_name: str) -> DuckDBPyRelation: ... def sqltype(self, type_str: str) -> DuckDBPyType: ... def dtype(self, type_str: str) -> DuckDBPyType: ... def type(self, type_str: str) -> DuckDBPyType: ... - def struct_type(self, fields: Union[Dict[str, DuckDBPyType], List[str]]) -> DuckDBPyType: ... - def row_type(self, fields: Union[Dict[str, DuckDBPyType], List[str]]) -> DuckDBPyType: ... - def union_type(self, members: Union[Dict[str, DuckDBPyType], List[str]]) -> DuckDBPyType: ... + def array_type(self, type: DuckDBPyType, size: int) -> DuckDBPyType: ... + def list_type(self, type: DuckDBPyType) -> DuckDBPyType: ... + def union_type(self, members: DuckDBPyType) -> DuckDBPyType: ... def string_type(self, collation: str = "") -> DuckDBPyType: ... def enum_type(self, name: str, type: DuckDBPyType, values: List[Any]) -> DuckDBPyType: ... def decimal_type(self, width: int, scale: int) -> DuckDBPyType: ... - def list_type(self, type: DuckDBPyType) -> DuckDBPyType: ... - def array_type(self, type: DuckDBPyType, size: int) -> DuckDBPyType: ... + def struct_type(self, fields: Union[Dict[str, DuckDBPyType], List[str]]) -> DuckDBPyType: ... + def row_type(self, fields: Union[Dict[str, DuckDBPyType], List[str]]) -> DuckDBPyType: ... def map_type(self, key: DuckDBPyType, value: DuckDBPyType) -> DuckDBPyType: ... + def duplicate(self) -> DuckDBPyConnection: ... + def execute(self, query: object, parameters: object = None, multiple_parameter_sets: bool = False) -> DuckDBPyConnection: ... + def executemany(self, query: object, parameters: object = None) -> DuckDBPyConnection: ... + def close(self) -> None: ... + def interrupt(self) -> None: ... + def fetchone(self) -> Optional[tuple]: ... + def fetchmany(self, size: int = 1) -> List[Any]: ... + def fetchall(self) -> List[Any]: ... + def fetchnumpy(self) -> dict: ... + def fetchdf(self, **kwargs) -> pandas.DataFrame: ... + def fetch_df(self, **kwargs) -> pandas.DataFrame: ... + def df(self, **kwargs) -> pandas.DataFrame: ... + def fetch_df_chunk(self, vectors_per_chunk: int = 1, **kwargs) -> pandas.DataFrame: ... + def pl(self, rows_per_batch: int = 1000000) -> polars.DataFrame: ... + def fetch_arrow_table(self, rows_per_batch: int = 1000000) -> pyarrow.lib.Table: ... + def arrow(self, rows_per_batch: int = 1000000) -> pyarrow.lib.Table: ... + def fetch_record_batch(self, rows_per_batch: int = 1000000) -> pyarrow.lib.RecordBatchReader: ... + def torch(self) -> dict: ... + def tf(self) -> dict: ... + def begin(self) -> DuckDBPyConnection: ... + def commit(self) -> DuckDBPyConnection: ... + def rollback(self) -> DuckDBPyConnection: ... + def append(self, table_name: str, df: pandas.DataFrame, **kwargs) -> DuckDBPyConnection: ... + def register(self, view_name: str, python_object: object) -> DuckDBPyConnection: ... + def unregister(self, view_name: str) -> DuckDBPyConnection: ... + def table(self, table_name: str) -> DuckDBPyRelation: ... + def view(self, view_name: str) -> DuckDBPyRelation: ... + def values(self, values: List[Any]) -> DuckDBPyRelation: ... + def table_function(self, name: str, parameters: object = None) -> DuckDBPyRelation: ... + def read_json(self, name: str, **kwargs) -> DuckDBPyRelation: ... + def extract_statements(self, query: str) -> List[Statement]: ... + def sql(self, query: str, **kwargs) -> DuckDBPyRelation: ... + def query(self, query: str, **kwargs) -> DuckDBPyRelation: ... + def from_query(self, query: str, **kwargs) -> DuckDBPyRelation: ... + def read_csv(self, path_or_buffer: Union[str, StringIO, TextIOBase], **kwargs) -> DuckDBPyRelation: ... + def from_csv_auto(self, path_or_buffer: Union[str, StringIO, TextIOBase], **kwargs) -> DuckDBPyRelation: ... + def from_df(self, df: pandas.DataFrame) -> DuckDBPyRelation: ... + def from_arrow(self, arrow_object: object) -> DuckDBPyRelation: ... + def from_parquet(self, file_glob: str, binary_as_string: bool = False, **kwargs) -> DuckDBPyRelation: ... + def read_parquet(self, file_glob: str, binary_as_string: bool = False, **kwargs) -> DuckDBPyRelation: ... + def from_substrait(self, proto: str) -> DuckDBPyRelation: ... + def get_substrait(self, query: str, **kwargs) -> str: ... + def get_substrait_json(self, query: str, **kwargs) -> str: ... + def from_substrait_json(self, json: str) -> DuckDBPyRelation: ... + def get_table_names(self, query: str) -> List[str]: ... + def install_extension(self, extension: str, **kwargs) -> None: ... + def load_extension(self, extension: str) -> None: ... # END OF CONNECTION METHODS class DuckDBPyRelation: diff --git a/tools/pythonpkg/scripts/connection_methods.json b/tools/pythonpkg/scripts/connection_methods.json index f0d204c3af94..3184459c7615 100644 --- a/tools/pythonpkg/scripts/connection_methods.json +++ b/tools/pythonpkg/scripts/connection_methods.json @@ -263,7 +263,7 @@ }, { "name": "multiple_parameter_sets", - "default": "false", + "default": "False", "type": "bool" } ], @@ -336,7 +336,7 @@ "kwargs": [ { "name": "date_as_object", - "default": "false", + "default": "False", "type": "bool" } ], @@ -356,7 +356,7 @@ "kwargs": [ { "name": "date_as_object", - "default": "false", + "default": "False", "type": "bool" } ], @@ -448,7 +448,7 @@ "kwargs": [ { "name": "by_name", - "default": "false", + "default": "False", "type": "bool" } ], @@ -754,29 +754,29 @@ }, { "name": "binary_as_string", - "default": "false", + "default": "False", "type": "bool" } ], "kwargs": [ { "name": "file_row_number", - "default": "false", + "default": "False", "type": "bool" }, { "name": "filename", - "default": "false", + "default": "False", "type": "bool" }, { "name": "hive_partitioning", - "default": "false", + "default": "False", "type": "bool" }, { "name": "union_by_name", - "default": "false", + "default": "False", "type": "bool" }, { @@ -798,29 +798,29 @@ }, { "name": "binary_as_string", - "default": "false", + "default": "False", "type": "bool" } ], "kwargs": [ { "name": "file_row_number", - "default": "false", + "default": "False", "type": "bool" }, { "name": "filename", - "default": "false", + "default": "False", "type": "bool" }, { "name": "hive_partitioning", - "default": "false", + "default": "False", "type": "bool" }, { "name": "union_by_name", - "default": "false", + "default": "False", "type": "bool" }, { @@ -918,7 +918,7 @@ "kwargs": [ { "name": "force_install", - "default": "false", + "default": "False", "type": "bool" } ], diff --git a/tools/pythonpkg/scripts/generate_connection_stubs.py b/tools/pythonpkg/scripts/generate_connection_stubs.py index 2d6683be8970..cacddd3b0ced 100644 --- a/tools/pythonpkg/scripts/generate_connection_stubs.py +++ b/tools/pythonpkg/scripts/generate_connection_stubs.py @@ -64,25 +64,30 @@ def create_definition(name, method) -> str: return definition +# We have "duplicate" methods, which are overloaded +# maybe we should add @overload to these instead, but this is easier +written_methods = set() + for method in connection_methods: if isinstance(method['name'], list): names = method['name'] else: names = [method['name']] for name in names: + if name in written_methods: + continue body.append(create_definition(name, method)) + written_methods.add(name) # ---- End of generation code ---- -with_newlines = ['\t' + x + '\n' for x in body] +with_newlines = [' ' + x + '\n' for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section print(with_newlines) -exit() - # Write out the modified DUCKDB_STUBS_FILE file with open(DUCKDB_STUBS_FILE, 'w') as source_file: source_file.write("".join(new_content)) diff --git a/tools/pythonpkg/src/pyconnection.cpp b/tools/pythonpkg/src/pyconnection.cpp index 39cf30451f5d..7e5697bcbfc6 100644 --- a/tools/pythonpkg/src/pyconnection.cpp +++ b/tools/pythonpkg/src/pyconnection.cpp @@ -182,19 +182,19 @@ static void InitializeConnectionMethods(py::class_ Date: Wed, 10 Apr 2024 22:09:39 +0200 Subject: [PATCH 083/201] generating the connection wrapper code, placed in __init__.py --- tools/pythonpkg/duckdb-stubs/__init__.pyi | 2 - tools/pythonpkg/duckdb/__init__.py | 457 +++++++++++++++--- tools/pythonpkg/duckdb_python.cpp | 3 +- .../generate_connection_wrapper_methods.py | 142 ++++++ 4 files changed, 531 insertions(+), 73 deletions(-) create mode 100644 tools/pythonpkg/scripts/generate_connection_wrapper_methods.py diff --git a/tools/pythonpkg/duckdb-stubs/__init__.pyi b/tools/pythonpkg/duckdb-stubs/__init__.pyi index a39a8b1451ca..748e1102c290 100644 --- a/tools/pythonpkg/duckdb-stubs/__init__.pyi +++ b/tools/pythonpkg/duckdb-stubs/__init__.pyi @@ -258,8 +258,6 @@ class DuckDBPyConnection: @property def rowcount(self) -> int: ... - - # NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_stubs.py. # Do not edit this section manually, your changes will be overwritten! diff --git a/tools/pythonpkg/duckdb/__init__.py b/tools/pythonpkg/duckdb/__init__.py index 8f01da43bc4f..f4f6af797c6e 100644 --- a/tools/pythonpkg/duckdb/__init__.py +++ b/tools/pythonpkg/duckdb/__init__.py @@ -40,83 +40,402 @@ "CaseExpression", ]) -# ---- Wrap the connection methods - -def is_dunder_method(method_name: str) -> bool: - if len(method_name) < 4: - return False - return method_name[:2] == '__' and method_name[:-3:-1] == '__' - -# Takes the function to execute on the 'connection' -def create_wrapper(func): - def _wrapper(*args, **kwargs): - connection = duckdb.connect(':default:') - if 'connection' in kwargs: - connection = kwargs.pop('connection') - return func(connection, *args, **kwargs) - return _wrapper - -# Takes the name of a DuckDBPyConnection function to wrap (copying signature, docs, etc) -# The 'func' is what gets executed when the function is called -def create_connection_wrapper(name, func): - # Define a decorator function that forwards attribute lookup to the default connection - return functools.wraps(getattr(DuckDBPyConnection, name))(create_wrapper(func)) - # These are overloaded twice, we define them inside of C++ so pybind can deal with it -EXCLUDED_METHODS = [ +_exported_symbols.extend([ 'df', 'arrow' -] -_exported_symbols.extend(EXCLUDED_METHODS) +]) from .duckdb import ( df, arrow ) -methods = [method for method in dir(DuckDBPyConnection) if not is_dunder_method(method) and method not in EXCLUDED_METHODS] -for method_name in methods: - def create_method_wrapper(method_name): - def call_method(conn, *args, **kwargs): - return getattr(conn, method_name)(*args, **kwargs) - return call_method - wrapper_function = create_connection_wrapper(method_name, create_method_wrapper(method_name)) - globals()[method_name] = wrapper_function # Define the wrapper function in the module namespace - _exported_symbols.append(method_name) - - -# Specialized "wrapper" methods - -SPECIAL_METHODS = [ - 'project', - 'distinct', - 'write_csv', - 'aggregate', - 'alias', - 'filter', - 'limit', - 'order', - 'query_df' -] - -for method_name in SPECIAL_METHODS: - def create_method_wrapper(name): - def _closure(name=name): - mapping = { - 'alias': 'set_alias', - 'query_df': 'query' - } - def call_method(con, df, *args, **kwargs): - if name in mapping: - mapped_name = mapping[name] - else: - mapped_name = name - return getattr(con.from_df(df), mapped_name)(*args, **kwargs) - return call_method - return _closure(name) - - wrapper_function = create_wrapper(create_method_wrapper(method_name)) - globals()[method_name] = wrapper_function # Define the wrapper function in the module namespace - _exported_symbols.append(method_name) +def __get_connection__(**kwargs): + if 'connection' in kwargs: + return kwargs.pop('connection') + else: + return duckdb.connect(":default:") + +# NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_wrapper_methods.py. +# Do not edit this section manually, your changes will be overwritten! + +# START OF CONNECTION WRAPPER + +def cursor(**kwargs): + conn = __get_connection__(*kwargs) + return conn.cursor(**kwargs) +_exported_symbols.append('cursor') + +def register_filesystem(filesystem, **kwargs): + conn = __get_connection__(*kwargs) + return conn.register_filesystem(filesystem, **kwargs) +_exported_symbols.append('register_filesystem') + +def unregister_filesystem(name, **kwargs): + conn = __get_connection__(*kwargs) + return conn.unregister_filesystem(name, **kwargs) +_exported_symbols.append('unregister_filesystem') + +def list_filesystems(**kwargs): + conn = __get_connection__(*kwargs) + return conn.list_filesystems(**kwargs) +_exported_symbols.append('list_filesystems') + +def filesystem_is_registered(name, **kwargs): + conn = __get_connection__(*kwargs) + return conn.filesystem_is_registered(name, **kwargs) +_exported_symbols.append('filesystem_is_registered') + +def create_function(name, function, parameters = None, return_type = None, **kwargs): + conn = __get_connection__(*kwargs) + return conn.create_function(name, function, parameters, return_type, **kwargs) +_exported_symbols.append('create_function') + +def remove_function(name, **kwargs): + conn = __get_connection__(*kwargs) + return conn.remove_function(name, **kwargs) +_exported_symbols.append('remove_function') + +def sqltype(type_str, **kwargs): + conn = __get_connection__(*kwargs) + return conn.sqltype(type_str, **kwargs) +_exported_symbols.append('sqltype') + +def dtype(type_str, **kwargs): + conn = __get_connection__(*kwargs) + return conn.dtype(type_str, **kwargs) +_exported_symbols.append('dtype') + +def type(type_str, **kwargs): + conn = __get_connection__(*kwargs) + return conn.type(type_str, **kwargs) +_exported_symbols.append('type') + +def array_type(type, size, **kwargs): + conn = __get_connection__(*kwargs) + return conn.array_type(type, size, **kwargs) +_exported_symbols.append('array_type') + +def list_type(type, **kwargs): + conn = __get_connection__(*kwargs) + return conn.list_type(type, **kwargs) +_exported_symbols.append('list_type') + +def union_type(members, **kwargs): + conn = __get_connection__(*kwargs) + return conn.union_type(members, **kwargs) +_exported_symbols.append('union_type') + +def string_type(collation = "", **kwargs): + conn = __get_connection__(*kwargs) + return conn.string_type(collation, **kwargs) +_exported_symbols.append('string_type') + +def enum_type(name, type, values, **kwargs): + conn = __get_connection__(*kwargs) + return conn.enum_type(name, type, values, **kwargs) +_exported_symbols.append('enum_type') + +def decimal_type(width, scale, **kwargs): + conn = __get_connection__(*kwargs) + return conn.decimal_type(width, scale, **kwargs) +_exported_symbols.append('decimal_type') + +def struct_type(fields, **kwargs): + conn = __get_connection__(*kwargs) + return conn.struct_type(fields, **kwargs) +_exported_symbols.append('struct_type') + +def row_type(fields, **kwargs): + conn = __get_connection__(*kwargs) + return conn.row_type(fields, **kwargs) +_exported_symbols.append('row_type') + +def map_type(key, value, **kwargs): + conn = __get_connection__(*kwargs) + return conn.map_type(key, value, **kwargs) +_exported_symbols.append('map_type') + +def duplicate(**kwargs): + conn = __get_connection__(*kwargs) + return conn.duplicate(**kwargs) +_exported_symbols.append('duplicate') + +def execute(query, parameters = None, multiple_parameter_sets = False, **kwargs): + conn = __get_connection__(*kwargs) + return conn.execute(query, parameters, multiple_parameter_sets, **kwargs) +_exported_symbols.append('execute') + +def executemany(query, parameters = None, **kwargs): + conn = __get_connection__(*kwargs) + return conn.executemany(query, parameters, **kwargs) +_exported_symbols.append('executemany') + +def close(**kwargs): + conn = __get_connection__(*kwargs) + return conn.close(**kwargs) +_exported_symbols.append('close') + +def interrupt(**kwargs): + conn = __get_connection__(*kwargs) + return conn.interrupt(**kwargs) +_exported_symbols.append('interrupt') + +def fetchone(**kwargs): + conn = __get_connection__(*kwargs) + return conn.fetchone(**kwargs) +_exported_symbols.append('fetchone') + +def fetchmany(size = 1, **kwargs): + conn = __get_connection__(*kwargs) + return conn.fetchmany(size, **kwargs) +_exported_symbols.append('fetchmany') + +def fetchall(**kwargs): + conn = __get_connection__(*kwargs) + return conn.fetchall(**kwargs) +_exported_symbols.append('fetchall') + +def fetchnumpy(**kwargs): + conn = __get_connection__(*kwargs) + return conn.fetchnumpy(**kwargs) +_exported_symbols.append('fetchnumpy') + +def fetchdf(**kwargs): + conn = __get_connection__(*kwargs) + return conn.fetchdf(**kwargs) +_exported_symbols.append('fetchdf') + +def fetch_df(**kwargs): + conn = __get_connection__(*kwargs) + return conn.fetch_df(**kwargs) +_exported_symbols.append('fetch_df') + +def df(**kwargs): + conn = __get_connection__(*kwargs) + return conn.df(**kwargs) +_exported_symbols.append('df') + +def fetch_df_chunk(vectors_per_chunk = 1, **kwargs): + conn = __get_connection__(*kwargs) + return conn.fetch_df_chunk(vectors_per_chunk, **kwargs) +_exported_symbols.append('fetch_df_chunk') + +def pl(rows_per_batch = 1000000, **kwargs): + conn = __get_connection__(*kwargs) + return conn.pl(rows_per_batch, **kwargs) +_exported_symbols.append('pl') + +def fetch_arrow_table(rows_per_batch = 1000000, **kwargs): + conn = __get_connection__(*kwargs) + return conn.fetch_arrow_table(rows_per_batch, **kwargs) +_exported_symbols.append('fetch_arrow_table') + +def arrow(rows_per_batch = 1000000, **kwargs): + conn = __get_connection__(*kwargs) + return conn.arrow(rows_per_batch, **kwargs) +_exported_symbols.append('arrow') + +def fetch_record_batch(rows_per_batch = 1000000, **kwargs): + conn = __get_connection__(*kwargs) + return conn.fetch_record_batch(rows_per_batch, **kwargs) +_exported_symbols.append('fetch_record_batch') + +def torch(**kwargs): + conn = __get_connection__(*kwargs) + return conn.torch(**kwargs) +_exported_symbols.append('torch') + +def tf(**kwargs): + conn = __get_connection__(*kwargs) + return conn.tf(**kwargs) +_exported_symbols.append('tf') + +def begin(**kwargs): + conn = __get_connection__(*kwargs) + return conn.begin(**kwargs) +_exported_symbols.append('begin') + +def commit(**kwargs): + conn = __get_connection__(*kwargs) + return conn.commit(**kwargs) +_exported_symbols.append('commit') + +def rollback(**kwargs): + conn = __get_connection__(*kwargs) + return conn.rollback(**kwargs) +_exported_symbols.append('rollback') + +def append(table_name, df, **kwargs): + conn = __get_connection__(*kwargs) + return conn.append(table_name, df, **kwargs) +_exported_symbols.append('append') + +def register(view_name, python_object, **kwargs): + conn = __get_connection__(*kwargs) + return conn.register(view_name, python_object, **kwargs) +_exported_symbols.append('register') + +def unregister(view_name, **kwargs): + conn = __get_connection__(*kwargs) + return conn.unregister(view_name, **kwargs) +_exported_symbols.append('unregister') + +def table(table_name, **kwargs): + conn = __get_connection__(*kwargs) + return conn.table(table_name, **kwargs) +_exported_symbols.append('table') + +def view(view_name, **kwargs): + conn = __get_connection__(*kwargs) + return conn.view(view_name, **kwargs) +_exported_symbols.append('view') + +def values(values, **kwargs): + conn = __get_connection__(*kwargs) + return conn.values(values, **kwargs) +_exported_symbols.append('values') + +def table_function(name, parameters = None, **kwargs): + conn = __get_connection__(*kwargs) + return conn.table_function(name, parameters, **kwargs) +_exported_symbols.append('table_function') + +def read_json(name, **kwargs): + conn = __get_connection__(*kwargs) + return conn.read_json(name, **kwargs) +_exported_symbols.append('read_json') + +def extract_statements(query, **kwargs): + conn = __get_connection__(*kwargs) + return conn.extract_statements(query, **kwargs) +_exported_symbols.append('extract_statements') + +def sql(query, **kwargs): + conn = __get_connection__(*kwargs) + return conn.sql(query, **kwargs) +_exported_symbols.append('sql') + +def query(query, **kwargs): + conn = __get_connection__(*kwargs) + return conn.query(query, **kwargs) +_exported_symbols.append('query') + +def from_query(query, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_query(query, **kwargs) +_exported_symbols.append('from_query') + +def read_csv(path_or_buffer, **kwargs): + conn = __get_connection__(*kwargs) + return conn.read_csv(path_or_buffer, **kwargs) +_exported_symbols.append('read_csv') + +def from_csv_auto(path_or_buffer, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_csv_auto(path_or_buffer, **kwargs) +_exported_symbols.append('from_csv_auto') + +def from_df(df, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_df(df, **kwargs) +_exported_symbols.append('from_df') + +def from_arrow(arrow_object, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_arrow(arrow_object, **kwargs) +_exported_symbols.append('from_arrow') + +def from_parquet(file_glob, binary_as_string = False, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_parquet(file_glob, binary_as_string, **kwargs) +_exported_symbols.append('from_parquet') + +def read_parquet(file_glob, binary_as_string = False, **kwargs): + conn = __get_connection__(*kwargs) + return conn.read_parquet(file_glob, binary_as_string, **kwargs) +_exported_symbols.append('read_parquet') + +def from_substrait(proto, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_substrait(proto, **kwargs) +_exported_symbols.append('from_substrait') + +def get_substrait(query, **kwargs): + conn = __get_connection__(*kwargs) + return conn.get_substrait(query, **kwargs) +_exported_symbols.append('get_substrait') + +def get_substrait_json(query, **kwargs): + conn = __get_connection__(*kwargs) + return conn.get_substrait_json(query, **kwargs) +_exported_symbols.append('get_substrait_json') + +def from_substrait_json(json, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_substrait_json(json, **kwargs) +_exported_symbols.append('from_substrait_json') + +def get_table_names(query, **kwargs): + conn = __get_connection__(*kwargs) + return conn.get_table_names(query, **kwargs) +_exported_symbols.append('get_table_names') + +def install_extension(extension, **kwargs): + conn = __get_connection__(*kwargs) + return conn.install_extension(extension, **kwargs) +_exported_symbols.append('install_extension') + +def load_extension(extension, **kwargs): + conn = __get_connection__(*kwargs) + return conn.load_extension(extension, **kwargs) +_exported_symbols.append('load_extension') + +def project(df, *args, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_df(df).project(*args, **kwargs) +_exported_symbols.append('project') + +def distinct(df, *args, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_df(df).distinct(*args, **kwargs) +_exported_symbols.append('distinct') + +def write_csv(df, *args, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_df(df).write_csv(*args, **kwargs) +_exported_symbols.append('write_csv') + +def aggregate(df, *args, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_df(df).aggregate(*args, **kwargs) +_exported_symbols.append('aggregate') + +def alias(df, *args, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_df(df).set_alias(*args, **kwargs) +_exported_symbols.append('alias') + +def filter(df, *args, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_df(df).filter(*args, **kwargs) +_exported_symbols.append('filter') + +def limit(df, *args, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_df(df).limit(*args, **kwargs) +_exported_symbols.append('limit') + +def order(df, *args, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_df(df).order(*args, **kwargs) +_exported_symbols.append('order') + +def query_df(df, *args, **kwargs): + conn = __get_connection__(*kwargs) + return conn.from_df(df).query(*args, **kwargs) +_exported_symbols.append('query_df') +# END OF CONNECTION WRAPPER # Enums from .duckdb import ( diff --git a/tools/pythonpkg/duckdb_python.cpp b/tools/pythonpkg/duckdb_python.cpp index 1289c5743833..bb2ee522d72c 100644 --- a/tools/pythonpkg/duckdb_python.cpp +++ b/tools/pythonpkg/duckdb_python.cpp @@ -71,8 +71,7 @@ static py::list PyTokenize(const string &query) { } static void InitializeConnectionMethods(py::module_ &m) { - // We define these "wrapper" methods inside of C++ because they are overloaded - // every other wrapper method is defined inside of __init__.py + // We define these "wrapper" methods manually because they are overloaded m.def( "arrow", [](idx_t rows_per_batch, shared_ptr conn) -> duckdb::pyarrow::Table { diff --git a/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py b/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py new file mode 100644 index 000000000000..b76ee249e510 --- /dev/null +++ b/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py @@ -0,0 +1,142 @@ +import os +import json + +os.chdir(os.path.dirname(__file__)) + +JSON_PATH = os.path.join("connection_methods.json") +DUCKDB_INIT_FILE = os.path.join("..", "duckdb", "__init__.py") + +START_MARKER = "# START OF CONNECTION WRAPPER" +END_MARKER = "# END OF CONNECTION WRAPPER" + +# Read the DUCKDB_INIT_FILE file +with open(DUCKDB_INIT_FILE, 'r') as source_file: + source_code = source_file.readlines() + +start_index = -1 +end_index = -1 +for i, line in enumerate(source_code): + if line.startswith(START_MARKER): + # TODO: handle the case where the start marker appears multiple times + start_index = i + elif line.startswith(END_MARKER): + # TODO: ditto ^ + end_index = i + +if start_index == -1 or end_index == -1: + raise ValueError("Couldn't find start or end marker in source file") + +start_section = source_code[: start_index + 1] +end_section = source_code[end_index:] +# ---- Generate the definition code from the json ---- + +# Read the JSON file +with open(JSON_PATH, 'r') as json_file: + connection_methods = json.load(json_file) + +# Artificial "wrapper" methods on pandas.DataFrames +SPECIAL_METHODS = [ + {'name': 'project', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, + {'name': 'distinct', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, + {'name': 'write_csv', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'None'}, + {'name': 'aggregate', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, + {'name': 'alias', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, + {'name': 'filter', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, + {'name': 'limit', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, + {'name': 'order', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, + {'name': 'query_df', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, +] + +connection_methods.extend(SPECIAL_METHODS) + +body = [] + +SPECIAL_METHOD_NAMES = [x['name'] for x in SPECIAL_METHODS] + + +def generate_arguments(name, method) -> str: + arguments = [] + if name in SPECIAL_METHOD_NAMES: + # We add 'df' to these methods because they operate on a DataFrame + arguments.append('df') + + if 'args' in method: + for arg in method['args']: + res = arg['name'] + if 'default' in arg: + res += f" = {arg['default']}" + arguments.append(res) + arguments.append('**kwargs') + return ', '.join(arguments) + + +def generate_parameters(method) -> str: + arguments = [] + if 'args' in method: + for arg in method['args']: + arguments.append(f"{arg['name']}") + arguments.append('**kwargs') + return ', '.join(arguments) + + +def generate_function_call(name, method) -> str: + function_call = '' + if name in SPECIAL_METHOD_NAMES: + function_call += 'from_df(df).' + + REMAPPED_FUNCTIONS = {'alias': 'set_alias', 'query_df': 'query'} + if name in REMAPPED_FUNCTIONS: + function_name = REMAPPED_FUNCTIONS[name] + else: + function_name = name + function_call += function_name + return function_call + + +def create_definition(name, method) -> str: + print(method) + arguments = generate_arguments(name, method) + parameters = generate_parameters(method) + function_call = generate_function_call(name, method) + + func = f""" +def {name}({arguments}): + conn = __get_connection__(*kwargs) + return conn.{function_call}({parameters}) +_exported_symbols.append('{name}') +""" + return func + + +# We have "duplicate" methods, which are overloaded +written_methods = set() + +for method in connection_methods: + if isinstance(method['name'], list): + names = method['name'] + else: + names = [method['name']] + + # Artificially add 'connection' keyword argument + if 'kwargs' not in method: + method['kwargs'] = [] + method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection'}) + + for name in names: + if name in written_methods: + continue + body.append(create_definition(name, method)) + written_methods.add(name) + +# ---- End of generation code ---- + +with_newlines = body +# Recreate the file content by concatenating all the pieces together + +new_content = start_section + with_newlines + end_section + +print(''.join(with_newlines)) + +# Write out the modified DUCKDB_INIT_FILE file +with open(DUCKDB_INIT_FILE, 'w') as source_file: + source_file.write("".join(new_content)) From a2c33c959c89f20281d475b8272793f52e3a849e Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 22:49:19 +0200 Subject: [PATCH 084/201] pass all tests --- tools/pythonpkg/duckdb/__init__.py | 397 ++++++++++++++---- .../generate_connection_wrapper_methods.py | 26 +- .../tests/fast/api/test_duckdb_connection.py | 7 + 3 files changed, 336 insertions(+), 94 deletions(-) diff --git a/tools/pythonpkg/duckdb/__init__.py b/tools/pythonpkg/duckdb/__init__.py index f4f6af797c6e..29252c856881 100644 --- a/tools/pythonpkg/duckdb/__init__.py +++ b/tools/pythonpkg/duckdb/__init__.py @@ -50,391 +50,610 @@ arrow ) -def __get_connection__(**kwargs): - if 'connection' in kwargs: - return kwargs.pop('connection') - else: - return duckdb.connect(":default:") - # NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_wrapper_methods.py. # Do not edit this section manually, your changes will be overwritten! # START OF CONNECTION WRAPPER def cursor(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.cursor(**kwargs) _exported_symbols.append('cursor') def register_filesystem(filesystem, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.register_filesystem(filesystem, **kwargs) _exported_symbols.append('register_filesystem') def unregister_filesystem(name, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.unregister_filesystem(name, **kwargs) _exported_symbols.append('unregister_filesystem') def list_filesystems(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.list_filesystems(**kwargs) _exported_symbols.append('list_filesystems') def filesystem_is_registered(name, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.filesystem_is_registered(name, **kwargs) _exported_symbols.append('filesystem_is_registered') def create_function(name, function, parameters = None, return_type = None, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.create_function(name, function, parameters, return_type, **kwargs) _exported_symbols.append('create_function') def remove_function(name, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.remove_function(name, **kwargs) _exported_symbols.append('remove_function') def sqltype(type_str, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.sqltype(type_str, **kwargs) _exported_symbols.append('sqltype') def dtype(type_str, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.dtype(type_str, **kwargs) _exported_symbols.append('dtype') def type(type_str, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.type(type_str, **kwargs) _exported_symbols.append('type') def array_type(type, size, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.array_type(type, size, **kwargs) _exported_symbols.append('array_type') def list_type(type, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.list_type(type, **kwargs) _exported_symbols.append('list_type') def union_type(members, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.union_type(members, **kwargs) _exported_symbols.append('union_type') def string_type(collation = "", **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.string_type(collation, **kwargs) _exported_symbols.append('string_type') def enum_type(name, type, values, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.enum_type(name, type, values, **kwargs) _exported_symbols.append('enum_type') def decimal_type(width, scale, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.decimal_type(width, scale, **kwargs) _exported_symbols.append('decimal_type') def struct_type(fields, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.struct_type(fields, **kwargs) _exported_symbols.append('struct_type') def row_type(fields, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.row_type(fields, **kwargs) _exported_symbols.append('row_type') def map_type(key, value, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.map_type(key, value, **kwargs) _exported_symbols.append('map_type') def duplicate(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.duplicate(**kwargs) _exported_symbols.append('duplicate') def execute(query, parameters = None, multiple_parameter_sets = False, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.execute(query, parameters, multiple_parameter_sets, **kwargs) _exported_symbols.append('execute') def executemany(query, parameters = None, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.executemany(query, parameters, **kwargs) _exported_symbols.append('executemany') def close(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.close(**kwargs) _exported_symbols.append('close') def interrupt(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.interrupt(**kwargs) _exported_symbols.append('interrupt') def fetchone(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.fetchone(**kwargs) _exported_symbols.append('fetchone') def fetchmany(size = 1, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.fetchmany(size, **kwargs) _exported_symbols.append('fetchmany') def fetchall(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.fetchall(**kwargs) _exported_symbols.append('fetchall') def fetchnumpy(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.fetchnumpy(**kwargs) _exported_symbols.append('fetchnumpy') def fetchdf(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.fetchdf(**kwargs) _exported_symbols.append('fetchdf') def fetch_df(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.fetch_df(**kwargs) _exported_symbols.append('fetch_df') -def df(**kwargs): - conn = __get_connection__(*kwargs) - return conn.df(**kwargs) -_exported_symbols.append('df') - def fetch_df_chunk(vectors_per_chunk = 1, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.fetch_df_chunk(vectors_per_chunk, **kwargs) _exported_symbols.append('fetch_df_chunk') def pl(rows_per_batch = 1000000, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.pl(rows_per_batch, **kwargs) _exported_symbols.append('pl') def fetch_arrow_table(rows_per_batch = 1000000, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.fetch_arrow_table(rows_per_batch, **kwargs) _exported_symbols.append('fetch_arrow_table') -def arrow(rows_per_batch = 1000000, **kwargs): - conn = __get_connection__(*kwargs) - return conn.arrow(rows_per_batch, **kwargs) -_exported_symbols.append('arrow') - def fetch_record_batch(rows_per_batch = 1000000, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.fetch_record_batch(rows_per_batch, **kwargs) _exported_symbols.append('fetch_record_batch') def torch(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.torch(**kwargs) _exported_symbols.append('torch') def tf(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.tf(**kwargs) _exported_symbols.append('tf') def begin(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.begin(**kwargs) _exported_symbols.append('begin') def commit(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.commit(**kwargs) _exported_symbols.append('commit') def rollback(**kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.rollback(**kwargs) _exported_symbols.append('rollback') def append(table_name, df, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.append(table_name, df, **kwargs) _exported_symbols.append('append') def register(view_name, python_object, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.register(view_name, python_object, **kwargs) _exported_symbols.append('register') def unregister(view_name, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.unregister(view_name, **kwargs) _exported_symbols.append('unregister') def table(table_name, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.table(table_name, **kwargs) _exported_symbols.append('table') def view(view_name, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.view(view_name, **kwargs) _exported_symbols.append('view') def values(values, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.values(values, **kwargs) _exported_symbols.append('values') def table_function(name, parameters = None, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.table_function(name, parameters, **kwargs) _exported_symbols.append('table_function') def read_json(name, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.read_json(name, **kwargs) _exported_symbols.append('read_json') def extract_statements(query, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.extract_statements(query, **kwargs) _exported_symbols.append('extract_statements') def sql(query, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.sql(query, **kwargs) _exported_symbols.append('sql') def query(query, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.query(query, **kwargs) _exported_symbols.append('query') def from_query(query, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_query(query, **kwargs) _exported_symbols.append('from_query') def read_csv(path_or_buffer, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.read_csv(path_or_buffer, **kwargs) _exported_symbols.append('read_csv') def from_csv_auto(path_or_buffer, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_csv_auto(path_or_buffer, **kwargs) _exported_symbols.append('from_csv_auto') def from_df(df, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_df(df, **kwargs) _exported_symbols.append('from_df') def from_arrow(arrow_object, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_arrow(arrow_object, **kwargs) _exported_symbols.append('from_arrow') def from_parquet(file_glob, binary_as_string = False, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_parquet(file_glob, binary_as_string, **kwargs) _exported_symbols.append('from_parquet') def read_parquet(file_glob, binary_as_string = False, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.read_parquet(file_glob, binary_as_string, **kwargs) _exported_symbols.append('read_parquet') def from_substrait(proto, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_substrait(proto, **kwargs) _exported_symbols.append('from_substrait') def get_substrait(query, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.get_substrait(query, **kwargs) _exported_symbols.append('get_substrait') def get_substrait_json(query, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.get_substrait_json(query, **kwargs) _exported_symbols.append('get_substrait_json') def from_substrait_json(json, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_substrait_json(json, **kwargs) _exported_symbols.append('from_substrait_json') def get_table_names(query, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.get_table_names(query, **kwargs) _exported_symbols.append('get_table_names') def install_extension(extension, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.install_extension(extension, **kwargs) _exported_symbols.append('install_extension') def load_extension(extension, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.load_extension(extension, **kwargs) _exported_symbols.append('load_extension') def project(df, *args, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_df(df).project(*args, **kwargs) _exported_symbols.append('project') def distinct(df, *args, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_df(df).distinct(*args, **kwargs) _exported_symbols.append('distinct') def write_csv(df, *args, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_df(df).write_csv(*args, **kwargs) _exported_symbols.append('write_csv') def aggregate(df, *args, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_df(df).aggregate(*args, **kwargs) _exported_symbols.append('aggregate') def alias(df, *args, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_df(df).set_alias(*args, **kwargs) _exported_symbols.append('alias') def filter(df, *args, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_df(df).filter(*args, **kwargs) _exported_symbols.append('filter') def limit(df, *args, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_df(df).limit(*args, **kwargs) _exported_symbols.append('limit') def order(df, *args, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_df(df).order(*args, **kwargs) _exported_symbols.append('order') def query_df(df, *args, **kwargs): - conn = __get_connection__(*kwargs) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") return conn.from_df(df).query(*args, **kwargs) _exported_symbols.append('query_df') + +def description(**kwargs): + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") + return conn.description +_exported_symbols.append('description') + +def rowcount(**kwargs): + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") + return conn.rowcount +_exported_symbols.append('rowcount') # END OF CONNECTION WRAPPER # Enums diff --git a/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py b/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py index b76ee249e510..c890bc8bf2da 100644 --- a/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py +++ b/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py @@ -47,11 +47,18 @@ {'name': 'query_df', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, ] +READONLY_PROPERTIES = [ + {'name': 'description', 'return': 'str'}, + {'name': 'rowcount', 'return': 'int'}, +] + connection_methods.extend(SPECIAL_METHODS) +connection_methods.extend(READONLY_PROPERTIES) body = [] SPECIAL_METHOD_NAMES = [x['name'] for x in SPECIAL_METHODS] +READONLY_PROPERTY_NAMES = [x['name'] for x in READONLY_PROPERTIES] def generate_arguments(name, method) -> str: @@ -70,13 +77,16 @@ def generate_arguments(name, method) -> str: return ', '.join(arguments) -def generate_parameters(method) -> str: +def generate_parameters(name, method) -> str: + if name in READONLY_PROPERTY_NAMES: + return '' arguments = [] if 'args' in method: for arg in method['args']: arguments.append(f"{arg['name']}") arguments.append('**kwargs') - return ', '.join(arguments) + result = ', '.join(arguments) + return '(' + result + ')' def generate_function_call(name, method) -> str: @@ -96,13 +106,16 @@ def generate_function_call(name, method) -> str: def create_definition(name, method) -> str: print(method) arguments = generate_arguments(name, method) - parameters = generate_parameters(method) + parameters = generate_parameters(name, method) function_call = generate_function_call(name, method) func = f""" def {name}({arguments}): - conn = __get_connection__(*kwargs) - return conn.{function_call}({parameters}) + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") + return conn.{function_call}{parameters} _exported_symbols.append('{name}') """ return func @@ -125,6 +138,9 @@ def {name}({arguments}): for name in names: if name in written_methods: continue + if name in ['arrow', 'df']: + # These methods are ambiguous and are handled in C++ code instead + continue body.append(create_definition(name, method)) written_methods.add(name) diff --git a/tools/pythonpkg/tests/fast/api/test_duckdb_connection.py b/tools/pythonpkg/tests/fast/api/test_duckdb_connection.py index c9c26a867ed2..4d1afb7b3bf4 100644 --- a/tools/pythonpkg/tests/fast/api/test_duckdb_connection.py +++ b/tools/pythonpkg/tests/fast/api/test_duckdb_connection.py @@ -84,6 +84,13 @@ def test_duplicate(self): with pytest.raises(duckdb.CatalogException): dup_conn.table("tbl").fetchall() + def test_readonly_properties(self): + duckdb.execute("select 42") + description = duckdb.description() + rowcount = duckdb.rowcount() + assert description == [('42', 'NUMBER', None, None, None, None, None)] + assert rowcount == -1 + def test_execute(self): assert [([4, 2],)] == duckdb.execute("select [4,2]").fetchall() From 6353eebdeb9e1a695e4ba642415cf8fa1195892f Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 23:03:33 +0200 Subject: [PATCH 085/201] cleanup --- .../scripts/connection_wrapper_methods.json | 100 ++++++++++++++++++ .../generate_connection_wrapper_methods.py | 44 ++++---- .../scripts/generate_function_definition.py | 3 - 3 files changed, 118 insertions(+), 29 deletions(-) create mode 100644 tools/pythonpkg/scripts/connection_wrapper_methods.json delete mode 100644 tools/pythonpkg/scripts/generate_function_definition.py diff --git a/tools/pythonpkg/scripts/connection_wrapper_methods.json b/tools/pythonpkg/scripts/connection_wrapper_methods.json new file mode 100644 index 000000000000..2be14ae79a4f --- /dev/null +++ b/tools/pythonpkg/scripts/connection_wrapper_methods.json @@ -0,0 +1,100 @@ +[ + { + "name": "project", + "args": [ + { + "name": "*args", + "type": "Any" + } + ], + "return": "DuckDBPyRelation" + }, + { + "name": "distinct", + "args": [ + { + "name": "*args", + "type": "Any" + } + ], + "return": "DuckDBPyRelation" + }, + { + "name": "write_csv", + "args": [ + { + "name": "*args", + "type": "Any" + } + ], + "return": "None" + }, + { + "name": "aggregate", + "args": [ + { + "name": "*args", + "type": "Any" + } + ], + "return": "DuckDBPyRelation" + }, + { + "name": "alias", + "args": [ + { + "name": "*args", + "type": "Any" + } + ], + "return": "DuckDBPyRelation" + }, + { + "name": "filter", + "args": [ + { + "name": "*args", + "type": "Any" + } + ], + "return": "DuckDBPyRelation" + }, + { + "name": "limit", + "args": [ + { + "name": "*args", + "type": "Any" + } + ], + "return": "DuckDBPyRelation" + }, + { + "name": "order", + "args": [ + { + "name": "*args", + "type": "Any" + } + ], + "return": "DuckDBPyRelation" + }, + { + "name": "query_df", + "args": [ + { + "name": "*args", + "type": "Any" + } + ], + "return": "DuckDBPyRelation" + }, + { + "name": "description", + "return": "str" + }, + { + "name": "rowcount", + "return": "int" + } +] diff --git a/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py b/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py index c890bc8bf2da..e716765f6816 100644 --- a/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py +++ b/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py @@ -4,6 +4,7 @@ os.chdir(os.path.dirname(__file__)) JSON_PATH = os.path.join("connection_methods.json") +WRAPPER_JSON_PATH = os.path.join("connection_wrapper_methods.json") DUCKDB_INIT_FILE = os.path.join("..", "duckdb", "__init__.py") START_MARKER = "# START OF CONNECTION WRAPPER" @@ -30,35 +31,25 @@ end_section = source_code[end_index:] # ---- Generate the definition code from the json ---- +methods = [] + # Read the JSON file with open(JSON_PATH, 'r') as json_file: connection_methods = json.load(json_file) -# Artificial "wrapper" methods on pandas.DataFrames -SPECIAL_METHODS = [ - {'name': 'project', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, - {'name': 'distinct', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, - {'name': 'write_csv', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'None'}, - {'name': 'aggregate', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, - {'name': 'alias', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, - {'name': 'filter', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, - {'name': 'limit', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, - {'name': 'order', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, - {'name': 'query_df', 'args': [{'name': "*args", 'type': 'Any'}], 'return': 'DuckDBPyRelation'}, -] - -READONLY_PROPERTIES = [ - {'name': 'description', 'return': 'str'}, - {'name': 'rowcount', 'return': 'int'}, -] - -connection_methods.extend(SPECIAL_METHODS) -connection_methods.extend(READONLY_PROPERTIES) +with open(WRAPPER_JSON_PATH, 'r') as json_file: + wrapper_methods = json.load(json_file) -body = [] +methods.extend(connection_methods) +methods.extend(wrapper_methods) -SPECIAL_METHOD_NAMES = [x['name'] for x in SPECIAL_METHODS] -READONLY_PROPERTY_NAMES = [x['name'] for x in READONLY_PROPERTIES] +# On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke +# that's not possible on 'duckdb' so it becomes a function call with no arguments (i.e duckdb.description()) +READONLY_PROPERTY_NAMES = ['description', 'rowcount'] + +# These methods are not directly DuckDBPyConnection methods, +# they first call 'from_df' and then call a method on the created DuckDBPyRelation +SPECIAL_METHOD_NAMES = [x['name'] for x in wrapper_methods if x['name'] not in READONLY_PROPERTY_NAMES] def generate_arguments(name, method) -> str: @@ -89,7 +80,7 @@ def generate_parameters(name, method) -> str: return '(' + result + ')' -def generate_function_call(name, method) -> str: +def generate_function_call(name) -> str: function_call = '' if name in SPECIAL_METHOD_NAMES: function_call += 'from_df(df).' @@ -107,7 +98,7 @@ def create_definition(name, method) -> str: print(method) arguments = generate_arguments(name, method) parameters = generate_parameters(name, method) - function_call = generate_function_call(name, method) + function_call = generate_function_call(name) func = f""" def {name}({arguments}): @@ -124,7 +115,8 @@ def {name}({arguments}): # We have "duplicate" methods, which are overloaded written_methods = set() -for method in connection_methods: +body = [] +for method in methods: if isinstance(method['name'], list): names = method['name'] else: diff --git a/tools/pythonpkg/scripts/generate_function_definition.py b/tools/pythonpkg/scripts/generate_function_definition.py deleted file mode 100644 index aca68675eb43..000000000000 --- a/tools/pythonpkg/scripts/generate_function_definition.py +++ /dev/null @@ -1,3 +0,0 @@ -class FunctionDefinition: - def __init__(self): - pass From 2ff91ad07b0cf63e32f46e319baf052156fb5f3b Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 23:19:49 +0200 Subject: [PATCH 086/201] add the right arguments for the (DuckDBPyRelation) functions we're wrapping --- tools/pythonpkg/duckdb/__init__.py | 32 ++++++------- .../scripts/connection_wrapper_methods.json | 48 +++++++++++-------- 2 files changed, 44 insertions(+), 36 deletions(-) diff --git a/tools/pythonpkg/duckdb/__init__.py b/tools/pythonpkg/duckdb/__init__.py index 29252c856881..5d2d56e2dd4e 100644 --- a/tools/pythonpkg/duckdb/__init__.py +++ b/tools/pythonpkg/duckdb/__init__.py @@ -567,20 +567,20 @@ def load_extension(extension, **kwargs): return conn.load_extension(extension, **kwargs) _exported_symbols.append('load_extension') -def project(df, *args, **kwargs): +def project(df, project_expr, **kwargs): if 'connection' in kwargs: conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") - return conn.from_df(df).project(*args, **kwargs) + return conn.from_df(df).project(project_expr, **kwargs) _exported_symbols.append('project') -def distinct(df, *args, **kwargs): +def distinct(df, **kwargs): if 'connection' in kwargs: conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") - return conn.from_df(df).distinct(*args, **kwargs) + return conn.from_df(df).distinct(**kwargs) _exported_symbols.append('distinct') def write_csv(df, *args, **kwargs): @@ -591,52 +591,52 @@ def write_csv(df, *args, **kwargs): return conn.from_df(df).write_csv(*args, **kwargs) _exported_symbols.append('write_csv') -def aggregate(df, *args, **kwargs): +def aggregate(df, aggr_expr, group_expr = "", **kwargs): if 'connection' in kwargs: conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") - return conn.from_df(df).aggregate(*args, **kwargs) + return conn.from_df(df).aggregate(aggr_expr, group_expr, **kwargs) _exported_symbols.append('aggregate') -def alias(df, *args, **kwargs): +def alias(df, alias, **kwargs): if 'connection' in kwargs: conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") - return conn.from_df(df).set_alias(*args, **kwargs) + return conn.from_df(df).set_alias(alias, **kwargs) _exported_symbols.append('alias') -def filter(df, *args, **kwargs): +def filter(df, filter_expr, **kwargs): if 'connection' in kwargs: conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") - return conn.from_df(df).filter(*args, **kwargs) + return conn.from_df(df).filter(filter_expr, **kwargs) _exported_symbols.append('filter') -def limit(df, *args, **kwargs): +def limit(df, n, offset = 0, **kwargs): if 'connection' in kwargs: conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") - return conn.from_df(df).limit(*args, **kwargs) + return conn.from_df(df).limit(n, offset, **kwargs) _exported_symbols.append('limit') -def order(df, *args, **kwargs): +def order(df, order_expr, **kwargs): if 'connection' in kwargs: conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") - return conn.from_df(df).order(*args, **kwargs) + return conn.from_df(df).order(order_expr, **kwargs) _exported_symbols.append('order') -def query_df(df, *args, **kwargs): +def query_df(df, virtual_table_name, sql_query, **kwargs): if 'connection' in kwargs: conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") - return conn.from_df(df).query(*args, **kwargs) + return conn.from_df(df).query(virtual_table_name, sql_query, **kwargs) _exported_symbols.append('query_df') def description(**kwargs): diff --git a/tools/pythonpkg/scripts/connection_wrapper_methods.json b/tools/pythonpkg/scripts/connection_wrapper_methods.json index 2be14ae79a4f..65ac4afd47f7 100644 --- a/tools/pythonpkg/scripts/connection_wrapper_methods.json +++ b/tools/pythonpkg/scripts/connection_wrapper_methods.json @@ -3,20 +3,14 @@ "name": "project", "args": [ { - "name": "*args", - "type": "Any" + "name": "project_expr", + "type": "str" } ], "return": "DuckDBPyRelation" }, { "name": "distinct", - "args": [ - { - "name": "*args", - "type": "Any" - } - ], "return": "DuckDBPyRelation" }, { @@ -33,8 +27,13 @@ "name": "aggregate", "args": [ { - "name": "*args", - "type": "Any" + "name": "aggr_expr", + "type": "str" + }, + { + "name": "group_expr", + "type" : "str", + "default": "\"\"" } ], "return": "DuckDBPyRelation" @@ -43,8 +42,8 @@ "name": "alias", "args": [ { - "name": "*args", - "type": "Any" + "name": "alias", + "type": "str" } ], "return": "DuckDBPyRelation" @@ -53,8 +52,8 @@ "name": "filter", "args": [ { - "name": "*args", - "type": "Any" + "name": "filter_expr", + "type": "str" } ], "return": "DuckDBPyRelation" @@ -63,8 +62,13 @@ "name": "limit", "args": [ { - "name": "*args", - "type": "Any" + "name": "n", + "type": "int" + }, + { + "name": "offset", + "type": "int", + "default": "0" } ], "return": "DuckDBPyRelation" @@ -73,8 +77,8 @@ "name": "order", "args": [ { - "name": "*args", - "type": "Any" + "name": "order_expr", + "type": "str" } ], "return": "DuckDBPyRelation" @@ -83,8 +87,12 @@ "name": "query_df", "args": [ { - "name": "*args", - "type": "Any" + "name": "virtual_table_name", + "type": "str" + }, + { + "name": "sql_query", + "type": "str" } ], "return": "DuckDBPyRelation" From c94e308be0b3e86895b2b3954e82c7e2d054fe44 Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 23:40:33 +0200 Subject: [PATCH 087/201] generating the stubs for the connection wrappers too --- tools/pythonpkg/duckdb-stubs/__init__.pyi | 227 +++++++----------- .../scripts/connection_wrapper_methods.json | 2 +- 2 files changed, 84 insertions(+), 145 deletions(-) diff --git a/tools/pythonpkg/duckdb-stubs/__init__.pyi b/tools/pythonpkg/duckdb-stubs/__init__.pyi index 748e1102c290..0a5587e628be 100644 --- a/tools/pythonpkg/duckdb-stubs/__init__.pyi +++ b/tools/pythonpkg/duckdb-stubs/__init__.pyi @@ -572,149 +572,88 @@ class token_type: # stubgen override - this gets removed by stubgen but it shouldn't def __members__(self) -> object: ... -def aggregate(df: pandas.DataFrame, aggr_expr: str, group_expr: str = ..., connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def alias(df: pandas.DataFrame, alias: str, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... def connect(database: str = ..., read_only: bool = ..., config: dict = ...) -> DuckDBPyConnection: ... -def distinct(df: pandas.DataFrame, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def filter(df: pandas.DataFrame, filter_expr: str, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_substrait_json(jsonm: str, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def limit(df: pandas.DataFrame, n: int, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def order(df: pandas.DataFrame, order_expr: str, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def project(df: pandas.DataFrame, project_expr: str, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def write_csv(df: pandas.DataFrame, file_name: str, connection: DuckDBPyConnection = ...) -> None: ... -def read_json( - file_name: str, - columns: Optional[Dict[str,str]] = None, - sample_size: Optional[int] = None, - maximum_depth: Optional[int] = None, - format: Optional[str] = None, - records: Optional[str] = None, - connection: DuckDBPyConnection = ... -) -> DuckDBPyRelation: ... -def read_csv( - path_or_buffer: Union[str, StringIO, TextIOBase], - header: Optional[bool | int] = None, - compression: Optional[str] = None, - sep: Optional[str] = None, - delimiter: Optional[str] = None, - dtype: Optional[Dict[str, str] | List[str]] = None, - na_values: Optional[str] = None, - skiprows: Optional[int] = None, - quotechar: Optional[str] = None, - escapechar: Optional[str] = None, - encoding: Optional[str] = None, - parallel: Optional[bool] = None, - date_format: Optional[str] = None, - timestamp_format: Optional[str] = None, - sample_size: Optional[int] = None, - all_varchar: Optional[bool] = None, - normalize_names: Optional[bool] = None, - filename: Optional[bool] = None, - connection: DuckDBPyConnection = ... -) -> DuckDBPyRelation: ... -def from_csv_auto( - name: str, - header: Optional[bool | int] = None, - compression: Optional[str] = None, - sep: Optional[str] = None, - delimiter: Optional[str] = None, - dtype: Optional[Dict[str, str] | List[str]] = None, - na_values: Optional[str] = None, - skiprows: Optional[int] = None, - quotechar: Optional[str] = None, - escapechar: Optional[str] = None, - encoding: Optional[str] = None, - parallel: Optional[bool] = None, - date_format: Optional[str] = None, - timestamp_format: Optional[str] = None, - sample_size: Optional[int] = None, - all_varchar: Optional[bool] = None, - normalize_names: Optional[bool] = None, - filename: Optional[bool] = None, - null_padding: Optional[bool] = None, - connection: DuckDBPyConnection = ... -) -> DuckDBPyRelation: ... - -def append(table_name: str, df: pandas.DataFrame, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def arrow(rows_per_batch: int = ..., connection: DuckDBPyConnection = ...) -> pyarrow.lib.Table: ... -def begin(connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def close(connection: DuckDBPyConnection = ...) -> None: ... -def commit(connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def cursor(connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def df(connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... -def description(connection: DuckDBPyConnection = ...) -> Optional[List[Any]]: ... -def rowcount(connection: DuckDBPyConnection = ...) -> int: ... -def duplicate(connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def execute(query: str, parameters: object = ..., multiple_parameter_sets: bool = ..., connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def executemany(query: str, parameters: object = ..., connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def fetch_arrow_table(rows_per_batch: int = ..., connection: DuckDBPyConnection = ...) -> pyarrow.lib.Table: ... -def fetch_df(*args, connection: DuckDBPyConnection = ..., **kwargs) -> pandas.DataFrame: ... -def fetch_df_chunk(*args, connection: DuckDBPyConnection = ..., **kwargs) -> pandas.DataFrame: ... -def fetch_record_batch(rows_per_batch: int = ..., connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... -def fetchall(connection: DuckDBPyConnection = ...) -> List[Any]: ... -def fetchdf(*args, connection: DuckDBPyConnection = ..., **kwargs) -> pandas.DataFrame: ... -def fetchmany(size: int = ..., connection: DuckDBPyConnection = ...) -> List[Any]: ... -def fetchnumpy(connection: DuckDBPyConnection = ...) -> dict: ... -def fetchone(connection: DuckDBPyConnection = ...) -> Optional[tuple]: ... -def from_arrow(arrow_object: object, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_df(df: pandas.DataFrame = ..., connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -@overload -def read_parquet(file_glob: str, binary_as_string: bool = ..., *, file_row_number: bool = ..., filename: bool = ..., hive_partitioning: bool = ..., union_by_name: bool = ..., connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -@overload -def read_parquet(file_globs: List[str], binary_as_string: bool = ..., *, file_row_number: bool = ..., filename: bool = ..., hive_partitioning: bool = ..., union_by_name: bool = ..., connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -@overload -def from_parquet(file_glob: str, binary_as_string: bool = ..., *, file_row_number: bool = ..., filename: bool = ..., hive_partitioning: bool = ..., union_by_name: bool = ..., connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -@overload -def from_parquet(file_globs: List[str], binary_as_string: bool = ..., *, file_row_number: bool = ..., filename: bool = ..., hive_partitioning: bool = ..., union_by_name: bool = ..., connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_substrait(proto: bytes, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def get_substrait(query: str, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def get_substrait_json(query: str, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def get_table_names(query: str, connection: DuckDBPyConnection = ...) -> Set[str]: ... -def install_extension(*args, connection: DuckDBPyConnection = ..., **kwargs) -> None: ... -def interrupt(connection: DuckDBPyConnection = ...) -> None: ... -def list_filesystems(connection: DuckDBPyConnection = ...) -> List[Any]: ... -def filesystem_is_registered(name: str, connection: DuckDBPyConnection = ...) -> bool: ... -def load_extension(extension: str, connection: DuckDBPyConnection = ...) -> None: ... -def pl(rows_per_batch: int = ..., connection: DuckDBPyConnection = ...) -> polars.DataFrame: ... -def torch(connection: DuckDBPyConnection = ...) -> dict: ... -def tf(self, connection: DuckDBPyConnection = ...) -> dict: ... -def register(view_name: str, python_object: object, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def remove_function(name: str, connection : DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def create_function( - name: str, - func: Callable, - parameters: Optional[List[DuckDBPyType]] = None, - return_type: Optional[DuckDBPyType] = None, - type: Optional[PythonUDFType] = PythonUDFType.NATIVE, - null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, - exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, - side_effects: Optional[bool] = False, - connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def register_filesystem(filesystem: fsspec.AbstractFileSystem, connection: DuckDBPyConnection = ...) -> None: ... -def rollback(connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... - -def query(query: str, connection: DuckDBPyConnection = ..., **kwargs) -> DuckDBPyRelation: ... -def sql(query: str, connection: DuckDBPyConnection = ..., **kwargs) -> DuckDBPyRelation: ... -def from_query(query: str, connection: DuckDBPyConnection = ..., **kwargs) -> DuckDBPyRelation: ... -def extract_statements(self, query: str, connection: DuckDBPyConnection = ...) -> List[Statement]: ... - -def table(table_name: str, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def table_function(name: str, parameters: object = ..., connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def unregister(view_name: str, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def query_df(df: pandas.DataFrame, virtual_table_name: str, sql_query: str, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def unregister_filesystem(name: str, connection: DuckDBPyConnection = ...) -> None: ... def tokenize(query: str) -> List[Any]: ... -def values(values: object, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def view(view_name: str, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def sqltype(type_str: str, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def dtype(type_str: str, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def type(type_str: str, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def struct_type(fields: Union[Dict[str, DuckDBPyType], List[str]], connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def row_type(fields: Union[Dict[str, DuckDBPyType], List[str]], connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def union_type(members: Union[Dict[str, DuckDBPyType], List[str]], connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def string_type(collation: str = "", connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def enum_type(name: str, type: DuckDBPyType, values: List[Any], connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def decimal_type(width: int, scale: int, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def array_type(type: DuckDBPyType, size: int, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def list_type(type: DuckDBPyType, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def map_type(key: DuckDBPyType, value: DuckDBPyType, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... + +# NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py. +# Do not edit this section manually, your changes will be overwritten! + +# START OF CONNECTION WRAPPER +def cursor(**kwargs) -> DuckDBPyConnection: ... +def register_filesystem(filesystem: str, **kwargs) -> None: ... +def unregister_filesystem(name: str, **kwargs) -> None: ... +def list_filesystems(**kwargs) -> list: ... +def filesystem_is_registered(name: str, **kwargs) -> bool: ... +def create_function(name: str, function: function, parameters: Optional[List[DuckDBPyType]] = None, return_type: Optional[DuckDBPyType] = None, **kwargs) -> DuckDBPyConnection: ... +def remove_function(name: str, **kwargs) -> DuckDBPyConnection: ... +def sqltype(type_str: str, **kwargs) -> DuckDBPyType: ... +def dtype(type_str: str, **kwargs) -> DuckDBPyType: ... +def type(type_str: str, **kwargs) -> DuckDBPyType: ... +def array_type(type: DuckDBPyType, size: int, **kwargs) -> DuckDBPyType: ... +def list_type(type: DuckDBPyType, **kwargs) -> DuckDBPyType: ... +def union_type(members: DuckDBPyType, **kwargs) -> DuckDBPyType: ... +def string_type(collation: str = "", **kwargs) -> DuckDBPyType: ... +def enum_type(name: str, type: DuckDBPyType, values: List[Any], **kwargs) -> DuckDBPyType: ... +def decimal_type(width: int, scale: int, **kwargs) -> DuckDBPyType: ... +def struct_type(fields: Union[Dict[str, DuckDBPyType], List[str]], **kwargs) -> DuckDBPyType: ... +def row_type(fields: Union[Dict[str, DuckDBPyType], List[str]], **kwargs) -> DuckDBPyType: ... +def map_type(key: DuckDBPyType, value: DuckDBPyType, **kwargs) -> DuckDBPyType: ... +def duplicate(**kwargs) -> DuckDBPyConnection: ... +def execute(query: object, parameters: object = None, multiple_parameter_sets: bool = False, **kwargs) -> DuckDBPyConnection: ... +def executemany(query: object, parameters: object = None, **kwargs) -> DuckDBPyConnection: ... +def close(**kwargs) -> None: ... +def interrupt(**kwargs) -> None: ... +def fetchone(**kwargs) -> Optional[tuple]: ... +def fetchmany(size: int = 1, **kwargs) -> List[Any]: ... +def fetchall(**kwargs) -> List[Any]: ... +def fetchnumpy(**kwargs) -> dict: ... +def fetchdf(**kwargs) -> pandas.DataFrame: ... +def fetch_df(**kwargs) -> pandas.DataFrame: ... +def df(**kwargs) -> pandas.DataFrame: ... +def fetch_df_chunk(vectors_per_chunk: int = 1, **kwargs) -> pandas.DataFrame: ... +def pl(rows_per_batch: int = 1000000, **kwargs) -> polars.DataFrame: ... +def fetch_arrow_table(rows_per_batch: int = 1000000, **kwargs) -> pyarrow.lib.Table: ... +def arrow(rows_per_batch: int = 1000000, **kwargs) -> pyarrow.lib.Table: ... +def fetch_record_batch(rows_per_batch: int = 1000000, **kwargs) -> pyarrow.lib.RecordBatchReader: ... +def torch(**kwargs) -> dict: ... +def tf(**kwargs) -> dict: ... +def begin(**kwargs) -> DuckDBPyConnection: ... +def commit(**kwargs) -> DuckDBPyConnection: ... +def rollback(**kwargs) -> DuckDBPyConnection: ... +def append(table_name: str, df: pandas.DataFrame, **kwargs) -> DuckDBPyConnection: ... +def register(view_name: str, python_object: object, **kwargs) -> DuckDBPyConnection: ... +def unregister(view_name: str, **kwargs) -> DuckDBPyConnection: ... +def table(table_name: str, **kwargs) -> DuckDBPyRelation: ... +def view(view_name: str, **kwargs) -> DuckDBPyRelation: ... +def values(values: List[Any], **kwargs) -> DuckDBPyRelation: ... +def table_function(name: str, parameters: object = None, **kwargs) -> DuckDBPyRelation: ... +def read_json(name: str, **kwargs) -> DuckDBPyRelation: ... +def extract_statements(query: str, **kwargs) -> List[Statement]: ... +def sql(query: str, **kwargs) -> DuckDBPyRelation: ... +def query(query: str, **kwargs) -> DuckDBPyRelation: ... +def from_query(query: str, **kwargs) -> DuckDBPyRelation: ... +def read_csv(path_or_buffer: Union[str, StringIO, TextIOBase], **kwargs) -> DuckDBPyRelation: ... +def from_csv_auto(path_or_buffer: Union[str, StringIO, TextIOBase], **kwargs) -> DuckDBPyRelation: ... +def from_df(df: pandas.DataFrame, **kwargs) -> DuckDBPyRelation: ... +def from_arrow(arrow_object: object, **kwargs) -> DuckDBPyRelation: ... +def from_parquet(file_glob: str, binary_as_string: bool = False, **kwargs) -> DuckDBPyRelation: ... +def read_parquet(file_glob: str, binary_as_string: bool = False, **kwargs) -> DuckDBPyRelation: ... +def from_substrait(proto: str, **kwargs) -> DuckDBPyRelation: ... +def get_substrait(query: str, **kwargs) -> str: ... +def get_substrait_json(query: str, **kwargs) -> str: ... +def from_substrait_json(json: str, **kwargs) -> DuckDBPyRelation: ... +def get_table_names(query: str, **kwargs) -> List[str]: ... +def install_extension(extension: str, **kwargs) -> None: ... +def load_extension(extension: str, **kwargs) -> None: ... +def project(df: pandas.DataFrame, project_expr: str, **kwargs) -> DuckDBPyRelation: ... +def distinct(df: pandas.DataFrame, **kwargs) -> DuckDBPyRelation: ... +def write_csv(df: pandas.DataFrame, *args: Any, **kwargs) -> None: ... +def aggregate(df: pandas.DataFrame, aggr_expr: str, group_expr: str = "", **kwargs) -> DuckDBPyRelation: ... +def alias(df: pandas.DataFrame, alias: str, **kwargs) -> DuckDBPyRelation: ... +def filter(df: pandas.DataFrame, filter_expr: str, **kwargs) -> DuckDBPyRelation: ... +def limit(df: pandas.DataFrame, n: int, offset: int = 0, **kwargs) -> DuckDBPyRelation: ... +def order(df: pandas.DataFrame, order_expr: str, **kwargs) -> DuckDBPyRelation: ... +def query_df(df: pandas.DataFrame, virtual_table_name: str, sql_query: str, **kwargs) -> DuckDBPyRelation: ... +def description(**kwargs) -> Optional[List[Any]]: ... +def rowcount(**kwargs) -> int: ... +# END OF CONNECTION WRAPPER diff --git a/tools/pythonpkg/scripts/connection_wrapper_methods.json b/tools/pythonpkg/scripts/connection_wrapper_methods.json index 65ac4afd47f7..74f4cf31eba1 100644 --- a/tools/pythonpkg/scripts/connection_wrapper_methods.json +++ b/tools/pythonpkg/scripts/connection_wrapper_methods.json @@ -99,7 +99,7 @@ }, { "name": "description", - "return": "str" + "return": "Optional[List[Any]]" }, { "name": "rowcount", From 5ee33e0ca8cfa054ab20bf927999980eb7b85246 Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Apr 2024 23:41:45 +0200 Subject: [PATCH 088/201] generation script --- .../generate_connection_wrapper_stubs.py | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py diff --git a/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py b/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py new file mode 100644 index 000000000000..f20b967318f3 --- /dev/null +++ b/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py @@ -0,0 +1,116 @@ +import os +import json + +os.chdir(os.path.dirname(__file__)) + +JSON_PATH = os.path.join("connection_methods.json") +WRAPPER_JSON_PATH = os.path.join("connection_wrapper_methods.json") +DUCKDB_STUBS_FILE = os.path.join("..", "duckdb-stubs", "__init__.pyi") + +START_MARKER = "# START OF CONNECTION WRAPPER" +END_MARKER = "# END OF CONNECTION WRAPPER" + +# Read the DUCKDB_STUBS_FILE file +with open(DUCKDB_STUBS_FILE, 'r') as source_file: + source_code = source_file.readlines() + +start_index = -1 +end_index = -1 +for i, line in enumerate(source_code): + if line.startswith(START_MARKER): + # TODO: handle the case where the start marker appears multiple times + start_index = i + elif line.startswith(END_MARKER): + # TODO: ditto ^ + end_index = i + +if start_index == -1 or end_index == -1: + raise ValueError("Couldn't find start or end marker in source file") + +start_section = source_code[: start_index + 1] +end_section = source_code[end_index:] +# ---- Generate the definition code from the json ---- + +methods = [] + +# Read the JSON file +with open(JSON_PATH, 'r') as json_file: + connection_methods = json.load(json_file) + +with open(WRAPPER_JSON_PATH, 'r') as json_file: + wrapper_methods = json.load(json_file) + +methods.extend(connection_methods) +methods.extend(wrapper_methods) + +# On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke +# that's not possible on 'duckdb' so it becomes a function call with no arguments (i.e duckdb.description()) +READONLY_PROPERTY_NAMES = ['description', 'rowcount'] + +# These methods are not directly DuckDBPyConnection methods, +# they first call 'from_df' and then call a method on the created DuckDBPyRelation +SPECIAL_METHOD_NAMES = [x['name'] for x in wrapper_methods if x['name'] not in READONLY_PROPERTY_NAMES] + + +def create_arguments(arguments) -> list: + result = [] + for arg in arguments: + argument = f"{arg['name']}: {arg['type']}" + # Add the default argument if present + if 'default' in arg: + default = arg['default'] + argument += f" = {default}" + result.append(argument) + return result + + +def create_definition(name, method) -> str: + print(method) + definition = f"def {name}(" + arguments = [] + if name in SPECIAL_METHOD_NAMES: + arguments.append('df: pandas.DataFrame') + if 'args' in method: + arguments.extend(create_arguments(method['args'])) + if 'kwargs' in method: + arguments.append("**kwargs") + definition += ', '.join(arguments) + definition += ")" + definition += f" -> {method['return']}: ..." + return definition + + +# We have "duplicate" methods, which are overloaded +# maybe we should add @overload to these instead, but this is easier +written_methods = set() + +body = [] +for method in methods: + if isinstance(method['name'], list): + names = method['name'] + else: + names = [method['name']] + + # Artificially add 'connection' keyword argument + if 'kwargs' not in method: + method['kwargs'] = [] + method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection'}) + + for name in names: + if name in written_methods: + continue + body.append(create_definition(name, method)) + written_methods.add(name) + +# ---- End of generation code ---- + +with_newlines = [x + '\n' for x in body] +# Recreate the file content by concatenating all the pieces together + +new_content = start_section + with_newlines + end_section + +print(''.join(with_newlines)) + +# Write out the modified DUCKDB_STUBS_FILE file +with open(DUCKDB_STUBS_FILE, 'w') as source_file: + source_file.write("".join(new_content)) From 1bbda5b844b04218339d64bb958d8182b84b7b30 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 11 Apr 2024 00:10:26 +0200 Subject: [PATCH 089/201] format --- tools/pythonpkg/src/pyconnection.cpp | 191 ++++++++++++++++++++------- 1 file changed, 142 insertions(+), 49 deletions(-) diff --git a/tools/pythonpkg/src/pyconnection.cpp b/tools/pythonpkg/src/pyconnection.cpp index 3e3eab8127c4..937cb6f3edb4 100644 --- a/tools/pythonpkg/src/pyconnection.cpp +++ b/tools/pythonpkg/src/pyconnection.cpp @@ -127,72 +127,165 @@ py::object ArrowTableFromDataframe(const py::object &df) { static void InitializeConnectionMethods(py::class_> &m) { m.def("cursor", &DuckDBPyConnection::Cursor, "Create a duplicate of the current connection"); - m.def("register_filesystem", &DuckDBPyConnection::RegisterFilesystem, "Register a fsspec compliant filesystem", py::arg("filesystem")); - m.def("unregister_filesystem", &DuckDBPyConnection::UnregisterFilesystem, "Unregister a filesystem", py::arg("name")); - m.def("list_filesystems", &DuckDBPyConnection::ListFilesystems, "List registered filesystems, including builtin ones"); - m.def("filesystem_is_registered", &DuckDBPyConnection::FileSystemIsRegistered, "Check if a filesystem with the provided name is currently registered", py::arg("name")); - m.def("create_function", &DuckDBPyConnection::RegisterScalarUDF, "Create a DuckDB function out of the passing in Python function so it can be used in queries", py::arg("name"), py::arg("function"), py::arg("parameters") = py::none(), py::arg("return_type") = py::none(), py::kw_only(), py::arg("type") = PythonUDFType::NATIVE, py::arg("null_handling") = FunctionNullHandling::DEFAULT_NULL_HANDLING, py::arg("exception_handling") = PythonExceptionHandling::FORWARD_ERROR, py::arg("side_effects") = false); - m.def("remove_function", &DuckDBPyConnection::UnregisterUDF, "Remove a previously created function", py::arg("name")); - m.def("sqltype", &DuckDBPyConnection::Type, "Create a type object by parsing the 'type_str' string", py::arg("type_str")); - m.def("dtype", &DuckDBPyConnection::Type, "Create a type object by parsing the 'type_str' string", py::arg("type_str")); - m.def("type", &DuckDBPyConnection::Type, "Create a type object by parsing the 'type_str' string", py::arg("type_str")); - m.def("array_type", &DuckDBPyConnection::ArrayType, "Create an array type object of 'type'", py::arg("type").none(false), py::arg("size")); - m.def("list_type", &DuckDBPyConnection::ListType, "Create a list type object of 'type'", py::arg("type").none(false)); - m.def("union_type", &DuckDBPyConnection::UnionType, "Create a union type object from 'members'", py::arg("members").none(false)); - m.def("string_type", &DuckDBPyConnection::StringType, "Create a string type with an optional collation", py::arg("collation") = ""); - m.def("enum_type", &DuckDBPyConnection::EnumType, "Create an enum type of underlying 'type', consisting of the list of 'values'", py::arg("name"), py::arg("type"), py::arg("values")); - m.def("decimal_type", &DuckDBPyConnection::DecimalType, "Create a decimal type with 'width' and 'scale'", py::arg("width"), py::arg("scale")); - m.def("struct_type", &DuckDBPyConnection::StructType, "Create a struct type object from 'fields'", py::arg("fields")); + m.def("register_filesystem", &DuckDBPyConnection::RegisterFilesystem, "Register a fsspec compliant filesystem", + py::arg("filesystem")); + m.def("unregister_filesystem", &DuckDBPyConnection::UnregisterFilesystem, "Unregister a filesystem", + py::arg("name")); + m.def("list_filesystems", &DuckDBPyConnection::ListFilesystems, + "List registered filesystems, including builtin ones"); + m.def("filesystem_is_registered", &DuckDBPyConnection::FileSystemIsRegistered, + "Check if a filesystem with the provided name is currently registered", py::arg("name")); + m.def("create_function", &DuckDBPyConnection::RegisterScalarUDF, + "Create a DuckDB function out of the passing in Python function so it can be used in queries", + py::arg("name"), py::arg("function"), py::arg("parameters") = py::none(), py::arg("return_type") = py::none(), + py::kw_only(), py::arg("type") = PythonUDFType::NATIVE, + py::arg("null_handling") = FunctionNullHandling::DEFAULT_NULL_HANDLING, + py::arg("exception_handling") = PythonExceptionHandling::FORWARD_ERROR, py::arg("side_effects") = false); + m.def("remove_function", &DuckDBPyConnection::UnregisterUDF, "Remove a previously created function", + py::arg("name")); + m.def("sqltype", &DuckDBPyConnection::Type, "Create a type object by parsing the 'type_str' string", + py::arg("type_str")); + m.def("dtype", &DuckDBPyConnection::Type, "Create a type object by parsing the 'type_str' string", + py::arg("type_str")); + m.def("type", &DuckDBPyConnection::Type, "Create a type object by parsing the 'type_str' string", + py::arg("type_str")); + m.def("array_type", &DuckDBPyConnection::ArrayType, "Create an array type object of 'type'", + py::arg("type").none(false), py::arg("size")); + m.def("list_type", &DuckDBPyConnection::ListType, "Create a list type object of 'type'", + py::arg("type").none(false)); + m.def("union_type", &DuckDBPyConnection::UnionType, "Create a union type object from 'members'", + py::arg("members").none(false)); + m.def("string_type", &DuckDBPyConnection::StringType, "Create a string type with an optional collation", + py::arg("collation") = ""); + m.def("enum_type", &DuckDBPyConnection::EnumType, + "Create an enum type of underlying 'type', consisting of the list of 'values'", py::arg("name"), + py::arg("type"), py::arg("values")); + m.def("decimal_type", &DuckDBPyConnection::DecimalType, "Create a decimal type with 'width' and 'scale'", + py::arg("width"), py::arg("scale")); + m.def("struct_type", &DuckDBPyConnection::StructType, "Create a struct type object from 'fields'", + py::arg("fields")); m.def("row_type", &DuckDBPyConnection::StructType, "Create a struct type object from 'fields'", py::arg("fields")); - m.def("map_type", &DuckDBPyConnection::MapType, "Create a map type object from 'key_type' and 'value_type'", py::arg("key").none(false), py::arg("value").none(false)); + m.def("map_type", &DuckDBPyConnection::MapType, "Create a map type object from 'key_type' and 'value_type'", + py::arg("key").none(false), py::arg("value").none(false)); m.def("duplicate", &DuckDBPyConnection::Cursor, "Create a duplicate of the current connection"); - m.def("execute", &DuckDBPyConnection::Execute, "Execute the given SQL query, optionally using prepared statements with parameters set", py::arg("query"), py::arg("parameters") = py::none(), py::arg("multiple_parameter_sets") = false); - m.def("executemany", &DuckDBPyConnection::ExecuteMany, "Execute the given prepared statement multiple times using the list of parameter sets in parameters", py::arg("query"), py::arg("parameters") = py::none()); + m.def("execute", &DuckDBPyConnection::Execute, + "Execute the given SQL query, optionally using prepared statements with parameters set", py::arg("query"), + py::arg("parameters") = py::none(), py::arg("multiple_parameter_sets") = false); + m.def("executemany", &DuckDBPyConnection::ExecuteMany, + "Execute the given prepared statement multiple times using the list of parameter sets in parameters", + py::arg("query"), py::arg("parameters") = py::none()); m.def("close", &DuckDBPyConnection::Close, "Close the connection"); m.def("interrupt", &DuckDBPyConnection::Interrupt, "Interrupt pending operations"); m.def("fetchone", &DuckDBPyConnection::FetchOne, "Fetch a single row from a result following execute"); - m.def("fetchmany", &DuckDBPyConnection::FetchMany, "Fetch the next set of rows from a result following execute", py::arg("size") = 1); + m.def("fetchmany", &DuckDBPyConnection::FetchMany, "Fetch the next set of rows from a result following execute", + py::arg("size") = 1); m.def("fetchall", &DuckDBPyConnection::FetchAll, "Fetch all rows from a result following execute"); m.def("fetchnumpy", &DuckDBPyConnection::FetchNumpy, "Fetch a result as list of NumPy arrays following execute"); - m.def("fetchdf", &DuckDBPyConnection::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), py::arg("date_as_object") = false); - m.def("fetch_df", &DuckDBPyConnection::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), py::arg("date_as_object") = false); - m.def("df", &DuckDBPyConnection::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), py::arg("date_as_object") = false); - m.def("fetch_df_chunk", &DuckDBPyConnection::FetchDFChunk, "Fetch a chunk of the result as DataFrame following execute()", py::arg("vectors_per_chunk") = 1, py::kw_only(), py::arg("date_as_object") = false); - m.def("pl", &DuckDBPyConnection::FetchPolars, "Fetch a result as Polars DataFrame following execute()", py::arg("rows_per_batch") = 1000000); - m.def("fetch_arrow_table", &DuckDBPyConnection::FetchArrow, "Fetch a result as Arrow table following execute()", py::arg("rows_per_batch") = 1000000); - m.def("arrow", &DuckDBPyConnection::FetchArrow, "Fetch a result as Arrow table following execute()", py::arg("rows_per_batch") = 1000000); - m.def("fetch_record_batch", &DuckDBPyConnection::FetchRecordBatchReader, "Fetch an Arrow RecordBatchReader following execute()", py::arg("rows_per_batch") = 1000000); + m.def("fetchdf", &DuckDBPyConnection::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), + py::arg("date_as_object") = false); + m.def("fetch_df", &DuckDBPyConnection::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), + py::arg("date_as_object") = false); + m.def("df", &DuckDBPyConnection::FetchDF, "Fetch a result as DataFrame following execute()", py::kw_only(), + py::arg("date_as_object") = false); + m.def("fetch_df_chunk", &DuckDBPyConnection::FetchDFChunk, + "Fetch a chunk of the result as DataFrame following execute()", py::arg("vectors_per_chunk") = 1, + py::kw_only(), py::arg("date_as_object") = false); + m.def("pl", &DuckDBPyConnection::FetchPolars, "Fetch a result as Polars DataFrame following execute()", + py::arg("rows_per_batch") = 1000000); + m.def("fetch_arrow_table", &DuckDBPyConnection::FetchArrow, "Fetch a result as Arrow table following execute()", + py::arg("rows_per_batch") = 1000000); + m.def("arrow", &DuckDBPyConnection::FetchArrow, "Fetch a result as Arrow table following execute()", + py::arg("rows_per_batch") = 1000000); + m.def("fetch_record_batch", &DuckDBPyConnection::FetchRecordBatchReader, + "Fetch an Arrow RecordBatchReader following execute()", py::arg("rows_per_batch") = 1000000); m.def("torch", &DuckDBPyConnection::FetchPyTorch, "Fetch a result as dict of PyTorch Tensors following execute()"); m.def("tf", &DuckDBPyConnection::FetchTF, "Fetch a result as dict of TensorFlow Tensors following execute()"); m.def("begin", &DuckDBPyConnection::Begin, "Start a new transaction"); m.def("commit", &DuckDBPyConnection::Commit, "Commit changes performed within a transaction"); m.def("rollback", &DuckDBPyConnection::Rollback, "Roll back changes performed within a transaction"); - m.def("append", &DuckDBPyConnection::Append, "Append the passed DataFrame to the named table", py::arg("table_name"), py::arg("df"), py::kw_only(), py::arg("by_name") = false); - m.def("register", &DuckDBPyConnection::RegisterPythonObject, "Register the passed Python Object value for querying with a view", py::arg("view_name"), py::arg("python_object")); + m.def("append", &DuckDBPyConnection::Append, "Append the passed DataFrame to the named table", + py::arg("table_name"), py::arg("df"), py::kw_only(), py::arg("by_name") = false); + m.def("register", &DuckDBPyConnection::RegisterPythonObject, + "Register the passed Python Object value for querying with a view", py::arg("view_name"), + py::arg("python_object")); m.def("unregister", &DuckDBPyConnection::UnregisterPythonObject, "Unregister the view name", py::arg("view_name")); m.def("table", &DuckDBPyConnection::Table, "Create a relation object for the named table", py::arg("table_name")); m.def("view", &DuckDBPyConnection::View, "Create a relation object for the named view", py::arg("view_name")); m.def("values", &DuckDBPyConnection::Values, "Create a relation object from the passed values", py::arg("values")); - m.def("table_function", &DuckDBPyConnection::TableFunction, "Create a relation object from the named table function with given parameters", py::arg("name"), py::arg("parameters") = py::none()); - m.def("read_json", &DuckDBPyConnection::ReadJSON, "Create a relation object from the JSON file in 'name'", py::arg("name"), py::kw_only(), py::arg("columns") = py::none(), py::arg("sample_size") = py::none(), py::arg("maximum_depth") = py::none(), py::arg("records") = py::none(), py::arg("format") = py::none()); - m.def("extract_statements", &DuckDBPyConnection::ExtractStatements, "Parse the query string and extract the Statement object(s) produced", py::arg("query")); - m.def("sql", &DuckDBPyConnection::RunQuery, "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, otherwise run the query as-is.", py::arg("query"), py::kw_only(), py::arg("alias") = "", py::arg("params") = py::none()); - m.def("query", &DuckDBPyConnection::RunQuery, "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, otherwise run the query as-is.", py::arg("query"), py::kw_only(), py::arg("alias") = "", py::arg("params") = py::none()); - m.def("from_query", &DuckDBPyConnection::RunQuery, "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, otherwise run the query as-is.", py::arg("query"), py::kw_only(), py::arg("alias") = "", py::arg("params") = py::none()); - m.def("read_csv", &DuckDBPyConnection::ReadCSV, "Create a relation object from the CSV file in 'name'", py::arg("path_or_buffer"), py::kw_only(), py::arg("header") = py::none(), py::arg("compression") = py::none(), py::arg("sep") = py::none(), py::arg("delimiter") = py::none(), py::arg("dtype") = py::none(), py::arg("na_values") = py::none(), py::arg("skiprows") = py::none(), py::arg("quotechar") = py::none(), py::arg("escapechar") = py::none(), py::arg("encoding") = py::none(), py::arg("parallel") = py::none(), py::arg("date_format") = py::none(), py::arg("timestamp_format") = py::none(), py::arg("sample_size") = py::none(), py::arg("all_varchar") = py::none(), py::arg("normalize_names") = py::none(), py::arg("filename") = py::none(), py::arg("null_padding") = py::none(), py::arg("names") = py::none()); - m.def("from_csv_auto", &DuckDBPyConnection::ReadCSV, "Create a relation object from the CSV file in 'name'", py::arg("path_or_buffer"), py::kw_only(), py::arg("header") = py::none(), py::arg("compression") = py::none(), py::arg("sep") = py::none(), py::arg("delimiter") = py::none(), py::arg("dtype") = py::none(), py::arg("na_values") = py::none(), py::arg("skiprows") = py::none(), py::arg("quotechar") = py::none(), py::arg("escapechar") = py::none(), py::arg("encoding") = py::none(), py::arg("parallel") = py::none(), py::arg("date_format") = py::none(), py::arg("timestamp_format") = py::none(), py::arg("sample_size") = py::none(), py::arg("all_varchar") = py::none(), py::arg("normalize_names") = py::none(), py::arg("filename") = py::none(), py::arg("null_padding") = py::none(), py::arg("names") = py::none()); + m.def("table_function", &DuckDBPyConnection::TableFunction, + "Create a relation object from the named table function with given parameters", py::arg("name"), + py::arg("parameters") = py::none()); + m.def("read_json", &DuckDBPyConnection::ReadJSON, "Create a relation object from the JSON file in 'name'", + py::arg("name"), py::kw_only(), py::arg("columns") = py::none(), py::arg("sample_size") = py::none(), + py::arg("maximum_depth") = py::none(), py::arg("records") = py::none(), py::arg("format") = py::none()); + m.def("extract_statements", &DuckDBPyConnection::ExtractStatements, + "Parse the query string and extract the Statement object(s) produced", py::arg("query")); + m.def("sql", &DuckDBPyConnection::RunQuery, + "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, otherwise " + "run the query as-is.", + py::arg("query"), py::kw_only(), py::arg("alias") = "", py::arg("params") = py::none()); + m.def("query", &DuckDBPyConnection::RunQuery, + "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, otherwise " + "run the query as-is.", + py::arg("query"), py::kw_only(), py::arg("alias") = "", py::arg("params") = py::none()); + m.def("from_query", &DuckDBPyConnection::RunQuery, + "Run a SQL query. If it is a SELECT statement, create a relation object from the given SQL query, otherwise " + "run the query as-is.", + py::arg("query"), py::kw_only(), py::arg("alias") = "", py::arg("params") = py::none()); + m.def("read_csv", &DuckDBPyConnection::ReadCSV, "Create a relation object from the CSV file in 'name'", + py::arg("path_or_buffer"), py::kw_only(), py::arg("header") = py::none(), py::arg("compression") = py::none(), + py::arg("sep") = py::none(), py::arg("delimiter") = py::none(), py::arg("dtype") = py::none(), + py::arg("na_values") = py::none(), py::arg("skiprows") = py::none(), py::arg("quotechar") = py::none(), + py::arg("escapechar") = py::none(), py::arg("encoding") = py::none(), py::arg("parallel") = py::none(), + py::arg("date_format") = py::none(), py::arg("timestamp_format") = py::none(), + py::arg("sample_size") = py::none(), py::arg("all_varchar") = py::none(), + py::arg("normalize_names") = py::none(), py::arg("filename") = py::none(), + py::arg("null_padding") = py::none(), py::arg("names") = py::none()); + m.def("from_csv_auto", &DuckDBPyConnection::ReadCSV, "Create a relation object from the CSV file in 'name'", + py::arg("path_or_buffer"), py::kw_only(), py::arg("header") = py::none(), py::arg("compression") = py::none(), + py::arg("sep") = py::none(), py::arg("delimiter") = py::none(), py::arg("dtype") = py::none(), + py::arg("na_values") = py::none(), py::arg("skiprows") = py::none(), py::arg("quotechar") = py::none(), + py::arg("escapechar") = py::none(), py::arg("encoding") = py::none(), py::arg("parallel") = py::none(), + py::arg("date_format") = py::none(), py::arg("timestamp_format") = py::none(), + py::arg("sample_size") = py::none(), py::arg("all_varchar") = py::none(), + py::arg("normalize_names") = py::none(), py::arg("filename") = py::none(), + py::arg("null_padding") = py::none(), py::arg("names") = py::none()); m.def("from_df", &DuckDBPyConnection::FromDF, "Create a relation object from the DataFrame in df", py::arg("df")); - m.def("from_arrow", &DuckDBPyConnection::FromArrow, "Create a relation object from an Arrow object", py::arg("arrow_object")); - m.def("from_parquet", &DuckDBPyConnection::FromParquet, "Create a relation object from the Parquet files in file_glob", py::arg("file_glob"), py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, py::arg("compression") = py::none()); - m.def("read_parquet", &DuckDBPyConnection::FromParquet, "Create a relation object from the Parquet files in file_glob", py::arg("file_glob"), py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, py::arg("compression") = py::none()); - m.def("from_parquet", &DuckDBPyConnection::FromParquets, "Create a relation object from the Parquet files in file_globs", py::arg("file_globs"), py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, py::arg("compression") = py::none()); - m.def("read_parquet", &DuckDBPyConnection::FromParquets, "Create a relation object from the Parquet files in file_globs", py::arg("file_globs"), py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, py::arg("compression") = py::none()); - m.def("from_substrait", &DuckDBPyConnection::FromSubstrait, "Create a query object from protobuf plan", py::arg("proto")); - m.def("get_substrait", &DuckDBPyConnection::GetSubstrait, "Serialize a query to protobuf", py::arg("query"), py::kw_only(), py::arg("enable_optimizer") = true); - m.def("get_substrait_json", &DuckDBPyConnection::GetSubstraitJSON, "Serialize a query to protobuf on the JSON format", py::arg("query"), py::kw_only(), py::arg("enable_optimizer") = true); - m.def("from_substrait_json", &DuckDBPyConnection::FromSubstraitJSON, "Create a query object from a JSON protobuf plan", py::arg("json")); - m.def("get_table_names", &DuckDBPyConnection::GetTableNames, "Extract the required table names from a query", py::arg("query")); - m.def("install_extension", &DuckDBPyConnection::InstallExtension, "Install an extension by name", py::arg("extension"), py::kw_only(), py::arg("force_install") = false); + m.def("from_arrow", &DuckDBPyConnection::FromArrow, "Create a relation object from an Arrow object", + py::arg("arrow_object")); + m.def("from_parquet", &DuckDBPyConnection::FromParquet, + "Create a relation object from the Parquet files in file_glob", py::arg("file_glob"), + py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, + py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, + py::arg("compression") = py::none()); + m.def("read_parquet", &DuckDBPyConnection::FromParquet, + "Create a relation object from the Parquet files in file_glob", py::arg("file_glob"), + py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, + py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, + py::arg("compression") = py::none()); + m.def("from_parquet", &DuckDBPyConnection::FromParquets, + "Create a relation object from the Parquet files in file_globs", py::arg("file_globs"), + py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, + py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, + py::arg("compression") = py::none()); + m.def("read_parquet", &DuckDBPyConnection::FromParquets, + "Create a relation object from the Parquet files in file_globs", py::arg("file_globs"), + py::arg("binary_as_string") = false, py::kw_only(), py::arg("file_row_number") = false, + py::arg("filename") = false, py::arg("hive_partitioning") = false, py::arg("union_by_name") = false, + py::arg("compression") = py::none()); + m.def("from_substrait", &DuckDBPyConnection::FromSubstrait, "Create a query object from protobuf plan", + py::arg("proto")); + m.def("get_substrait", &DuckDBPyConnection::GetSubstrait, "Serialize a query to protobuf", py::arg("query"), + py::kw_only(), py::arg("enable_optimizer") = true); + m.def("get_substrait_json", &DuckDBPyConnection::GetSubstraitJSON, + "Serialize a query to protobuf on the JSON format", py::arg("query"), py::kw_only(), + py::arg("enable_optimizer") = true); + m.def("from_substrait_json", &DuckDBPyConnection::FromSubstraitJSON, + "Create a query object from a JSON protobuf plan", py::arg("json")); + m.def("get_table_names", &DuckDBPyConnection::GetTableNames, "Extract the required table names from a query", + py::arg("query")); + m.def("install_extension", &DuckDBPyConnection::InstallExtension, "Install an extension by name", + py::arg("extension"), py::kw_only(), py::arg("force_install") = false); m.def("load_extension", &DuckDBPyConnection::LoadExtension, "Load an installed extension", py::arg("extension")); } // END_OF_CONNECTION_METHODS From 6137227e900889c16261bd8b630de2dd026c86af Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 11 Apr 2024 10:01:05 +0200 Subject: [PATCH 090/201] the one script to rule them all --- .../scripts/generate_connection_code.py | 10 + .../scripts/generate_connection_methods.py | 189 ++++++------ .../scripts/generate_connection_stubs.py | 167 +++++------ .../generate_connection_wrapper_methods.py | 278 +++++++++--------- .../generate_connection_wrapper_stubs.py | 212 ++++++------- 5 files changed, 440 insertions(+), 416 deletions(-) create mode 100644 tools/pythonpkg/scripts/generate_connection_code.py diff --git a/tools/pythonpkg/scripts/generate_connection_code.py b/tools/pythonpkg/scripts/generate_connection_code.py new file mode 100644 index 000000000000..3737f83ad319 --- /dev/null +++ b/tools/pythonpkg/scripts/generate_connection_code.py @@ -0,0 +1,10 @@ +import generate_connection_methods +import generate_connection_stubs +import generate_connection_wrapper_methods +import generate_connection_wrapper_stubs + +if __name__ == '__main__': + generate_connection_methods.generate() + generate_connection_stubs.generate() + generate_connection_wrapper_methods.generate() + generate_connection_wrapper_stubs.generate() diff --git a/tools/pythonpkg/scripts/generate_connection_methods.py b/tools/pythonpkg/scripts/generate_connection_methods.py index efcf13304f0d..10aacd1418e3 100644 --- a/tools/pythonpkg/scripts/generate_connection_methods.py +++ b/tools/pythonpkg/scripts/generate_connection_methods.py @@ -11,97 +11,102 @@ ) END_MARKER = "} // END_OF_CONNECTION_METHODS" -# Read the PYCONNECTION_SOURCE file -with open(PYCONNECTION_SOURCE, 'r') as source_file: - source_code = source_file.readlines() - -# Locate the InitializeConnectionMethods function in it -start_index = -1 -end_index = -1 -for i, line in enumerate(source_code): - if line.startswith(INITIALIZE_METHOD): - start_index = i - elif line.startswith(END_MARKER): - end_index = i - -if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") - -start_section = source_code[: start_index + 1] -end_section = source_code[end_index:] -# ---- Generate the definition code from the json ---- - -# Read the JSON file -with open(JSON_PATH, 'r') as json_file: - connection_methods = json.load(json_file) - -body = [] - -DEFAULT_ARGUMENT_MAP = { - 'True': 'true', - 'False': 'false', - 'None': 'py::none()', - 'PythonUDFType.NATIVE': 'PythonUDFType::NATIVE', - 'PythonExceptionHandling.DEFAULT': 'PythonExceptionHandling::FORWARD_ERROR', - 'FunctionNullHandling.DEFAULT': 'FunctionNullHandling::DEFAULT_NULL_HANDLING', -} - - -def map_default(val): - if val in DEFAULT_ARGUMENT_MAP: - return DEFAULT_ARGUMENT_MAP[val] - return val - - -def create_arguments(arguments) -> list: - result = [] - for arg in arguments: - argument = f"py::arg(\"{arg['name']}\")" - if 'allow_none' in arg: - value = str(arg['allow_none']).lower() - argument += f".none({value})" - # Add the default argument if present - if 'default' in arg: - default = map_default(arg['default']) - argument += f" = {default}" - result.append(argument) - return result - - -def create_definition(name, method) -> str: - definition = f"m.def(\"{name}\"" - definition += ", " - definition += f"""&DuckDBPyConnection::{method['function']}""" - definition += ", " - definition += f"\"{method['docs']}\"" - if 'args' in method: + +def generate(): + # Read the PYCONNECTION_SOURCE file + with open(PYCONNECTION_SOURCE, 'r') as source_file: + source_code = source_file.readlines() + + start_index = -1 + end_index = -1 + for i, line in enumerate(source_code): + if line.startswith(INITIALIZE_METHOD): + if start_index != -1: + raise ValueError("Encountered the INITIALIZE_METHOD a second time, quitting!") + start_index = i + elif line.startswith(END_MARKER): + if end_index != -1: + raise ValueError("Encountered the END_MARKER a second time, quitting!") + end_index = i + + if start_index == -1 or end_index == -1: + raise ValueError("Couldn't find start or end marker in source file") + + start_section = source_code[: start_index + 1] + end_section = source_code[end_index:] + # ---- Generate the definition code from the json ---- + + # Read the JSON file + with open(JSON_PATH, 'r') as json_file: + connection_methods = json.load(json_file) + + DEFAULT_ARGUMENT_MAP = { + 'True': 'true', + 'False': 'false', + 'None': 'py::none()', + 'PythonUDFType.NATIVE': 'PythonUDFType::NATIVE', + 'PythonExceptionHandling.DEFAULT': 'PythonExceptionHandling::FORWARD_ERROR', + 'FunctionNullHandling.DEFAULT': 'FunctionNullHandling::DEFAULT_NULL_HANDLING', + } + + def map_default(val): + if val in DEFAULT_ARGUMENT_MAP: + return DEFAULT_ARGUMENT_MAP[val] + return val + + def create_arguments(arguments) -> list: + result = [] + for arg in arguments: + argument = f"py::arg(\"{arg['name']}\")" + if 'allow_none' in arg: + value = str(arg['allow_none']).lower() + argument += f".none({value})" + # Add the default argument if present + if 'default' in arg: + default = map_default(arg['default']) + argument += f" = {default}" + result.append(argument) + return result + + def create_definition(name, method) -> str: + definition = f"m.def(\"{name}\"" definition += ", " - arguments = create_arguments(method['args']) - definition += ', '.join(arguments) - if 'kwargs' in method: + definition += f"""&DuckDBPyConnection::{method['function']}""" definition += ", " - definition += "py::kw_only(), " - arguments = create_arguments(method['kwargs']) - definition += ', '.join(arguments) - definition += ");" - return definition - - -for method in connection_methods: - if isinstance(method['name'], list): - names = method['name'] - else: - names = [method['name']] - for name in names: - body.append(create_definition(name, method)) - -# ---- End of generation code ---- - -with_newlines = ['\t' + x + '\n' for x in body] -# Recreate the file content by concatenating all the pieces together - -new_content = start_section + with_newlines + end_section - -# Write out the modified PYCONNECTION_SOURCE file -with open(PYCONNECTION_SOURCE, 'w') as source_file: - source_file.write("".join(new_content)) + definition += f"\"{method['docs']}\"" + if 'args' in method: + definition += ", " + arguments = create_arguments(method['args']) + definition += ', '.join(arguments) + if 'kwargs' in method: + definition += ", " + definition += "py::kw_only(), " + arguments = create_arguments(method['kwargs']) + definition += ', '.join(arguments) + definition += ");" + return definition + + body = [] + for method in connection_methods: + if isinstance(method['name'], list): + names = method['name'] + else: + names = [method['name']] + for name in names: + body.append(create_definition(name, method)) + + # ---- End of generation code ---- + + with_newlines = ['\t' + x + '\n' for x in body] + # Recreate the file content by concatenating all the pieces together + + new_content = start_section + with_newlines + end_section + + # Write out the modified PYCONNECTION_SOURCE file + with open(PYCONNECTION_SOURCE, 'w') as source_file: + source_file.write("".join(new_content)) + + +if __name__ == '__main__': + raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") + # generate() diff --git a/tools/pythonpkg/scripts/generate_connection_stubs.py b/tools/pythonpkg/scripts/generate_connection_stubs.py index cacddd3b0ced..fc8c5e9d456a 100644 --- a/tools/pythonpkg/scripts/generate_connection_stubs.py +++ b/tools/pythonpkg/scripts/generate_connection_stubs.py @@ -9,85 +9,88 @@ START_MARKER = " # START OF CONNECTION METHODS" END_MARKER = " # END OF CONNECTION METHODS" -# Read the DUCKDB_STUBS_FILE file -with open(DUCKDB_STUBS_FILE, 'r') as source_file: - source_code = source_file.readlines() - -# Locate the InitializeConnectionMethods function in it -start_index = -1 -end_index = -1 -for i, line in enumerate(source_code): - if line.startswith(START_MARKER): - # TODO: handle the case where the start marker appears multiple times - start_index = i - elif line.startswith(END_MARKER): - # TODO: ditto ^ - end_index = i - -if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") - -start_section = source_code[: start_index + 1] -end_section = source_code[end_index:] -# ---- Generate the definition code from the json ---- - -# Read the JSON file -with open(JSON_PATH, 'r') as json_file: - connection_methods = json.load(json_file) - -body = [] - - -def create_arguments(arguments) -> list: - result = [] - for arg in arguments: - argument = f"{arg['name']}: {arg['type']}" - # Add the default argument if present - if 'default' in arg: - default = arg['default'] - argument += f" = {default}" - result.append(argument) - return result - - -def create_definition(name, method) -> str: - print(method) - definition = f"def {name}(self" - if 'args' in method: - definition += ", " - arguments = create_arguments(method['args']) - definition += ', '.join(arguments) - if 'kwargs' in method: - definition += ", **kwargs" - definition += ")" - definition += f" -> {method['return']}: ..." - return definition - - -# We have "duplicate" methods, which are overloaded -# maybe we should add @overload to these instead, but this is easier -written_methods = set() - -for method in connection_methods: - if isinstance(method['name'], list): - names = method['name'] - else: - names = [method['name']] - for name in names: - if name in written_methods: - continue - body.append(create_definition(name, method)) - written_methods.add(name) - -# ---- End of generation code ---- - -with_newlines = [' ' + x + '\n' for x in body] -# Recreate the file content by concatenating all the pieces together - -new_content = start_section + with_newlines + end_section - -print(with_newlines) - -# Write out the modified DUCKDB_STUBS_FILE file -with open(DUCKDB_STUBS_FILE, 'w') as source_file: - source_file.write("".join(new_content)) + +def generate(): + # Read the DUCKDB_STUBS_FILE file + with open(DUCKDB_STUBS_FILE, 'r') as source_file: + source_code = source_file.readlines() + + start_index = -1 + end_index = -1 + for i, line in enumerate(source_code): + if line.startswith(START_MARKER): + if start_index != -1: + raise ValueError("Encountered the START_MARKER a second time, quitting!") + start_index = i + elif line.startswith(END_MARKER): + if end_index != -1: + raise ValueError("Encountered the END_MARKER a second time, quitting!") + end_index = i + + if start_index == -1 or end_index == -1: + raise ValueError("Couldn't find start or end marker in source file") + + start_section = source_code[: start_index + 1] + end_section = source_code[end_index:] + # ---- Generate the definition code from the json ---- + + # Read the JSON file + with open(JSON_PATH, 'r') as json_file: + connection_methods = json.load(json_file) + + body = [] + + def create_arguments(arguments) -> list: + result = [] + for arg in arguments: + argument = f"{arg['name']}: {arg['type']}" + # Add the default argument if present + if 'default' in arg: + default = arg['default'] + argument += f" = {default}" + result.append(argument) + return result + + def create_definition(name, method) -> str: + print(method) + definition = f"def {name}(self" + if 'args' in method: + definition += ", " + arguments = create_arguments(method['args']) + definition += ', '.join(arguments) + if 'kwargs' in method: + definition += ", **kwargs" + definition += ")" + definition += f" -> {method['return']}: ..." + return definition + + # We have "duplicate" methods, which are overloaded + # maybe we should add @overload to these instead, but this is easier + written_methods = set() + + for method in connection_methods: + if isinstance(method['name'], list): + names = method['name'] + else: + names = [method['name']] + for name in names: + if name in written_methods: + continue + body.append(create_definition(name, method)) + written_methods.add(name) + + # ---- End of generation code ---- + + with_newlines = [' ' + x + '\n' for x in body] + # Recreate the file content by concatenating all the pieces together + + new_content = start_section + with_newlines + end_section + + # Write out the modified DUCKDB_STUBS_FILE file + with open(DUCKDB_STUBS_FILE, 'w') as source_file: + source_file.write("".join(new_content)) + + +if __name__ == '__main__': + raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") + # generate() diff --git a/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py b/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py index e716765f6816..9d1d1a0fefdc 100644 --- a/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py +++ b/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py @@ -10,141 +10,143 @@ START_MARKER = "# START OF CONNECTION WRAPPER" END_MARKER = "# END OF CONNECTION WRAPPER" -# Read the DUCKDB_INIT_FILE file -with open(DUCKDB_INIT_FILE, 'r') as source_file: - source_code = source_file.readlines() - -start_index = -1 -end_index = -1 -for i, line in enumerate(source_code): - if line.startswith(START_MARKER): - # TODO: handle the case where the start marker appears multiple times - start_index = i - elif line.startswith(END_MARKER): - # TODO: ditto ^ - end_index = i - -if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") - -start_section = source_code[: start_index + 1] -end_section = source_code[end_index:] -# ---- Generate the definition code from the json ---- - -methods = [] - -# Read the JSON file -with open(JSON_PATH, 'r') as json_file: - connection_methods = json.load(json_file) - -with open(WRAPPER_JSON_PATH, 'r') as json_file: - wrapper_methods = json.load(json_file) - -methods.extend(connection_methods) -methods.extend(wrapper_methods) - -# On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke -# that's not possible on 'duckdb' so it becomes a function call with no arguments (i.e duckdb.description()) -READONLY_PROPERTY_NAMES = ['description', 'rowcount'] - -# These methods are not directly DuckDBPyConnection methods, -# they first call 'from_df' and then call a method on the created DuckDBPyRelation -SPECIAL_METHOD_NAMES = [x['name'] for x in wrapper_methods if x['name'] not in READONLY_PROPERTY_NAMES] - - -def generate_arguments(name, method) -> str: - arguments = [] - if name in SPECIAL_METHOD_NAMES: - # We add 'df' to these methods because they operate on a DataFrame - arguments.append('df') - - if 'args' in method: - for arg in method['args']: - res = arg['name'] - if 'default' in arg: - res += f" = {arg['default']}" - arguments.append(res) - arguments.append('**kwargs') - return ', '.join(arguments) - - -def generate_parameters(name, method) -> str: - if name in READONLY_PROPERTY_NAMES: - return '' - arguments = [] - if 'args' in method: - for arg in method['args']: - arguments.append(f"{arg['name']}") - arguments.append('**kwargs') - result = ', '.join(arguments) - return '(' + result + ')' - - -def generate_function_call(name) -> str: - function_call = '' - if name in SPECIAL_METHOD_NAMES: - function_call += 'from_df(df).' - - REMAPPED_FUNCTIONS = {'alias': 'set_alias', 'query_df': 'query'} - if name in REMAPPED_FUNCTIONS: - function_name = REMAPPED_FUNCTIONS[name] - else: - function_name = name - function_call += function_name - return function_call - - -def create_definition(name, method) -> str: - print(method) - arguments = generate_arguments(name, method) - parameters = generate_parameters(name, method) - function_call = generate_function_call(name) - - func = f""" -def {name}({arguments}): - if 'connection' in kwargs: - conn = kwargs.pop('connection') - else: - conn = duckdb.connect(":default:") - return conn.{function_call}{parameters} -_exported_symbols.append('{name}') -""" - return func - - -# We have "duplicate" methods, which are overloaded -written_methods = set() - -body = [] -for method in methods: - if isinstance(method['name'], list): - names = method['name'] - else: - names = [method['name']] - - # Artificially add 'connection' keyword argument - if 'kwargs' not in method: - method['kwargs'] = [] - method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection'}) - - for name in names: - if name in written_methods: - continue - if name in ['arrow', 'df']: - # These methods are ambiguous and are handled in C++ code instead - continue - body.append(create_definition(name, method)) - written_methods.add(name) - -# ---- End of generation code ---- - -with_newlines = body -# Recreate the file content by concatenating all the pieces together - -new_content = start_section + with_newlines + end_section - -print(''.join(with_newlines)) - -# Write out the modified DUCKDB_INIT_FILE file -with open(DUCKDB_INIT_FILE, 'w') as source_file: - source_file.write("".join(new_content)) + +def generate(): + # Read the DUCKDB_INIT_FILE file + with open(DUCKDB_INIT_FILE, 'r') as source_file: + source_code = source_file.readlines() + + start_index = -1 + end_index = -1 + for i, line in enumerate(source_code): + if line.startswith(START_MARKER): + if start_index != -1: + raise ValueError("Encountered the START_MARKER a second time, quitting!") + start_index = i + elif line.startswith(END_MARKER): + if end_index != -1: + raise ValueError("Encountered the END_MARKER a second time, quitting!") + end_index = i + + if start_index == -1 or end_index == -1: + raise ValueError("Couldn't find start or end marker in source file") + + start_section = source_code[: start_index + 1] + end_section = source_code[end_index:] + # ---- Generate the definition code from the json ---- + + methods = [] + + # Read the JSON file + with open(JSON_PATH, 'r') as json_file: + connection_methods = json.load(json_file) + + with open(WRAPPER_JSON_PATH, 'r') as json_file: + wrapper_methods = json.load(json_file) + + methods.extend(connection_methods) + methods.extend(wrapper_methods) + + # On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke + # that's not possible on 'duckdb' so it becomes a function call with no arguments (i.e duckdb.description()) + READONLY_PROPERTY_NAMES = ['description', 'rowcount'] + + # These methods are not directly DuckDBPyConnection methods, + # they first call 'from_df' and then call a method on the created DuckDBPyRelation + SPECIAL_METHOD_NAMES = [x['name'] for x in wrapper_methods if x['name'] not in READONLY_PROPERTY_NAMES] + + def generate_arguments(name, method) -> str: + arguments = [] + if name in SPECIAL_METHOD_NAMES: + # We add 'df' to these methods because they operate on a DataFrame + arguments.append('df') + + if 'args' in method: + for arg in method['args']: + res = arg['name'] + if 'default' in arg: + res += f" = {arg['default']}" + arguments.append(res) + arguments.append('**kwargs') + return ', '.join(arguments) + + def generate_parameters(name, method) -> str: + if name in READONLY_PROPERTY_NAMES: + return '' + arguments = [] + if 'args' in method: + for arg in method['args']: + arguments.append(f"{arg['name']}") + arguments.append('**kwargs') + result = ', '.join(arguments) + return '(' + result + ')' + + def generate_function_call(name) -> str: + function_call = '' + if name in SPECIAL_METHOD_NAMES: + function_call += 'from_df(df).' + + REMAPPED_FUNCTIONS = {'alias': 'set_alias', 'query_df': 'query'} + if name in REMAPPED_FUNCTIONS: + function_name = REMAPPED_FUNCTIONS[name] + else: + function_name = name + function_call += function_name + return function_call + + def create_definition(name, method) -> str: + print(method) + arguments = generate_arguments(name, method) + parameters = generate_parameters(name, method) + function_call = generate_function_call(name) + + func = f""" + def {name}({arguments}): + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") + return conn.{function_call}{parameters} + _exported_symbols.append('{name}') + """ + return func + + # We have "duplicate" methods, which are overloaded + written_methods = set() + + body = [] + for method in methods: + if isinstance(method['name'], list): + names = method['name'] + else: + names = [method['name']] + + # Artificially add 'connection' keyword argument + if 'kwargs' not in method: + method['kwargs'] = [] + method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection'}) + + for name in names: + if name in written_methods: + continue + if name in ['arrow', 'df']: + # These methods are ambiguous and are handled in C++ code instead + continue + body.append(create_definition(name, method)) + written_methods.add(name) + + # ---- End of generation code ---- + + with_newlines = body + # Recreate the file content by concatenating all the pieces together + + new_content = start_section + with_newlines + end_section + + # Write out the modified DUCKDB_INIT_FILE file + with open(DUCKDB_INIT_FILE, 'w') as source_file: + source_file.write("".join(new_content)) + + +if __name__ == '__main__': + raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") + # generate() diff --git a/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py b/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py index f20b967318f3..ce6d4a58d869 100644 --- a/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py +++ b/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py @@ -10,107 +10,111 @@ START_MARKER = "# START OF CONNECTION WRAPPER" END_MARKER = "# END OF CONNECTION WRAPPER" -# Read the DUCKDB_STUBS_FILE file -with open(DUCKDB_STUBS_FILE, 'r') as source_file: - source_code = source_file.readlines() - -start_index = -1 -end_index = -1 -for i, line in enumerate(source_code): - if line.startswith(START_MARKER): - # TODO: handle the case where the start marker appears multiple times - start_index = i - elif line.startswith(END_MARKER): - # TODO: ditto ^ - end_index = i - -if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") - -start_section = source_code[: start_index + 1] -end_section = source_code[end_index:] -# ---- Generate the definition code from the json ---- - -methods = [] - -# Read the JSON file -with open(JSON_PATH, 'r') as json_file: - connection_methods = json.load(json_file) - -with open(WRAPPER_JSON_PATH, 'r') as json_file: - wrapper_methods = json.load(json_file) - -methods.extend(connection_methods) -methods.extend(wrapper_methods) - -# On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke -# that's not possible on 'duckdb' so it becomes a function call with no arguments (i.e duckdb.description()) -READONLY_PROPERTY_NAMES = ['description', 'rowcount'] - -# These methods are not directly DuckDBPyConnection methods, -# they first call 'from_df' and then call a method on the created DuckDBPyRelation -SPECIAL_METHOD_NAMES = [x['name'] for x in wrapper_methods if x['name'] not in READONLY_PROPERTY_NAMES] - - -def create_arguments(arguments) -> list: - result = [] - for arg in arguments: - argument = f"{arg['name']}: {arg['type']}" - # Add the default argument if present - if 'default' in arg: - default = arg['default'] - argument += f" = {default}" - result.append(argument) - return result - - -def create_definition(name, method) -> str: - print(method) - definition = f"def {name}(" - arguments = [] - if name in SPECIAL_METHOD_NAMES: - arguments.append('df: pandas.DataFrame') - if 'args' in method: - arguments.extend(create_arguments(method['args'])) - if 'kwargs' in method: - arguments.append("**kwargs") - definition += ', '.join(arguments) - definition += ")" - definition += f" -> {method['return']}: ..." - return definition - - -# We have "duplicate" methods, which are overloaded -# maybe we should add @overload to these instead, but this is easier -written_methods = set() - -body = [] -for method in methods: - if isinstance(method['name'], list): - names = method['name'] - else: - names = [method['name']] - - # Artificially add 'connection' keyword argument - if 'kwargs' not in method: - method['kwargs'] = [] - method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection'}) - - for name in names: - if name in written_methods: - continue - body.append(create_definition(name, method)) - written_methods.add(name) - -# ---- End of generation code ---- - -with_newlines = [x + '\n' for x in body] -# Recreate the file content by concatenating all the pieces together - -new_content = start_section + with_newlines + end_section - -print(''.join(with_newlines)) - -# Write out the modified DUCKDB_STUBS_FILE file -with open(DUCKDB_STUBS_FILE, 'w') as source_file: - source_file.write("".join(new_content)) + +def generate(): + # Read the DUCKDB_STUBS_FILE file + with open(DUCKDB_STUBS_FILE, 'r') as source_file: + source_code = source_file.readlines() + + start_index = -1 + end_index = -1 + for i, line in enumerate(source_code): + if line.startswith(START_MARKER): + if start_index != -1: + raise ValueError("Encountered the START_MARKER a second time, quitting!") + start_index = i + elif line.startswith(END_MARKER): + if end_index != -1: + raise ValueError("Encountered the END_MARKER a second time, quitting!") + end_index = i + + if start_index == -1 or end_index == -1: + raise ValueError("Couldn't find start or end marker in source file") + + start_section = source_code[: start_index + 1] + end_section = source_code[end_index:] + # ---- Generate the definition code from the json ---- + + methods = [] + + # Read the JSON file + with open(JSON_PATH, 'r') as json_file: + connection_methods = json.load(json_file) + + with open(WRAPPER_JSON_PATH, 'r') as json_file: + wrapper_methods = json.load(json_file) + + methods.extend(connection_methods) + methods.extend(wrapper_methods) + + # On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke + # that's not possible on 'duckdb' so it becomes a function call with no arguments (i.e duckdb.description()) + READONLY_PROPERTY_NAMES = ['description', 'rowcount'] + + # These methods are not directly DuckDBPyConnection methods, + # they first call 'from_df' and then call a method on the created DuckDBPyRelation + SPECIAL_METHOD_NAMES = [x['name'] for x in wrapper_methods if x['name'] not in READONLY_PROPERTY_NAMES] + + def create_arguments(arguments) -> list: + result = [] + for arg in arguments: + argument = f"{arg['name']}: {arg['type']}" + # Add the default argument if present + if 'default' in arg: + default = arg['default'] + argument += f" = {default}" + result.append(argument) + return result + + def create_definition(name, method) -> str: + print(method) + definition = f"def {name}(" + arguments = [] + if name in SPECIAL_METHOD_NAMES: + arguments.append('df: pandas.DataFrame') + if 'args' in method: + arguments.extend(create_arguments(method['args'])) + if 'kwargs' in method: + arguments.append("**kwargs") + definition += ', '.join(arguments) + definition += ")" + definition += f" -> {method['return']}: ..." + return definition + + # We have "duplicate" methods, which are overloaded + # maybe we should add @overload to these instead, but this is easier + written_methods = set() + + body = [] + for method in methods: + if isinstance(method['name'], list): + names = method['name'] + else: + names = [method['name']] + + # Artificially add 'connection' keyword argument + if 'kwargs' not in method: + method['kwargs'] = [] + method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection'}) + + for name in names: + if name in written_methods: + continue + body.append(create_definition(name, method)) + written_methods.add(name) + + # ---- End of generation code ---- + + with_newlines = [x + '\n' for x in body] + # Recreate the file content by concatenating all the pieces together + + new_content = start_section + with_newlines + end_section + + # Write out the modified DUCKDB_STUBS_FILE file + with open(DUCKDB_STUBS_FILE, 'w') as source_file: + source_file.write("".join(new_content)) + + +if __name__ == '__main__': + raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") + # generate() From 64c8921d7e3cc0c0396983ef09accb35bde72cb4 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 11 Apr 2024 10:05:41 +0200 Subject: [PATCH 091/201] cleanup --- tools/pythonpkg/duckdb/__init__.py | 150 +++++++++--------- .../scripts/generate_connection_stubs.py | 1 - .../generate_connection_wrapper_methods.py | 17 +- .../generate_connection_wrapper_stubs.py | 1 - 4 files changed, 83 insertions(+), 86 deletions(-) diff --git a/tools/pythonpkg/duckdb/__init__.py b/tools/pythonpkg/duckdb/__init__.py index 5d2d56e2dd4e..c6683a712782 100644 --- a/tools/pythonpkg/duckdb/__init__.py +++ b/tools/pythonpkg/duckdb/__init__.py @@ -57,7 +57,7 @@ def cursor(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.cursor(**kwargs) @@ -65,7 +65,7 @@ def cursor(**kwargs): def register_filesystem(filesystem, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.register_filesystem(filesystem, **kwargs) @@ -73,7 +73,7 @@ def register_filesystem(filesystem, **kwargs): def unregister_filesystem(name, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.unregister_filesystem(name, **kwargs) @@ -81,7 +81,7 @@ def unregister_filesystem(name, **kwargs): def list_filesystems(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.list_filesystems(**kwargs) @@ -89,7 +89,7 @@ def list_filesystems(**kwargs): def filesystem_is_registered(name, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.filesystem_is_registered(name, **kwargs) @@ -97,7 +97,7 @@ def filesystem_is_registered(name, **kwargs): def create_function(name, function, parameters = None, return_type = None, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.create_function(name, function, parameters, return_type, **kwargs) @@ -105,7 +105,7 @@ def create_function(name, function, parameters = None, return_type = None, **kwa def remove_function(name, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.remove_function(name, **kwargs) @@ -113,7 +113,7 @@ def remove_function(name, **kwargs): def sqltype(type_str, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.sqltype(type_str, **kwargs) @@ -121,7 +121,7 @@ def sqltype(type_str, **kwargs): def dtype(type_str, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.dtype(type_str, **kwargs) @@ -129,7 +129,7 @@ def dtype(type_str, **kwargs): def type(type_str, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.type(type_str, **kwargs) @@ -137,7 +137,7 @@ def type(type_str, **kwargs): def array_type(type, size, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.array_type(type, size, **kwargs) @@ -145,7 +145,7 @@ def array_type(type, size, **kwargs): def list_type(type, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.list_type(type, **kwargs) @@ -153,7 +153,7 @@ def list_type(type, **kwargs): def union_type(members, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.union_type(members, **kwargs) @@ -161,7 +161,7 @@ def union_type(members, **kwargs): def string_type(collation = "", **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.string_type(collation, **kwargs) @@ -169,7 +169,7 @@ def string_type(collation = "", **kwargs): def enum_type(name, type, values, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.enum_type(name, type, values, **kwargs) @@ -177,7 +177,7 @@ def enum_type(name, type, values, **kwargs): def decimal_type(width, scale, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.decimal_type(width, scale, **kwargs) @@ -185,7 +185,7 @@ def decimal_type(width, scale, **kwargs): def struct_type(fields, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.struct_type(fields, **kwargs) @@ -193,7 +193,7 @@ def struct_type(fields, **kwargs): def row_type(fields, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.row_type(fields, **kwargs) @@ -201,7 +201,7 @@ def row_type(fields, **kwargs): def map_type(key, value, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.map_type(key, value, **kwargs) @@ -209,7 +209,7 @@ def map_type(key, value, **kwargs): def duplicate(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.duplicate(**kwargs) @@ -217,7 +217,7 @@ def duplicate(**kwargs): def execute(query, parameters = None, multiple_parameter_sets = False, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.execute(query, parameters, multiple_parameter_sets, **kwargs) @@ -225,7 +225,7 @@ def execute(query, parameters = None, multiple_parameter_sets = False, **kwargs) def executemany(query, parameters = None, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.executemany(query, parameters, **kwargs) @@ -233,7 +233,7 @@ def executemany(query, parameters = None, **kwargs): def close(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.close(**kwargs) @@ -241,7 +241,7 @@ def close(**kwargs): def interrupt(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.interrupt(**kwargs) @@ -249,7 +249,7 @@ def interrupt(**kwargs): def fetchone(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.fetchone(**kwargs) @@ -257,7 +257,7 @@ def fetchone(**kwargs): def fetchmany(size = 1, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.fetchmany(size, **kwargs) @@ -265,7 +265,7 @@ def fetchmany(size = 1, **kwargs): def fetchall(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.fetchall(**kwargs) @@ -273,7 +273,7 @@ def fetchall(**kwargs): def fetchnumpy(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.fetchnumpy(**kwargs) @@ -281,7 +281,7 @@ def fetchnumpy(**kwargs): def fetchdf(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.fetchdf(**kwargs) @@ -289,7 +289,7 @@ def fetchdf(**kwargs): def fetch_df(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.fetch_df(**kwargs) @@ -297,7 +297,7 @@ def fetch_df(**kwargs): def fetch_df_chunk(vectors_per_chunk = 1, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.fetch_df_chunk(vectors_per_chunk, **kwargs) @@ -305,7 +305,7 @@ def fetch_df_chunk(vectors_per_chunk = 1, **kwargs): def pl(rows_per_batch = 1000000, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.pl(rows_per_batch, **kwargs) @@ -313,7 +313,7 @@ def pl(rows_per_batch = 1000000, **kwargs): def fetch_arrow_table(rows_per_batch = 1000000, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.fetch_arrow_table(rows_per_batch, **kwargs) @@ -321,7 +321,7 @@ def fetch_arrow_table(rows_per_batch = 1000000, **kwargs): def fetch_record_batch(rows_per_batch = 1000000, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.fetch_record_batch(rows_per_batch, **kwargs) @@ -329,7 +329,7 @@ def fetch_record_batch(rows_per_batch = 1000000, **kwargs): def torch(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.torch(**kwargs) @@ -337,7 +337,7 @@ def torch(**kwargs): def tf(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.tf(**kwargs) @@ -345,7 +345,7 @@ def tf(**kwargs): def begin(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.begin(**kwargs) @@ -353,7 +353,7 @@ def begin(**kwargs): def commit(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.commit(**kwargs) @@ -361,7 +361,7 @@ def commit(**kwargs): def rollback(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.rollback(**kwargs) @@ -369,7 +369,7 @@ def rollback(**kwargs): def append(table_name, df, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.append(table_name, df, **kwargs) @@ -377,7 +377,7 @@ def append(table_name, df, **kwargs): def register(view_name, python_object, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.register(view_name, python_object, **kwargs) @@ -385,7 +385,7 @@ def register(view_name, python_object, **kwargs): def unregister(view_name, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.unregister(view_name, **kwargs) @@ -393,7 +393,7 @@ def unregister(view_name, **kwargs): def table(table_name, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.table(table_name, **kwargs) @@ -401,7 +401,7 @@ def table(table_name, **kwargs): def view(view_name, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.view(view_name, **kwargs) @@ -409,7 +409,7 @@ def view(view_name, **kwargs): def values(values, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.values(values, **kwargs) @@ -417,7 +417,7 @@ def values(values, **kwargs): def table_function(name, parameters = None, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.table_function(name, parameters, **kwargs) @@ -425,7 +425,7 @@ def table_function(name, parameters = None, **kwargs): def read_json(name, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.read_json(name, **kwargs) @@ -433,7 +433,7 @@ def read_json(name, **kwargs): def extract_statements(query, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.extract_statements(query, **kwargs) @@ -441,7 +441,7 @@ def extract_statements(query, **kwargs): def sql(query, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.sql(query, **kwargs) @@ -449,7 +449,7 @@ def sql(query, **kwargs): def query(query, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.query(query, **kwargs) @@ -457,7 +457,7 @@ def query(query, **kwargs): def from_query(query, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_query(query, **kwargs) @@ -465,7 +465,7 @@ def from_query(query, **kwargs): def read_csv(path_or_buffer, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.read_csv(path_or_buffer, **kwargs) @@ -473,7 +473,7 @@ def read_csv(path_or_buffer, **kwargs): def from_csv_auto(path_or_buffer, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_csv_auto(path_or_buffer, **kwargs) @@ -481,7 +481,7 @@ def from_csv_auto(path_or_buffer, **kwargs): def from_df(df, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_df(df, **kwargs) @@ -489,7 +489,7 @@ def from_df(df, **kwargs): def from_arrow(arrow_object, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_arrow(arrow_object, **kwargs) @@ -497,7 +497,7 @@ def from_arrow(arrow_object, **kwargs): def from_parquet(file_glob, binary_as_string = False, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_parquet(file_glob, binary_as_string, **kwargs) @@ -505,7 +505,7 @@ def from_parquet(file_glob, binary_as_string = False, **kwargs): def read_parquet(file_glob, binary_as_string = False, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.read_parquet(file_glob, binary_as_string, **kwargs) @@ -513,7 +513,7 @@ def read_parquet(file_glob, binary_as_string = False, **kwargs): def from_substrait(proto, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_substrait(proto, **kwargs) @@ -521,7 +521,7 @@ def from_substrait(proto, **kwargs): def get_substrait(query, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.get_substrait(query, **kwargs) @@ -529,7 +529,7 @@ def get_substrait(query, **kwargs): def get_substrait_json(query, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.get_substrait_json(query, **kwargs) @@ -537,7 +537,7 @@ def get_substrait_json(query, **kwargs): def from_substrait_json(json, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_substrait_json(json, **kwargs) @@ -545,7 +545,7 @@ def from_substrait_json(json, **kwargs): def get_table_names(query, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.get_table_names(query, **kwargs) @@ -553,7 +553,7 @@ def get_table_names(query, **kwargs): def install_extension(extension, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.install_extension(extension, **kwargs) @@ -561,7 +561,7 @@ def install_extension(extension, **kwargs): def load_extension(extension, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.load_extension(extension, **kwargs) @@ -569,7 +569,7 @@ def load_extension(extension, **kwargs): def project(df, project_expr, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_df(df).project(project_expr, **kwargs) @@ -577,7 +577,7 @@ def project(df, project_expr, **kwargs): def distinct(df, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_df(df).distinct(**kwargs) @@ -585,7 +585,7 @@ def distinct(df, **kwargs): def write_csv(df, *args, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_df(df).write_csv(*args, **kwargs) @@ -593,7 +593,7 @@ def write_csv(df, *args, **kwargs): def aggregate(df, aggr_expr, group_expr = "", **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_df(df).aggregate(aggr_expr, group_expr, **kwargs) @@ -601,7 +601,7 @@ def aggregate(df, aggr_expr, group_expr = "", **kwargs): def alias(df, alias, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_df(df).set_alias(alias, **kwargs) @@ -609,7 +609,7 @@ def alias(df, alias, **kwargs): def filter(df, filter_expr, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_df(df).filter(filter_expr, **kwargs) @@ -617,7 +617,7 @@ def filter(df, filter_expr, **kwargs): def limit(df, n, offset = 0, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_df(df).limit(n, offset, **kwargs) @@ -625,7 +625,7 @@ def limit(df, n, offset = 0, **kwargs): def order(df, order_expr, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_df(df).order(order_expr, **kwargs) @@ -633,7 +633,7 @@ def order(df, order_expr, **kwargs): def query_df(df, virtual_table_name, sql_query, **kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.from_df(df).query(virtual_table_name, sql_query, **kwargs) @@ -641,7 +641,7 @@ def query_df(df, virtual_table_name, sql_query, **kwargs): def description(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.description @@ -649,7 +649,7 @@ def description(**kwargs): def rowcount(**kwargs): if 'connection' in kwargs: - conn = kwargs.pop('connection') + conn = kwargs.pop('connection') else: conn = duckdb.connect(":default:") return conn.rowcount diff --git a/tools/pythonpkg/scripts/generate_connection_stubs.py b/tools/pythonpkg/scripts/generate_connection_stubs.py index fc8c5e9d456a..7343ec3ec889 100644 --- a/tools/pythonpkg/scripts/generate_connection_stubs.py +++ b/tools/pythonpkg/scripts/generate_connection_stubs.py @@ -52,7 +52,6 @@ def create_arguments(arguments) -> list: return result def create_definition(name, method) -> str: - print(method) definition = f"def {name}(self" if 'args' in method: definition += ", " diff --git a/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py b/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py index 9d1d1a0fefdc..ac6345212c80 100644 --- a/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py +++ b/tools/pythonpkg/scripts/generate_connection_wrapper_methods.py @@ -95,20 +95,19 @@ def generate_function_call(name) -> str: return function_call def create_definition(name, method) -> str: - print(method) arguments = generate_arguments(name, method) parameters = generate_parameters(name, method) function_call = generate_function_call(name) func = f""" - def {name}({arguments}): - if 'connection' in kwargs: - conn = kwargs.pop('connection') - else: - conn = duckdb.connect(":default:") - return conn.{function_call}{parameters} - _exported_symbols.append('{name}') - """ +def {name}({arguments}): + if 'connection' in kwargs: + conn = kwargs.pop('connection') + else: + conn = duckdb.connect(":default:") + return conn.{function_call}{parameters} +_exported_symbols.append('{name}') +""" return func # We have "duplicate" methods, which are overloaded diff --git a/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py b/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py index ce6d4a58d869..7665a91906fd 100644 --- a/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py +++ b/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py @@ -67,7 +67,6 @@ def create_arguments(arguments) -> list: return result def create_definition(name, method) -> str: - print(method) definition = f"def {name}(" arguments = [] if name in SPECIAL_METHOD_NAMES: From cbf1b651d1cb8f966a4e608870afa4cf2ba83bf4 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 11 Apr 2024 10:08:58 +0200 Subject: [PATCH 092/201] add connection code generation script to 'generate-files' --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index a6fd0dcd3ea0..903363ad5039 100644 --- a/Makefile +++ b/Makefile @@ -443,6 +443,7 @@ generate-files: python3 scripts/generate_functions.py python3 scripts/generate_serialization.py python3 scripts/generate_enum_util.py + python3 tools/pythonpkg/scripts/generate_connection_code.py ./scripts/generate_micro_extended.sh bundle-library: release From cdec49e5fd1cc3f9889c29f2e05e537f5082da56 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 11 Apr 2024 14:16:10 +0200 Subject: [PATCH 093/201] reformat after generation --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index 903363ad5039..eb70c8bd5fda 100644 --- a/Makefile +++ b/Makefile @@ -445,6 +445,8 @@ generate-files: python3 scripts/generate_enum_util.py python3 tools/pythonpkg/scripts/generate_connection_code.py ./scripts/generate_micro_extended.sh +# Run the formatter again after (re)generating the files + $(MAKE) format-main bundle-library: release cd build/release && \ From f66fed072a8b964331d26d37b8ddc552936aadff Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 11 Apr 2024 14:24:22 +0200 Subject: [PATCH 094/201] the shared_ptr internals live in .ipp files so we can separate enable_shared_from_this, shared_ptr and weak_ptr --- scripts/package_build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/package_build.py b/scripts/package_build.py index ac737e88c141..3499f02c3747 100644 --- a/scripts/package_build.py +++ b/scripts/package_build.py @@ -343,7 +343,7 @@ def generate_unity_builds(source_list, nsplits, linenumbers): unity_build = True # re-order the files in the unity build so that they follow the same order as the CMake scores = {} - filenames = [x[0] for x in re.findall('([a-zA-Z0-9_]+[.](cpp|cc|c|cxx))', text)] + filenames = [x[0] for x in re.findall('([a-zA-Z0-9_]+[.](cpp|ipp|cc|c|cxx))', text)] score = 0 for filename in filenames: scores[filename] = score From d6bcc62d239e892e3fba990a4ccb3e29d83cba03 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 11 Apr 2024 14:25:58 +0200 Subject: [PATCH 095/201] changed in the wrong place --- scripts/amalgamation.py | 2 +- scripts/package_build.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/amalgamation.py b/scripts/amalgamation.py index e26c71049f31..325cc19f1521 100644 --- a/scripts/amalgamation.py +++ b/scripts/amalgamation.py @@ -376,7 +376,7 @@ def list_include_files_recursive(dname, file_list): fpath = os.path.join(dname, fname) if os.path.isdir(fpath): list_include_files_recursive(fpath, file_list) - elif fname.endswith(('.hpp', '.h', '.hh', '.tcc', '.inc')): + elif fname.endswith(('.hpp', '.ipp', '.h', '.hh', '.tcc', '.inc')): file_list.append(fpath) diff --git a/scripts/package_build.py b/scripts/package_build.py index 3499f02c3747..ac737e88c141 100644 --- a/scripts/package_build.py +++ b/scripts/package_build.py @@ -343,7 +343,7 @@ def generate_unity_builds(source_list, nsplits, linenumbers): unity_build = True # re-order the files in the unity build so that they follow the same order as the CMake scores = {} - filenames = [x[0] for x in re.findall('([a-zA-Z0-9_]+[.](cpp|ipp|cc|c|cxx))', text)] + filenames = [x[0] for x in re.findall('([a-zA-Z0-9_]+[.](cpp|cc|c|cxx))', text)] score = 0 for filename in filenames: scores[filename] = score From e733fb7bf6218432356108349e1fc8c6cef7c725 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 11 Apr 2024 17:33:11 +0200 Subject: [PATCH 096/201] duckdb::shared_ptr in duckdb_java --- tools/jdbc/src/jni/duckdb_java.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/jdbc/src/jni/duckdb_java.cpp b/tools/jdbc/src/jni/duckdb_java.cpp index e019526c65be..f7fed87b26d9 100644 --- a/tools/jdbc/src/jni/duckdb_java.cpp +++ b/tools/jdbc/src/jni/duckdb_java.cpp @@ -1,6 +1,7 @@ #include "functions.hpp" #include "duckdb.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/common/shared_ptr.hpp" #include "duckdb/main/client_data.hpp" #include "duckdb/catalog/catalog_search_path.hpp" #include "duckdb/main/appender.hpp" @@ -348,10 +349,10 @@ static Value create_value_from_bigdecimal(JNIEnv *env, jobject decimal) { * DuckDB is released as well. */ struct ConnectionHolder { - const shared_ptr db; + const duckdb::shared_ptr db; const duckdb::unique_ptr connection; - ConnectionHolder(shared_ptr _db) : db(_db), connection(make_uniq(*_db)) { + ConnectionHolder(duckdb::shared_ptr _db) : db(_db), connection(make_uniq(*_db)) { } }; From 31ca8e46e056588ace09654ae574d2f39f012f8b Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 11 Apr 2024 21:46:29 +0200 Subject: [PATCH 097/201] clang tidy fixes --- .clang-tidy | 2 +- .../duckdb/common/enable_shared_from_this.ipp | 14 +++--- src/include/duckdb/common/helper.hpp | 2 +- src/include/duckdb/common/shared_ptr.hpp | 14 +++--- src/include/duckdb/common/shared_ptr.ipp | 46 +++++++++---------- src/include/duckdb/common/weak_ptr.ipp | 14 +++--- 6 files changed, 46 insertions(+), 46 deletions(-) diff --git a/.clang-tidy b/.clang-tidy index ee2b06b745d7..c1b42c62ec42 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -46,7 +46,7 @@ CheckOptions: - key: readability-identifier-naming.VariableCase value: lower_case - key: modernize-use-emplace.SmartPointers - value: '::std::shared_ptr;::duckdb::unique_ptr;::std::auto_ptr;::std::weak_ptr' + value: '::duckdb::shared_ptr;::duckdb::unique_ptr;::std::auto_ptr;::duckdb::weak_ptr' - key: cppcoreguidelines-rvalue-reference-param-not-moved.IgnoreUnnamedParams value: true diff --git a/src/include/duckdb/common/enable_shared_from_this.ipp b/src/include/duckdb/common/enable_shared_from_this.ipp index d68d20033aa7..85cdd2205411 100644 --- a/src/include/duckdb/common/enable_shared_from_this.ipp +++ b/src/include/duckdb/common/enable_shared_from_this.ipp @@ -1,18 +1,18 @@ namespace duckdb { template -class enable_shared_from_this { +class enable_shared_from_this { // NOLINT: invalid case style public: template friend class shared_ptr; private: - mutable weak_ptr __weak_this_; + mutable weak_ptr __weak_this_; // NOLINT: __weak_this_ is reserved protected: constexpr enable_shared_from_this() noexcept { } - enable_shared_from_this(enable_shared_from_this const &) noexcept { + enable_shared_from_this(enable_shared_from_this const &) noexcept { // NOLINT: not marked as explicit } enable_shared_from_this &operator=(enable_shared_from_this const &) noexcept { return *this; @@ -21,19 +21,19 @@ protected: } public: - shared_ptr shared_from_this() { + shared_ptr shared_from_this() { // NOLINT: invalid case style return shared_ptr(__weak_this_); } - shared_ptr shared_from_this() const { + shared_ptr shared_from_this() const { // NOLINT: invalid case style return shared_ptr(__weak_this_); } #if _LIBCPP_STD_VER >= 17 - weak_ptr weak_from_this() noexcept { + weak_ptr weak_from_this() noexcept { // NOLINT: invalid case style return __weak_this_; } - weak_ptr weak_from_this() const noexcept { + weak_ptr weak_from_this() const noexcept { // NOLINT: invalid case style return __weak_this_; } #endif // _LIBCPP_STD_VER >= 17 diff --git a/src/include/duckdb/common/helper.hpp b/src/include/duckdb/common/helper.hpp index 2f12dca961f0..ed3e6e1f1be3 100644 --- a/src/include/duckdb/common/helper.hpp +++ b/src/include/duckdb/common/helper.hpp @@ -118,7 +118,7 @@ unique_ptr unique_ptr_cast(unique_ptr src) { // NOLINT: mimic std style } template -shared_ptr shared_ptr_cast(shared_ptr src) { +shared_ptr shared_ptr_cast(shared_ptr src) { // NOLINT: mimic std style return shared_ptr(std::static_pointer_cast(src.internal)); } diff --git a/src/include/duckdb/common/shared_ptr.hpp b/src/include/duckdb/common/shared_ptr.hpp index 1e6e1b2523b0..f5ca7d762c3a 100644 --- a/src/include/duckdb/common/shared_ptr.hpp +++ b/src/include/duckdb/common/shared_ptr.hpp @@ -22,17 +22,17 @@ namespace duckdb { // originally named '__compatible_with' #if _LIBCPP_STD_VER >= 17 -template +template struct __bounded_convertible_to_unbounded : false_type {}; -template -struct __bounded_convertible_to_unbounded<_Up[_Np], _Tp> : is_same, _Up[]> {}; +template +struct __bounded_convertible_to_unbounded<_Up[_Np], T> : is_same, _Up[]> {}; -template -struct compatible_with_t : _Or, __bounded_convertible_to_unbounded<_Yp, _Tp>> {}; +template +struct compatible_with_t : _Or, __bounded_convertible_to_unbounded> {}; #else -template -struct compatible_with_t : std::is_convertible<_Yp *, _Tp *> {}; +template +struct compatible_with_t : std::is_convertible {}; // NOLINT: invalid case style #endif // _LIBCPP_STD_VER >= 17 } // namespace duckdb diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index c36f28622f02..840f101588d4 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -7,7 +7,7 @@ template class enable_shared_from_this; template -class shared_ptr { +class shared_ptr { // NOLINT: invalid case style public: using original = std::shared_ptr; using element_type = typename original::element_type; @@ -49,33 +49,33 @@ public: explicit shared_ptr(U *ptr) : internal(ptr) { __enable_weak_this(internal.get(), internal.get()); } - // From raw pointer of type T with custom Deleter - template - shared_ptr(T *ptr, Deleter deleter) : internal(ptr, deleter) { + // From raw pointer of type T with custom DELETER + template + shared_ptr(T *ptr, DELETER deleter) : internal(ptr, deleter) { __enable_weak_this(internal.get(), internal.get()); } - // Aliasing constructor: shares ownership information with __r but contains __p instead - // When the created shared_ptr goes out of scope, it will call the Deleter of __r, will not delete __p + // Aliasing constructor: shares ownership information with ref but contains ptr instead + // When the created shared_ptr goes out of scope, it will call the DELETER of ref, will not delete ptr template - shared_ptr(const shared_ptr &__r, T *__p) noexcept : internal(__r.internal, __p) { + shared_ptr(const shared_ptr &ref, T *ptr) noexcept : internal(ref.internal, ptr) { } #if _LIBCPP_STD_VER >= 20 template - shared_ptr(shared_ptr &&__r, T *__p) noexcept : internal(std::move(__r.internal), __p) { + shared_ptr(shared_ptr &&ref, T *ptr) noexcept : internal(std::move(ref.internal), ptr) { } #endif - // Copy constructor, share ownership with __r + // Copy constructor, share ownership with ref template ::value, int>::type = 0> - shared_ptr(const shared_ptr &__r) noexcept : internal(__r.internal) { + shared_ptr(const shared_ptr &ref) noexcept : internal(ref.internal) { // NOLINT: not marked as explicit } - shared_ptr(const shared_ptr &other) : internal(other.internal) { + shared_ptr(const shared_ptr &other) : internal(other.internal) { // NOLINT: not marked as explicit } - // Move constructor, share ownership with __r + // Move constructor, share ownership with ref template ::value, int>::type = 0> - shared_ptr(shared_ptr &&__r) noexcept : internal(std::move(__r.internal)) { + shared_ptr(shared_ptr &&ref) noexcept : internal(std::move(ref.internal)) { // NOLINT: not marked as explicit } - shared_ptr(shared_ptr &&other) : internal(std::move(other.internal)) { + shared_ptr(shared_ptr &&other) : internal(std::move(other.internal)) { // NOLINT: not marked as explicit } // Construct from std::shared_ptr @@ -131,8 +131,8 @@ public: typename std::enable_if::value && std::is_convertible::pointer, T *>::value, int>::type = 0> - shared_ptr &operator=(unique_ptr &&__r) { - shared_ptr(std::move(__r)).swap(*this); + shared_ptr &operator=(unique_ptr &&ref) { + shared_ptr(std::move(ref)).swap(*this); return *this; } @@ -141,7 +141,7 @@ public: [[clang::reinitializes]] #endif void - reset() { + reset() { // NOLINT: invalid case style internal.reset(); } #ifdef DUCKDB_CLANG_TIDY @@ -149,15 +149,15 @@ public: [[clang::reinitializes]] #endif template - void reset(U *ptr) { + void reset(U *ptr) { // NOLINT: invalid case style internal.reset(ptr); } #ifdef DUCKDB_CLANG_TIDY // This is necessary to tell clang-tidy that it reinitializes the variable after a move [[clang::reinitializes]] #endif - template - void reset(U *ptr, Deleter deleter) { + template + void reset(U *ptr, DELETER deleter) { // NOLINT: invalid case style internal.reset(ptr, deleter); } @@ -236,12 +236,12 @@ private: template *>::value, int>::type = 0> - void __enable_weak_this(const enable_shared_from_this *__e, _OrigPtr *__ptr) noexcept { + void __enable_weak_this(const enable_shared_from_this *object, _OrigPtr *ptr) noexcept { typedef typename std::remove_cv::type NonConstU; - if (__e && __e->__weak_this_.expired()) { + if (object && object->__weak_this_.expired()) { // __weak_this__ is the mutable variable returned by 'shared_from_this' // it is initialized here - __e->__weak_this_ = shared_ptr(*this, const_cast(static_cast(__ptr))); + object->__weak_this_ = shared_ptr(*this, const_cast(static_cast(ptr))); } } diff --git a/src/include/duckdb/common/weak_ptr.ipp b/src/include/duckdb/common/weak_ptr.ipp index aae7c95dabd6..fff31e251e04 100644 --- a/src/include/duckdb/common/weak_ptr.ipp +++ b/src/include/duckdb/common/weak_ptr.ipp @@ -1,7 +1,7 @@ namespace duckdb { template -class weak_ptr { +class weak_ptr { // NOLINT: invalid case style public: using original = std::weak_ptr; using element_type = typename original::element_type; @@ -23,13 +23,13 @@ public: typename std::enable_if::value, int>::type = 0) noexcept : internal(ptr.internal) { } - weak_ptr(weak_ptr const &other) noexcept : internal(other.internal) { + weak_ptr(weak_ptr const &other) noexcept : internal(other.internal) { // NOLINT: not marked as explicit } template weak_ptr(weak_ptr const &ptr, typename std::enable_if::value, int>::type = 0) noexcept : internal(ptr.internal) { } - weak_ptr(weak_ptr &&ptr) noexcept : internal(std::move(ptr.internal)) { + weak_ptr(weak_ptr &&ptr) noexcept : internal(std::move(ptr.internal)) { // NOLINT: not marked as explicit } template weak_ptr(weak_ptr &&ptr, typename std::enable_if::value, int>::type = 0) noexcept @@ -51,20 +51,20 @@ public: } // Modifiers - void reset() { + void reset() { // NOLINT: invalid case style internal.reset(); } // Observers - long use_count() const { + long use_count() const { // NOLINT: invalid case style return internal.use_count(); } - bool expired() const { + bool expired() const { // NOLINT: invalid case style return internal.expired(); } - shared_ptr lock() const { + shared_ptr lock() const { // NOLINT: invalid case style return shared_ptr(internal.lock()); } From a9ef0e20190383425e591cf6feaa715999fdba20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Fri, 12 Apr 2024 07:55:58 +0200 Subject: [PATCH 098/201] forgotten merge conflict --- src/include/duckdb/storage/string_uncompressed.hpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/include/duckdb/storage/string_uncompressed.hpp b/src/include/duckdb/storage/string_uncompressed.hpp index fca222511961..a57b851eb86d 100644 --- a/src/include/duckdb/storage/string_uncompressed.hpp +++ b/src/include/duckdb/storage/string_uncompressed.hpp @@ -162,12 +162,8 @@ struct UncompressedStringStorage { memcpy(dict_pos, source_data[source_idx].GetData(), string_length); // place the dictionary offset into the set of vectors -<<<<<<< HEAD - result_data[target_idx] = NumericCast(*dictionary_size); -======= D_ASSERT(*dictionary_size <= int32_t(Storage::BLOCK_SIZE)); - result_data[target_idx] = *dictionary_size; ->>>>>>> 40e2ff4837e79f2cc75e0d805595158a5409e680 + result_data[target_idx] = NumericCast(*dictionary_size); } D_ASSERT(RemainingSpace(segment, handle) <= Storage::BLOCK_SIZE); #ifdef DEBUG From 976f43ddb34272541e3384243bb04d21886acccf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Fri, 12 Apr 2024 09:41:21 +0200 Subject: [PATCH 099/201] csv reader use optional idx --- .../csv_scanner/scanner/string_value_scanner.cpp | 6 +++--- .../table_function/global_csv_state.cpp | 4 ++-- .../operator/csv_scanner/util/csv_error.cpp | 14 +++++++------- .../core_functions/aggregate/sum_helpers.hpp | 2 +- .../execution/operator/csv_scanner/csv_error.hpp | 12 ++++++------ src/main/settings/settings.cpp | 1 - 6 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp index ad42b6dafaa6..da0f70fea17d 100644 --- a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp +++ b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp @@ -777,8 +777,8 @@ void StringValueScanner::Flush(DataChunk &insert_chunk) { auto csv_error = CSVError::CastError( state_machine->options, csv_file_scan->names[col_idx], error_message, col_idx, borked_line, lines_per_batch, - result.line_positions_per_row[line_error].begin.GetGlobalPosition(result.result_size, first_nl), -1, - result_vector.GetType().id()); + result.line_positions_per_row[line_error].begin.GetGlobalPosition(result.result_size, first_nl), + optional_idx::Invalid(), result_vector.GetType().id()); error_handler->Error(csv_error); } @@ -802,7 +802,7 @@ void StringValueScanner::Flush(DataChunk &insert_chunk) { state_machine->options, csv_file_scan->names[col_idx], error_message, col_idx, borked_line, lines_per_batch, result.line_positions_per_row[line_error].begin.GetGlobalPosition(result.result_size, first_nl), - -1, result_vector.GetType().id()); + optional_idx::Invalid(), result_vector.GetType().id()); error_handler->Error(csv_error); } } diff --git a/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp b/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp index a56a9e687804..9cfc08d42084 100644 --- a/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp +++ b/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp @@ -289,12 +289,12 @@ void CSVGlobalState::FillRejectsTable() { // 4. Byte Position of the row error errors_appender.Append(error.row_byte_position + 1); // 5. Byte Position where error occurred - if (error.byte_position == -1) { + if (!error.byte_position.IsValid()) { // This means this error comes from a flush, and we don't support this yet, so we give it // a null errors_appender.Append(Value()); } else { - errors_appender.Append(error.byte_position + 1); + errors_appender.Append(error.byte_position.GetIndex() + 1); } // 6. Column Index if (error.type == CSVErrorType::MAXIMUM_LINE_SIZE) { diff --git a/src/execution/operator/csv_scanner/util/csv_error.cpp b/src/execution/operator/csv_scanner/util/csv_error.cpp index d22738f21415..208c23371183 100644 --- a/src/execution/operator/csv_scanner/util/csv_error.cpp +++ b/src/execution/operator/csv_scanner/util/csv_error.cpp @@ -87,7 +87,7 @@ CSVError::CSVError(string error_message_p, CSVErrorType type_p, LinesPerBoundary } CSVError::CSVError(string error_message_p, CSVErrorType type_p, idx_t column_idx_p, string csv_row_p, - LinesPerBoundary error_info_p, idx_t row_byte_position, int64_t byte_position_p, + LinesPerBoundary error_info_p, idx_t row_byte_position, optional_idx byte_position_p, const CSVReaderOptions &reader_options, const string &fixes) : error_message(std::move(error_message_p)), type(type_p), column_idx(column_idx_p), csv_row(std::move(csv_row_p)), error_info(error_info_p), row_byte_position(row_byte_position), byte_position(byte_position_p) { @@ -129,7 +129,7 @@ void CSVError::RemoveNewLine(string &error) { CSVError CSVError::CastError(const CSVReaderOptions &options, string &column_name, string &cast_error, idx_t column_idx, string &csv_row, LinesPerBoundary error_info, idx_t row_byte_position, - int64_t byte_position, LogicalTypeId type) { + optional_idx byte_position, LogicalTypeId type) { std::ostringstream error; // Which column error << "Error when converting column \"" << column_name << "\". "; @@ -192,7 +192,7 @@ CSVError CSVError::NullPaddingFail(const CSVReaderOptions &options, LinesPerBoun CSVError CSVError::UnterminatedQuotesError(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - int64_t byte_position) { + optional_idx byte_position) { std::ostringstream error; error << "Value with unterminated quote found." << '\n'; std::ostringstream how_to_fix_it; @@ -203,7 +203,7 @@ CSVError CSVError::UnterminatedQuotesError(const CSVReaderOptions &options, idx_ CSVError CSVError::IncorrectColumnAmountError(const CSVReaderOptions &options, idx_t actual_columns, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - int64_t byte_position) { + optional_idx byte_position) { std::ostringstream error; // We don't have a fix for this std::ostringstream how_to_fix_it; @@ -218,15 +218,15 @@ CSVError CSVError::IncorrectColumnAmountError(const CSVReaderOptions &options, i error << "Expected Number of Columns: " << options.dialect_options.num_cols << " Found: " << actual_columns + 1; if (actual_columns >= options.dialect_options.num_cols) { return CSVError(error.str(), CSVErrorType::TOO_MANY_COLUMNS, actual_columns, csv_row, error_info, - row_byte_position, byte_position - 1, options, how_to_fix_it.str()); + row_byte_position, byte_position.GetIndex() - 1, options, how_to_fix_it.str()); } else { return CSVError(error.str(), CSVErrorType::TOO_FEW_COLUMNS, actual_columns, csv_row, error_info, - row_byte_position, byte_position - 1, options, how_to_fix_it.str()); + row_byte_position, byte_position.GetIndex() - 1, options, how_to_fix_it.str()); } } CSVError CSVError::InvalidUTF8(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, - string &csv_row, idx_t row_byte_position, int64_t byte_position) { + string &csv_row, idx_t row_byte_position, optional_idx byte_position) { std::ostringstream error; // How many columns were expected and how many were found error << "Invalid unicode (byte sequence mismatch) detected." << '\n'; diff --git a/src/include/duckdb/core_functions/aggregate/sum_helpers.hpp b/src/include/duckdb/core_functions/aggregate/sum_helpers.hpp index 554ef2390faa..562f61ade356 100644 --- a/src/include/duckdb/core_functions/aggregate/sum_helpers.hpp +++ b/src/include/duckdb/core_functions/aggregate/sum_helpers.hpp @@ -73,7 +73,7 @@ struct HugeintAdd { template static void AddConstant(STATE &state, T input, idx_t count) { - AddNumber(state, Hugeint::Multiply(input, count)); + AddNumber(state, Hugeint::Multiply(input, UnsafeNumericCast(count))); } }; diff --git a/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp b/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp index 340c42cd1c40..5f3ef031178a 100644 --- a/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp +++ b/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp @@ -52,7 +52,7 @@ class CSVError { public: CSVError() {}; CSVError(string error_message, CSVErrorType type, idx_t column_idx, string csv_row, LinesPerBoundary error_info, - idx_t row_byte_position, int64_t byte_position, const CSVReaderOptions &reader_options, + idx_t row_byte_position, optional_idx byte_position, const CSVReaderOptions &reader_options, const string &fixes); CSVError(string error_message, CSVErrorType type, LinesPerBoundary error_info); //! Produces error messages for column name -> type mismatch. @@ -60,7 +60,7 @@ class CSVError { //! Produces error messages for casting errors static CSVError CastError(const CSVReaderOptions &options, string &column_name, string &cast_error, idx_t column_idx, string &csv_row, LinesPerBoundary error_info, idx_t row_byte_position, - int64_t byte_position, LogicalTypeId type); + optional_idx byte_position, LogicalTypeId type); //! Produces error for when the line size exceeds the maximum line size option static CSVError LineSizeError(const CSVReaderOptions &options, idx_t actual_size, LinesPerBoundary error_info, string &csv_row, idx_t byte_position); @@ -69,15 +69,15 @@ class CSVError { //! Produces error messages for unterminated quoted values static CSVError UnterminatedQuotesError(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - int64_t byte_position); + optional_idx byte_position); //! Produces error messages for null_padding option is set and we have quoted new values in parallel static CSVError NullPaddingFail(const CSVReaderOptions &options, LinesPerBoundary error_info); //! Produces error for incorrect (e.g., smaller and lower than the predefined) number of columns in a CSV Line static CSVError IncorrectColumnAmountError(const CSVReaderOptions &state_machine, idx_t actual_columns, LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - int64_t byte_position); + optional_idx byte_position); static CSVError InvalidUTF8(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, - string &csv_row, idx_t row_byte_position, int64_t byte_position); + string &csv_row, idx_t row_byte_position, optional_idx byte_position); idx_t GetBoundaryIndex() { return error_info.boundary_idx; @@ -104,7 +104,7 @@ class CSVError { //! Byte position of where the row starts idx_t row_byte_position; //! Byte Position where error occurred. - int64_t byte_position; + optional_idx byte_position; }; class CSVErrorHandler { diff --git a/src/main/settings/settings.cpp b/src/main/settings/settings.cpp index aecc7e767330..6209c0c1543d 100644 --- a/src/main/settings/settings.cpp +++ b/src/main/settings/settings.cpp @@ -999,7 +999,6 @@ void PartitionedWriteFlushThreshold::SetLocal(ClientContext &context, const Valu ClientConfig::GetConfig(context).partitioned_write_flush_threshold = input.GetValue(); } - Value PartitionedWriteFlushThreshold::GetSetting(const ClientContext &context) { return Value::BIGINT(NumericCast(ClientConfig::GetConfig(context).partitioned_write_flush_threshold)); } From 2973658f019928cbef971e25b768811877c84840 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Fri, 12 Apr 2024 10:09:48 +0200 Subject: [PATCH 100/201] compression code yay --- .../duckdb/planner/operator/logical_top_n.hpp | 6 +++--- .../storage/compression/alp/algorithm/alp.hpp | 5 +++-- .../storage/compression/alp/alp_compress.hpp | 4 ++-- .../duckdb/storage/compression/alp/alp_fetch.hpp | 2 +- .../duckdb/storage/compression/alp/alp_utils.hpp | 4 ++-- .../storage/compression/alprd/alprd_compress.hpp | 4 ++-- .../storage/compression/alprd/alprd_fetch.hpp | 2 +- .../storage/compression/chimp/chimp_fetch.hpp | 2 +- .../storage/compression/patas/patas_fetch.hpp | 2 +- src/optimizer/topn_optimizer.cpp | 6 +++--- src/storage/compression/fsst.cpp | 16 ++++++++++------ src/storage/metadata/metadata_manager.cpp | 4 ++-- src/storage/metadata/metadata_writer.cpp | 2 +- 13 files changed, 32 insertions(+), 27 deletions(-) diff --git a/src/include/duckdb/planner/operator/logical_top_n.hpp b/src/include/duckdb/planner/operator/logical_top_n.hpp index 745ca6292ca8..cc19ea6b7ff6 100644 --- a/src/include/duckdb/planner/operator/logical_top_n.hpp +++ b/src/include/duckdb/planner/operator/logical_top_n.hpp @@ -19,15 +19,15 @@ class LogicalTopN : public LogicalOperator { static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_TOP_N; public: - LogicalTopN(vector orders, int64_t limit, int64_t offset) + LogicalTopN(vector orders, idx_t limit, idx_t offset) : LogicalOperator(LogicalOperatorType::LOGICAL_TOP_N), orders(std::move(orders)), limit(limit), offset(offset) { } vector orders; //! The maximum amount of elements to emit - int64_t limit; + idx_t limit; //! The offset from the start to begin emitting elements - int64_t offset; + idx_t offset; public: vector GetColumnBindings() override { diff --git a/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp b/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp index d71189ded9a3..d216f23e58af 100644 --- a/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp +++ b/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp @@ -254,7 +254,8 @@ struct AlpCompression { static void FindBestFactorAndExponent(const T *input_vector, idx_t n_values, State &state) { //! We sample equidistant values within a vector; to do this we skip a fixed number of values vector vector_sample; - uint32_t idx_increments = MaxValue(1, (int32_t)std::ceil((double)n_values / AlpConstants::SAMPLES_PER_VECTOR)); + auto idx_increments = MaxValue( + 1, UnsafeNumericCast(std::ceil((double)n_values / AlpConstants::SAMPLES_PER_VECTOR))); for (idx_t i = 0; i < n_values; i += idx_increments) { vector_sample.push_back(input_vector[i]); } @@ -360,7 +361,7 @@ struct AlpCompression { } state.bit_width = bit_width; // in bits state.bp_size = bp_size; // in bytes - state.frame_of_reference = min_value; + state.frame_of_reference = static_cast(min_value); // understood this can be negative } /* diff --git a/src/include/duckdb/storage/compression/alp/alp_compress.hpp b/src/include/duckdb/storage/compression/alp/alp_compress.hpp index fc7d88694df3..e4021d939190 100644 --- a/src/include/duckdb/storage/compression/alp/alp_compress.hpp +++ b/src/include/duckdb/storage/compression/alp/alp_compress.hpp @@ -183,10 +183,10 @@ struct AlpCompressionState : public CompressionState { // Verify that the metadata_ptr is not smaller than the space used by the data D_ASSERT(dataptr + metadata_offset <= metadata_ptr); - idx_t bytes_used_by_metadata = dataptr + Storage::BLOCK_SIZE - metadata_ptr; + auto bytes_used_by_metadata = UnsafeNumericCast(dataptr + Storage::BLOCK_SIZE - metadata_ptr); // Initially the total segment size is the size of the block - idx_t total_segment_size = Storage::BLOCK_SIZE; + auto total_segment_size = Storage::BLOCK_SIZE; //! We compact the block if the space used is less than a threshold const auto used_space_percentage = diff --git a/src/include/duckdb/storage/compression/alp/alp_fetch.hpp b/src/include/duckdb/storage/compression/alp/alp_fetch.hpp index 81d3d2ac0989..54a9964d4ad7 100644 --- a/src/include/duckdb/storage/compression/alp/alp_fetch.hpp +++ b/src/include/duckdb/storage/compression/alp/alp_fetch.hpp @@ -28,7 +28,7 @@ void AlpFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, using EXACT_TYPE = typename FloatingToExact::TYPE; AlpScanState scan_state(segment); - scan_state.Skip(segment, row_id); + scan_state.Skip(segment, UnsafeNumericCast(row_id)); auto result_data = FlatVector::GetData(result); result_data[result_idx] = (EXACT_TYPE)0; diff --git a/src/include/duckdb/storage/compression/alp/alp_utils.hpp b/src/include/duckdb/storage/compression/alp/alp_utils.hpp index 1b77e219e489..b5e49a6f3027 100644 --- a/src/include/duckdb/storage/compression/alp/alp_utils.hpp +++ b/src/include/duckdb/storage/compression/alp/alp_utils.hpp @@ -40,8 +40,8 @@ class AlpUtils { auto n_lookup_values = NumericCast(MinValue(current_vector_n_values, (idx_t)AlpConstants::ALP_VECTOR_SIZE)); //! We sample equidistant values within a vector; to do this we jump a fixed number of values - uint32_t n_sampled_increments = - MaxValue(1, (int32_t)std::ceil((double)n_lookup_values / AlpConstants::SAMPLES_PER_VECTOR)); + uint32_t n_sampled_increments = MaxValue( + 1, UnsafeNumericCast(std::ceil((double)n_lookup_values / AlpConstants::SAMPLES_PER_VECTOR))); uint32_t n_sampled_values = std::ceil((double)n_lookup_values / n_sampled_increments); D_ASSERT(n_sampled_values < AlpConstants::ALP_VECTOR_SIZE); diff --git a/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp b/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp index 5ff07dd7bf9a..3f2a8aca329d 100644 --- a/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp +++ b/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp @@ -185,10 +185,10 @@ struct AlpRDCompressionState : public CompressionState { // Verify that the metadata_ptr is not smaller than the space used by the data D_ASSERT(dataptr + metadata_offset <= metadata_ptr); - idx_t bytes_used_by_metadata = dataptr + Storage::BLOCK_SIZE - metadata_ptr; + auto bytes_used_by_metadata = UnsafeNumericCast(dataptr + Storage::BLOCK_SIZE - metadata_ptr); // Initially the total segment size is the size of the block - idx_t total_segment_size = Storage::BLOCK_SIZE; + auto total_segment_size = Storage::BLOCK_SIZE; //! We compact the block if the space used is less than a threshold const auto used_space_percentage = diff --git a/src/include/duckdb/storage/compression/alprd/alprd_fetch.hpp b/src/include/duckdb/storage/compression/alprd/alprd_fetch.hpp index 0128b8db683e..35923019877c 100644 --- a/src/include/duckdb/storage/compression/alprd/alprd_fetch.hpp +++ b/src/include/duckdb/storage/compression/alprd/alprd_fetch.hpp @@ -27,7 +27,7 @@ template void AlpRDFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { using EXACT_TYPE = typename FloatingToExact::TYPE; AlpRDScanState scan_state(segment); - scan_state.Skip(segment, row_id); + scan_state.Skip(segment, UnsafeNumericCast(row_id)); auto result_data = FlatVector::GetData(result); result_data[result_idx] = (EXACT_TYPE)0; diff --git a/src/include/duckdb/storage/compression/chimp/chimp_fetch.hpp b/src/include/duckdb/storage/compression/chimp/chimp_fetch.hpp index c8e4edb9b36a..5d6be51e64cb 100644 --- a/src/include/duckdb/storage/compression/chimp/chimp_fetch.hpp +++ b/src/include/duckdb/storage/compression/chimp/chimp_fetch.hpp @@ -29,7 +29,7 @@ void ChimpFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id using INTERNAL_TYPE = typename ChimpType::TYPE; ChimpScanState scan_state(segment); - scan_state.Skip(segment, row_id); + scan_state.Skip(segment, UnsafeNumericCast(row_id)); auto result_data = FlatVector::GetData(result); if (scan_state.GroupFinished() && scan_state.total_value_count < scan_state.segment_count) { diff --git a/src/include/duckdb/storage/compression/patas/patas_fetch.hpp b/src/include/duckdb/storage/compression/patas/patas_fetch.hpp index fd416cfc6aa9..8e20ae67fe4e 100644 --- a/src/include/duckdb/storage/compression/patas/patas_fetch.hpp +++ b/src/include/duckdb/storage/compression/patas/patas_fetch.hpp @@ -29,7 +29,7 @@ void PatasFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id using EXACT_TYPE = typename FloatingToExact::TYPE; PatasScanState scan_state(segment); - scan_state.Skip(segment, row_id); + scan_state.Skip(segment, UnsafeNumericCast(row_id)); auto result_data = FlatVector::GetData(result); result_data[result_idx] = (EXACT_TYPE)0; diff --git a/src/optimizer/topn_optimizer.cpp b/src/optimizer/topn_optimizer.cpp index 26f6ca99506e..5e83f0a05f79 100644 --- a/src/optimizer/topn_optimizer.cpp +++ b/src/optimizer/topn_optimizer.cpp @@ -29,10 +29,10 @@ unique_ptr TopN::Optimize(unique_ptr op) { if (CanOptimize(*op)) { auto &limit = op->Cast(); auto &order_by = (op->children[0])->Cast(); - auto limit_val = int64_t(limit.limit_val.GetConstantValue()); - int64_t offset_val = 0; + auto limit_val = NumericCast(limit.limit_val.GetConstantValue()); + idx_t offset_val = 0; if (limit.offset_val.Type() == LimitNodeType::CONSTANT_VALUE) { - offset_val = NumericCast(limit.offset_val.GetConstantValue()); + offset_val = NumericCast(limit.offset_val.GetConstantValue()); } auto topn = make_uniq(std::move(order_by.orders), limit_val, offset_val); topn->AddChild(std::move(order_by.children[0])); diff --git a/src/storage/compression/fsst.cpp b/src/storage/compression/fsst.cpp index 6c9562366f0c..02474963f31d 100644 --- a/src/storage/compression/fsst.cpp +++ b/src/storage/compression/fsst.cpp @@ -178,7 +178,8 @@ idx_t FSSTStorage::StringFinalAnalyze(AnalyzeState &state_p) { compressed_dict_size += size; max_compressed_string_length = MaxValue(max_compressed_string_length, size); } - D_ASSERT(compressed_dict_size == (compressed_ptrs[res - 1] - compressed_ptrs[0]) + compressed_sizes[res - 1]); + D_ASSERT(compressed_dict_size == + (uint64_t)(compressed_ptrs[res - 1] - compressed_ptrs[0]) + compressed_sizes[res - 1]); auto minimum_width = BitpackingPrimitives::MinimumBitWidth(max_compressed_string_length); auto bitpacked_offsets_size = @@ -606,7 +607,8 @@ void FSSTStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &sta for (idx_t i = 0; i < scan_count; i++) { uint32_t string_length = bitunpack_buffer[i + offsets.scan_offset]; result_data[i] = UncompressedStringStorage::FetchStringFromDict( - segment, dict, result, baseptr, delta_decode_buffer[i + offsets.unused_delta_decoded_values], + segment, dict, result, baseptr, + UnsafeNumericCast(delta_decode_buffer[i + offsets.unused_delta_decoded_values]), string_length); FSSTVector::SetCount(result, scan_count); } @@ -615,7 +617,8 @@ void FSSTStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &sta for (idx_t i = 0; i < scan_count; i++) { uint32_t str_len = bitunpack_buffer[i + offsets.scan_offset]; auto str_ptr = FSSTStorage::FetchStringPointer( - dict, baseptr, delta_decode_buffer[i + offsets.unused_delta_decoded_values]); + dict, baseptr, + UnsafeNumericCast(delta_decode_buffer[i + offsets.unused_delta_decoded_values])); if (str_len > 0) { result_data[i + result_offset] = @@ -627,7 +630,7 @@ void FSSTStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &sta } scan_state.StoreLastDelta(delta_decode_buffer[scan_count + offsets.unused_delta_decoded_values - 1], - start + scan_count - 1); + UnsafeNumericCast(start + scan_count - 1)); } void FSSTStorage::StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { @@ -655,7 +658,7 @@ void FSSTStorage::StringFetchRow(ColumnSegment &segment, ColumnFetchState &state if (have_symbol_table) { // We basically just do a scan of 1 which is kinda expensive as we need to repeatedly delta decode until we // reach the row we want, we could consider a more clever caching trick if this is slow - auto offsets = CalculateBpDeltaOffsets(-1, row_id, 1); + auto offsets = CalculateBpDeltaOffsets(-1, UnsafeNumericCast(row_id), 1); auto bitunpack_buffer = unique_ptr(new uint32_t[offsets.total_bitunpack_count]); BitUnpackRange(base_data, data_ptr_cast(bitunpack_buffer.get()), offsets.total_bitunpack_count, @@ -667,7 +670,8 @@ void FSSTStorage::StringFetchRow(ColumnSegment &segment, ColumnFetchState &state uint32_t string_length = bitunpack_buffer[offsets.scan_offset]; string_t compressed_string = UncompressedStringStorage::FetchStringFromDict( - segment, dict, result, base_ptr, delta_decode_buffer[offsets.unused_delta_decoded_values], string_length); + segment, dict, result, base_ptr, + UnsafeNumericCast(delta_decode_buffer[offsets.unused_delta_decoded_values]), string_length); result_data[result_idx] = FSSTPrimitives::DecompressValue((void *)&decoder, result, compressed_string.GetData(), compressed_string.GetSize()); diff --git a/src/storage/metadata/metadata_manager.cpp b/src/storage/metadata/metadata_manager.cpp index 8d25e037e1b5..2fba4fc1bbc7 100644 --- a/src/storage/metadata/metadata_manager.cpp +++ b/src/storage/metadata/metadata_manager.cpp @@ -33,7 +33,7 @@ MetadataHandle MetadataManager::AllocateHandle() { // select the first free metadata block we can find MetadataPointer pointer; - pointer.block_index = free_block; + pointer.block_index = UnsafeNumericCast(free_block); auto &block = blocks[free_block]; if (block.block->BlockId() < MAXIMUM_BLOCK) { // this block is a disk-backed block, yet we are planning to write to it @@ -134,7 +134,7 @@ MetadataPointer MetadataManager::FromDiskPointer(MetaBlockPointer pointer) { pointer.block_pointer); } // LCOV_EXCL_STOP MetadataPointer result; - result.block_index = block_id; + result.block_index = UnsafeNumericCast(block_id); result.index = UnsafeNumericCast(index); return result; } diff --git a/src/storage/metadata/metadata_writer.cpp b/src/storage/metadata/metadata_writer.cpp index e47708c21f6a..cc95b9086a81 100644 --- a/src/storage/metadata/metadata_writer.cpp +++ b/src/storage/metadata/metadata_writer.cpp @@ -48,7 +48,7 @@ void MetadataWriter::NextBlock() { current_pointer = block.pointer; offset = sizeof(idx_t); capacity = MetadataManager::METADATA_BLOCK_SIZE; - Store(-1, BasePtr()); + Store(static_cast(-1), BasePtr()); if (written_pointers) { written_pointers->push_back(manager.GetDiskPointer(current_pointer)); } From 86d65a93e510d860ce4be6fbb656b36b6babdee0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Fri, 12 Apr 2024 10:14:42 +0200 Subject: [PATCH 101/201] column data --- src/storage/table/column_data.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/storage/table/column_data.cpp b/src/storage/table/column_data.cpp index e48495049d7e..1ccc4b8a19f9 100644 --- a/src/storage/table/column_data.cpp +++ b/src/storage/table/column_data.cpp @@ -105,7 +105,7 @@ idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remai if (state.scan_options && state.scan_options->force_fetch_row) { for (idx_t i = 0; i < scan_count; i++) { ColumnFetchState fetch_state; - state.current->FetchRow(fetch_state, state.row_index + i, result, result_offset + i); + state.current->FetchRow(fetch_state, UnsafeNumericCast(state.row_index + i), result, result_offset + i); } } else { state.current->Scan(state, scan_count, result, result_offset, @@ -325,24 +325,24 @@ void ColumnData::RevertAppend(row_t start_row) { return; } // find the segment index that the current row belongs to - idx_t segment_index = data.GetSegmentIndex(l, start_row); - auto segment = data.GetSegmentByIndex(l, segment_index); + idx_t segment_index = data.GetSegmentIndex(l, UnsafeNumericCast(start_row)); + auto segment = data.GetSegmentByIndex(l, UnsafeNumericCast(segment_index)); auto &transient = *segment; D_ASSERT(transient.segment_type == ColumnSegmentType::TRANSIENT); // remove any segments AFTER this segment: they should be deleted entirely data.EraseSegments(l, segment_index); - this->count = start_row - this->start; + this->count = UnsafeNumericCast(start_row) - this->start; segment->next = nullptr; - transient.RevertAppend(start_row); + transient.RevertAppend(UnsafeNumericCast(start_row)); } idx_t ColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { D_ASSERT(row_id >= 0); D_ASSERT(idx_t(row_id) >= start); // perform the fetch within the segment - state.row_index = start + ((row_id - start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE); + state.row_index = start + ((UnsafeNumericCast(row_id) - start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE); state.current = data.GetSegment(state.row_index); state.internal_index = state.current->start; return ScanVector(state, result, STANDARD_VECTOR_SIZE, false); @@ -350,14 +350,14 @@ idx_t ColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { void ColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - auto segment = data.GetSegment(row_id); + auto segment = data.GetSegment(UnsafeNumericCast(row_id)); // now perform the fetch within the segment segment->FetchRow(state, row_id, result, result_idx); // merge any updates made to this row lock_guard update_guard(update_lock); if (updates) { - updates->FetchRow(transaction, row_id, result, result_idx); + updates->FetchRow(transaction, UnsafeNumericCast(row_id), result, result_idx); } } @@ -422,7 +422,7 @@ void ColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, if (state.scan_options && state.scan_options->force_fetch_row) { for (idx_t i = 0; i < count; i++) { ColumnFetchState fetch_state; - segment.FetchRow(fetch_state, state.row_index + i, scan_vector, i); + segment.FetchRow(fetch_state, UnsafeNumericCast(state.row_index + i), scan_vector, i); } } else { segment.Scan(state, count, scan_vector, 0, true); From 175c19dcf187816d8527632cffcee5926c5d3a93 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 12 Apr 2024 10:30:01 +0200 Subject: [PATCH 102/201] make_refcounted -> make_shared_ptr --- extension/httpfs/httpfs.cpp | 4 +- extension/httpfs/s3fs.cpp | 2 +- extension/json/json_functions/copy_json.cpp | 2 +- extension/json/json_functions/read_json.cpp | 8 +-- .../json/json_functions/read_json_objects.cpp | 6 +- extension/parquet/column_reader.cpp | 10 +-- .../include/templated_column_reader.hpp | 2 +- extension/parquet/parquet_crypto.cpp | 2 +- extension/parquet/parquet_extension.cpp | 4 +- extension/parquet/parquet_reader.cpp | 2 +- .../catalog_entry/duck_table_entry.cpp | 10 +-- src/common/allocator.cpp | 2 +- src/common/arrow/arrow_wrapper.cpp | 2 +- src/common/extra_type_info.cpp | 8 +-- src/common/http_state.cpp | 4 +- src/common/re2_regex.cpp | 2 +- src/common/types.cpp | 28 ++++---- .../types/column/column_data_collection.cpp | 14 ++-- .../column/column_data_collection_segment.cpp | 2 +- .../types/column/partitioned_column_data.cpp | 4 +- .../types/row/partitioned_tuple_data.cpp | 4 +- .../types/row/tuple_data_collection.cpp | 4 +- src/common/types/value.cpp | 36 +++++------ src/common/types/vector_cache.cpp | 6 +- src/execution/aggregate_hashtable.cpp | 2 +- src/execution/index/art/art.cpp | 2 +- .../operator/aggregate/aggregate_object.cpp | 2 +- .../aggregate/physical_hash_aggregate.cpp | 4 +- .../physical_ungrouped_aggregate.cpp | 2 +- .../operator/aggregate/physical_window.cpp | 2 +- .../csv_scanner/buffer_manager/csv_buffer.cpp | 4 +- .../buffer_manager/csv_buffer_manager.cpp | 2 +- .../scanner/string_value_scanner.cpp | 10 +-- .../csv_scanner/sniffer/csv_sniffer.cpp | 4 +- .../table_function/csv_file_scanner.cpp | 16 ++--- .../table_function/global_csv_state.cpp | 12 ++-- .../helper/physical_buffered_collector.cpp | 2 +- .../operator/join/physical_asof_join.cpp | 2 +- .../operator/join/physical_hash_join.cpp | 4 +- .../operator/join/physical_range_join.cpp | 2 +- .../operator/order/physical_order.cpp | 2 +- .../physical_batch_copy_to_file.cpp | 2 +- .../persistent/physical_copy_to_file.cpp | 2 +- .../schema/physical_create_art_index.cpp | 2 +- .../operator/set/physical_recursive_cte.cpp | 2 +- src/execution/physical_plan/plan_cte.cpp | 2 +- .../physical_plan/plan_recursive_cte.cpp | 2 +- src/function/table/copy_csv.cpp | 2 +- src/function/table/read_csv.cpp | 2 +- src/function/table/sniff_csv.cpp | 2 +- src/include/duckdb/common/helper.hpp | 4 +- .../duckdb/common/multi_file_reader.hpp | 2 +- src/include/duckdb/common/types.hpp | 2 +- .../duckdb/common/types/selection_vector.hpp | 2 +- src/include/duckdb/planner/binder.hpp | 2 +- src/include/duckdb/storage/object_cache.hpp | 2 +- .../duckdb/storage/serialization/types.json | 2 +- src/main/capi/table_function-c.cpp | 2 +- src/main/client_context.cpp | 2 +- src/main/client_data.cpp | 4 +- src/main/connection.cpp | 20 +++--- src/main/database.cpp | 4 +- src/main/db_instance_cache.cpp | 2 +- src/main/relation.cpp | 56 ++++++++-------- src/main/relation/read_csv_relation.cpp | 2 +- src/main/relation/table_relation.cpp | 4 +- src/parallel/executor.cpp | 16 ++--- src/parallel/meta_pipeline.cpp | 4 +- src/planner/bind_context.cpp | 4 +- src/planner/binder.cpp | 2 +- src/planner/bound_parameter_map.cpp | 2 +- src/planner/planner.cpp | 2 +- src/storage/buffer/block_manager.cpp | 2 +- src/storage/checkpoint_manager.cpp | 2 +- src/storage/data_table.cpp | 4 +- src/storage/local_storage.cpp | 10 +-- src/storage/serialization/serialize_types.cpp | 2 +- src/storage/standard_buffer_manager.cpp | 4 +- src/storage/statistics/column_statistics.cpp | 6 +- src/storage/table/row_group.cpp | 2 +- src/storage/table/row_group_collection.cpp | 8 +-- src/storage/table/row_version_manager.cpp | 2 +- src/storage/wal_replay.cpp | 2 +- test/api/test_object_cache.cpp | 2 +- tools/odbc/include/duckdb_odbc.hpp | 2 +- tools/pythonpkg/src/pyconnection.cpp | 14 ++-- .../src/pyconnection/type_creation.cpp | 16 ++--- tools/pythonpkg/src/pyexpression.cpp | 26 ++++---- tools/pythonpkg/src/pyrelation.cpp | 4 +- tools/pythonpkg/src/typing/pytype.cpp | 24 +++---- tools/pythonpkg/src/typing/typing.cpp | 64 +++++++++---------- 91 files changed, 299 insertions(+), 299 deletions(-) diff --git a/extension/httpfs/httpfs.cpp b/extension/httpfs/httpfs.cpp index 476409ab8820..8c0bd34262b2 100644 --- a/extension/httpfs/httpfs.cpp +++ b/extension/httpfs/httpfs.cpp @@ -557,7 +557,7 @@ static optional_ptr TryGetMetadataCache(optional_ptrregistered_state.find("http_cache"); if (lookup == client_context->registered_state.end()) { - auto cache = make_refcounted(true, true); + auto cache = make_shared_ptr(true, true); client_context->registered_state["http_cache"] = cache; return cache.get(); } else { @@ -572,7 +572,7 @@ void HTTPFileHandle::Initialize(optional_ptr opener) { auto &hfs = file_system.Cast(); state = HTTPState::TryGetState(opener); if (!state) { - state = make_refcounted(); + state = make_shared_ptr(); } auto current_cache = TryGetMetadataCache(opener, hfs); diff --git a/extension/httpfs/s3fs.cpp b/extension/httpfs/s3fs.cpp index 888b6cb2115d..b06120896c00 100644 --- a/extension/httpfs/s3fs.cpp +++ b/extension/httpfs/s3fs.cpp @@ -568,7 +568,7 @@ shared_ptr S3FileHandle::GetBuffer(uint16_t write_buffer_idx) { auto buffer_handle = s3fs.Allocate(part_size, config_params.max_upload_threads); auto new_write_buffer = - make_refcounted(write_buffer_idx * part_size, part_size, std::move(buffer_handle)); + make_shared_ptr(write_buffer_idx * part_size, part_size, std::move(buffer_handle)); { unique_lock lck(write_buffers_lock); auto lookup_result = write_buffers.find(write_buffer_idx); diff --git a/extension/json/json_functions/copy_json.cpp b/extension/json/json_functions/copy_json.cpp index 5d97c01e3b84..4744e5a78c82 100644 --- a/extension/json/json_functions/copy_json.cpp +++ b/extension/json/json_functions/copy_json.cpp @@ -185,7 +185,7 @@ CopyFunction JSONFunctions::GetJSONCopyFunction() { function.plan = CopyToJSONPlan; function.copy_from_bind = CopyFromJSONBind; - function.copy_from_function = JSONFunctions::GetReadJSONTableFunction(make_refcounted( + function.copy_from_function = JSONFunctions::GetReadJSONTableFunction(make_shared_ptr( JSONScanType::READ_JSON, JSONFormat::NEWLINE_DELIMITED, JSONRecordType::RECORDS, false)); return function; diff --git a/extension/json/json_functions/read_json.cpp b/extension/json/json_functions/read_json.cpp index a7624e287240..e207060a448b 100644 --- a/extension/json/json_functions/read_json.cpp +++ b/extension/json/json_functions/read_json.cpp @@ -382,25 +382,25 @@ TableFunctionSet CreateJSONFunctionInfo(string name, shared_ptr in } TableFunctionSet JSONFunctions::GetReadJSONFunction() { - auto info = make_refcounted(JSONScanType::READ_JSON, JSONFormat::AUTO_DETECT, + auto info = make_shared_ptr(JSONScanType::READ_JSON, JSONFormat::AUTO_DETECT, JSONRecordType::AUTO_DETECT, true); return CreateJSONFunctionInfo("read_json", std::move(info)); } TableFunctionSet JSONFunctions::GetReadNDJSONFunction() { - auto info = make_refcounted(JSONScanType::READ_JSON, JSONFormat::NEWLINE_DELIMITED, + auto info = make_shared_ptr(JSONScanType::READ_JSON, JSONFormat::NEWLINE_DELIMITED, JSONRecordType::AUTO_DETECT, true); return CreateJSONFunctionInfo("read_ndjson", std::move(info)); } TableFunctionSet JSONFunctions::GetReadJSONAutoFunction() { - auto info = make_refcounted(JSONScanType::READ_JSON, JSONFormat::AUTO_DETECT, + auto info = make_shared_ptr(JSONScanType::READ_JSON, JSONFormat::AUTO_DETECT, JSONRecordType::AUTO_DETECT, true); return CreateJSONFunctionInfo("read_json_auto", std::move(info)); } TableFunctionSet JSONFunctions::GetReadNDJSONAutoFunction() { - auto info = make_refcounted(JSONScanType::READ_JSON, JSONFormat::NEWLINE_DELIMITED, + auto info = make_shared_ptr(JSONScanType::READ_JSON, JSONFormat::NEWLINE_DELIMITED, JSONRecordType::AUTO_DETECT, true); return CreateJSONFunctionInfo("read_ndjson_auto", std::move(info)); } diff --git a/extension/json/json_functions/read_json_objects.cpp b/extension/json/json_functions/read_json_objects.cpp index c8e798111071..630c2e694689 100644 --- a/extension/json/json_functions/read_json_objects.cpp +++ b/extension/json/json_functions/read_json_objects.cpp @@ -62,7 +62,7 @@ TableFunction GetReadJSONObjectsTableFunction(bool list_parameter, shared_ptr(JSONScanType::READ_JSON_OBJECTS, JSONFormat::ARRAY, JSONRecordType::RECORDS); + make_shared_ptr(JSONScanType::READ_JSON_OBJECTS, JSONFormat::ARRAY, JSONRecordType::RECORDS); function_set.AddFunction(GetReadJSONObjectsTableFunction(false, function_info)); function_set.AddFunction(GetReadJSONObjectsTableFunction(true, function_info)); return function_set; @@ -70,7 +70,7 @@ TableFunctionSet JSONFunctions::GetReadJSONObjectsFunction() { TableFunctionSet JSONFunctions::GetReadNDJSONObjectsFunction() { TableFunctionSet function_set("read_ndjson_objects"); - auto function_info = make_refcounted(JSONScanType::READ_JSON_OBJECTS, JSONFormat::NEWLINE_DELIMITED, + auto function_info = make_shared_ptr(JSONScanType::READ_JSON_OBJECTS, JSONFormat::NEWLINE_DELIMITED, JSONRecordType::RECORDS); function_set.AddFunction(GetReadJSONObjectsTableFunction(false, function_info)); function_set.AddFunction(GetReadJSONObjectsTableFunction(true, function_info)); @@ -79,7 +79,7 @@ TableFunctionSet JSONFunctions::GetReadNDJSONObjectsFunction() { TableFunctionSet JSONFunctions::GetReadJSONObjectsAutoFunction() { TableFunctionSet function_set("read_json_objects_auto"); - auto function_info = make_refcounted(JSONScanType::READ_JSON_OBJECTS, JSONFormat::AUTO_DETECT, + auto function_info = make_shared_ptr(JSONScanType::READ_JSON_OBJECTS, JSONFormat::AUTO_DETECT, JSONRecordType::RECORDS); function_set.AddFunction(GetReadJSONObjectsTableFunction(false, function_info)); function_set.AddFunction(GetReadJSONObjectsTableFunction(true, function_info)); diff --git a/extension/parquet/column_reader.cpp b/extension/parquet/column_reader.cpp index 680623c67f4b..e02feca15208 100644 --- a/extension/parquet/column_reader.cpp +++ b/extension/parquet/column_reader.cpp @@ -304,7 +304,7 @@ void ColumnReader::PreparePageV2(PageHeader &page_hdr) { void ColumnReader::AllocateBlock(idx_t size) { if (!block) { - block = make_refcounted(GetAllocator(), size); + block = make_shared_ptr(GetAllocator(), size); } else { block->resize(GetAllocator(), size); } @@ -516,7 +516,7 @@ idx_t ColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data_ptr result); } else if (dbp_decoder) { // TODO keep this in the state - auto read_buf = make_refcounted(); + auto read_buf = make_shared_ptr(); switch (schema.type) { case duckdb_parquet::format::Type::INT32: @@ -537,7 +537,7 @@ idx_t ColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data_ptr } else if (rle_decoder) { // RLE encoding for boolean D_ASSERT(type.id() == LogicalTypeId::BOOLEAN); - auto read_buf = make_refcounted(); + auto read_buf = make_shared_ptr(); read_buf->resize(reader.allocator, sizeof(bool) * (read_now - null_count)); rle_decoder->GetBatch(read_buf->ptr, read_now - null_count); PlainTemplated>(read_buf, define_out, read_now, filter, @@ -546,7 +546,7 @@ idx_t ColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data_ptr // DELTA_BYTE_ARRAY or DELTA_LENGTH_BYTE_ARRAY DeltaByteArray(define_out, read_now, filter, result_offset, result); } else if (bss_decoder) { - auto read_buf = make_refcounted(); + auto read_buf = make_shared_ptr(); switch (schema.type) { case duckdb_parquet::format::Type::FLOAT: @@ -662,7 +662,7 @@ void StringColumnReader::Dictionary(shared_ptr data, idx_t num static shared_ptr ReadDbpData(Allocator &allocator, ResizeableBuffer &buffer, idx_t &value_count) { auto decoder = make_uniq(buffer.ptr, buffer.len); value_count = decoder->TotalValues(); - auto result = make_refcounted(); + auto result = make_shared_ptr(); result->resize(allocator, sizeof(uint32_t) * value_count); decoder->GetBatch(result->ptr, value_count); decoder->Finalize(); diff --git a/extension/parquet/include/templated_column_reader.hpp b/extension/parquet/include/templated_column_reader.hpp index c8ffb761eb5f..009cd6aeeb86 100644 --- a/extension/parquet/include/templated_column_reader.hpp +++ b/extension/parquet/include/templated_column_reader.hpp @@ -44,7 +44,7 @@ class TemplatedColumnReader : public ColumnReader { public: void AllocateDict(idx_t size) { if (!dict) { - dict = make_refcounted(GetAllocator(), size); + dict = make_shared_ptr(GetAllocator(), size); } else { dict->resize(GetAllocator(), size); } diff --git a/extension/parquet/parquet_crypto.cpp b/extension/parquet/parquet_crypto.cpp index 13d96ab76135..d6bb7f1b2200 100644 --- a/extension/parquet/parquet_crypto.cpp +++ b/extension/parquet/parquet_crypto.cpp @@ -14,7 +14,7 @@ namespace duckdb { ParquetKeys &ParquetKeys::Get(ClientContext &context) { auto &cache = ObjectCache::GetObjectCache(context); if (!cache.Get(ParquetKeys::ObjectType())) { - cache.Put(ParquetKeys::ObjectType(), make_refcounted()); + cache.Put(ParquetKeys::ObjectType(), make_shared_ptr()); } return *cache.Get(ParquetKeys::ObjectType()); } diff --git a/extension/parquet/parquet_extension.cpp b/extension/parquet/parquet_extension.cpp index 5e898669c016..4cde1e9ec399 100644 --- a/extension/parquet/parquet_extension.cpp +++ b/extension/parquet/parquet_extension.cpp @@ -544,7 +544,7 @@ class ParquetScanFunction { result->initial_reader = result->readers[0]; } else { result->initial_reader = - make_refcounted(context, bind_data.files[0], bind_data.parquet_options); + make_shared_ptr(context, bind_data.files[0], bind_data.parquet_options); result->readers[0] = result->initial_reader; } result->file_states[0] = ParquetFileState::OPEN; @@ -750,7 +750,7 @@ class ParquetScanFunction { shared_ptr reader; try { - reader = make_refcounted(context, file, pq_options); + reader = make_shared_ptr(context, file, pq_options); InitializeParquetReader(*reader, bind_data, parallel_state.column_ids, parallel_state.filters, context); } catch (...) { diff --git a/extension/parquet/parquet_reader.cpp b/extension/parquet/parquet_reader.cpp index ecd51fc3c968..896bc11b390c 100644 --- a/extension/parquet/parquet_reader.cpp +++ b/extension/parquet/parquet_reader.cpp @@ -113,7 +113,7 @@ LoadMetadata(Allocator &allocator, FileHandle &file_handle, metadata->read(file_proto.get()); } - return make_refcounted(std::move(metadata), current_time); + return make_shared_ptr(std::move(metadata), current_time); } LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, bool binary_as_string) { diff --git a/src/catalog/catalog_entry/duck_table_entry.cpp b/src/catalog/catalog_entry/duck_table_entry.cpp index 53ca0cbba0c3..d6a56d0451cd 100644 --- a/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/catalog/catalog_entry/duck_table_entry.cpp @@ -84,7 +84,7 @@ DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, Bou storage_columns.push_back(col_def.Copy()); } storage = - make_refcounted(catalog.GetAttached(), StorageManager::Get(catalog).GetTableIOManager(&info), + make_shared_ptr(catalog.GetAttached(), StorageManager::Get(catalog).GetTableIOManager(&info), schema.name, name, std::move(storage_columns), std::move(info.data)); // create the unique indexes for the UNIQUE and PRIMARY KEY and FOREIGN KEY constraints @@ -345,7 +345,7 @@ unique_ptr DuckTableEntry::AddColumn(ClientContext &context, AddCo auto binder = Binder::CreateBinder(context); auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); auto new_storage = - make_refcounted(context, *storage, info.new_column, *bound_create_info->bound_defaults.back()); + make_shared_ptr(context, *storage, info.new_column, *bound_create_info->bound_defaults.back()); return make_uniq(catalog, schema, *bound_create_info, new_storage); } @@ -481,7 +481,7 @@ unique_ptr DuckTableEntry::RemoveColumn(ClientContext &context, Re return make_uniq(catalog, schema, *bound_create_info, storage); } auto new_storage = - make_refcounted(context, *storage, columns.LogicalToPhysical(LogicalIndex(removed_index)).index); + make_shared_ptr(context, *storage, columns.LogicalToPhysical(LogicalIndex(removed_index)).index); return make_uniq(catalog, schema, *bound_create_info, new_storage); } @@ -549,7 +549,7 @@ unique_ptr DuckTableEntry::SetNotNull(ClientContext &context, SetN } // Return with new storage info. Note that we need the bound column index here. - auto new_storage = make_refcounted( + auto new_storage = make_shared_ptr( context, *storage, make_uniq(columns.LogicalToPhysical(LogicalIndex(not_null_idx)))); return make_uniq(catalog, schema, *bound_create_info, new_storage); } @@ -660,7 +660,7 @@ unique_ptr DuckTableEntry::ChangeColumnType(ClientContext &context } auto new_storage = - make_refcounted(context, *storage, columns.LogicalToPhysical(LogicalIndex(change_idx)).index, + make_shared_ptr(context, *storage, columns.LogicalToPhysical(LogicalIndex(change_idx)).index, info.target_type, std::move(storage_oids), *bound_expression); auto result = make_uniq(catalog, schema, *bound_create_info, new_storage); return std::move(result); diff --git a/src/common/allocator.cpp b/src/common/allocator.cpp index 835aba386ff2..c587aaf33442 100644 --- a/src/common/allocator.cpp +++ b/src/common/allocator.cpp @@ -195,7 +195,7 @@ data_ptr_t Allocator::DefaultReallocate(PrivateAllocatorData *private_data, data } shared_ptr &Allocator::DefaultAllocatorReference() { - static shared_ptr DEFAULT_ALLOCATOR = make_refcounted(); + static shared_ptr DEFAULT_ALLOCATOR = make_shared_ptr(); return DEFAULT_ALLOCATOR; } diff --git a/src/common/arrow/arrow_wrapper.cpp b/src/common/arrow/arrow_wrapper.cpp index 1bcc48ce91b6..0f0613bca1c7 100644 --- a/src/common/arrow/arrow_wrapper.cpp +++ b/src/common/arrow/arrow_wrapper.cpp @@ -50,7 +50,7 @@ void ArrowArrayStreamWrapper::GetSchema(ArrowSchemaWrapper &schema) { } shared_ptr ArrowArrayStreamWrapper::GetNextChunk() { - auto current_chunk = make_refcounted(); + auto current_chunk = make_shared_ptr(); if (arrow_array_stream.get_next(&arrow_array_stream, ¤t_chunk->arrow_array)) { // LCOV_EXCL_START throw InvalidInputException("arrow_scan: get_next failed(): %s", string(GetError())); } // LCOV_EXCL_STOP diff --git a/src/common/extra_type_info.cpp b/src/common/extra_type_info.cpp index c5c3ca1b6b8f..d6a632164916 100644 --- a/src/common/extra_type_info.cpp +++ b/src/common/extra_type_info.cpp @@ -190,7 +190,7 @@ struct EnumTypeInfoTemplated : public EnumTypeInfo { deserializer.ReadList(201, "values", [&](Deserializer::List &list, idx_t i) { strings[i] = StringVector::AddStringOrBlob(values_insert_order, list.ReadElement()); }); - return make_refcounted(values_insert_order, size); + return make_shared_ptr(values_insert_order, size); } const string_map_t &GetValues() const { @@ -227,13 +227,13 @@ LogicalType EnumTypeInfo::CreateType(Vector &ordered_data, idx_t size) { auto enum_internal_type = EnumTypeInfo::DictType(size); switch (enum_internal_type) { case PhysicalType::UINT8: - info = make_refcounted>(ordered_data, size); + info = make_shared_ptr>(ordered_data, size); break; case PhysicalType::UINT16: - info = make_refcounted>(ordered_data, size); + info = make_shared_ptr>(ordered_data, size); break; case PhysicalType::UINT32: - info = make_refcounted>(ordered_data, size); + info = make_shared_ptr>(ordered_data, size); break; default: throw InternalException("Invalid Physical Type for ENUMs"); diff --git a/src/common/http_state.cpp b/src/common/http_state.cpp index 880454b87a75..a2e91182f588 100644 --- a/src/common/http_state.cpp +++ b/src/common/http_state.cpp @@ -69,7 +69,7 @@ shared_ptr HTTPState::TryGetState(ClientContext &context, bool create return nullptr; } - auto http_state = make_refcounted(); + auto http_state = make_shared_ptr(); context.registered_state["http_state"] = http_state; return http_state; } @@ -87,7 +87,7 @@ shared_ptr &HTTPState::GetCachedFile(const string &path) { lock_guard lock(cached_files_mutex); auto &cache_entry_ref = cached_files[path]; if (!cache_entry_ref) { - cache_entry_ref = make_refcounted(); + cache_entry_ref = make_shared_ptr(); } return cache_entry_ref; } diff --git a/src/common/re2_regex.cpp b/src/common/re2_regex.cpp index 4b3e2fb8e87b..c82bf429a135 100644 --- a/src/common/re2_regex.cpp +++ b/src/common/re2_regex.cpp @@ -10,7 +10,7 @@ namespace duckdb_re2 { Regex::Regex(const std::string &pattern, RegexOptions options) { RE2::Options o; o.set_case_sensitive(options == RegexOptions::CASE_INSENSITIVE); - regex = duckdb::make_refcounted(StringPiece(pattern), o); + regex = duckdb::make_shared_ptr(StringPiece(pattern), o); } bool RegexSearchInternal(const char *input, Match &match, const Regex &r, RE2::Anchor anchor, size_t start, diff --git a/src/common/types.cpp b/src/common/types.cpp index f70d69d8102a..b9f5b8edbb40 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -1127,7 +1127,7 @@ bool ApproxEqual(double ldecimal, double rdecimal) { //===--------------------------------------------------------------------===// void LogicalType::SetAlias(string alias) { if (!type_info_) { - type_info_ = make_refcounted(ExtraTypeInfoType::GENERIC_TYPE_INFO, std::move(alias)); + type_info_ = make_shared_ptr(ExtraTypeInfoType::GENERIC_TYPE_INFO, std::move(alias)); } else { type_info_->alias = std::move(alias); } @@ -1176,7 +1176,7 @@ uint8_t DecimalType::MaxWidth() { LogicalType LogicalType::DECIMAL(uint8_t width, uint8_t scale) { D_ASSERT(width >= scale); - auto type_info = make_refcounted(width, scale); + auto type_info = make_shared_ptr(width, scale); return LogicalType(LogicalTypeId::DECIMAL, std::move(type_info)); } @@ -1198,7 +1198,7 @@ string StringType::GetCollation(const LogicalType &type) { } LogicalType LogicalType::VARCHAR_COLLATION(string collation) { // NOLINT - auto string_info = make_refcounted(std::move(collation)); + auto string_info = make_shared_ptr(std::move(collation)); return LogicalType(LogicalTypeId::VARCHAR, std::move(string_info)); } @@ -1213,7 +1213,7 @@ const LogicalType &ListType::GetChildType(const LogicalType &type) { } LogicalType LogicalType::LIST(const LogicalType &child) { - auto info = make_refcounted(child); + auto info = make_shared_ptr(child); return LogicalType(LogicalTypeId::LIST, std::move(info)); } @@ -1285,12 +1285,12 @@ bool StructType::IsUnnamed(const LogicalType &type) { } LogicalType LogicalType::STRUCT(child_list_t children) { - auto info = make_refcounted(std::move(children)); + auto info = make_shared_ptr(std::move(children)); return LogicalType(LogicalTypeId::STRUCT, std::move(info)); } LogicalType LogicalType::AGGREGATE_STATE(aggregate_state_t state_type) { // NOLINT - auto info = make_refcounted(std::move(state_type)); + auto info = make_shared_ptr(std::move(state_type)); return LogicalType(LogicalTypeId::AGGREGATE_STATE, std::move(info)); } @@ -1315,7 +1315,7 @@ LogicalType LogicalType::MAP(const LogicalType &child_p) { new_children[1].first = "value"; auto child = LogicalType::STRUCT(std::move(new_children)); - auto info = make_refcounted(child); + auto info = make_shared_ptr(child); return LogicalType(LogicalTypeId::MAP, std::move(info)); } @@ -1344,7 +1344,7 @@ LogicalType LogicalType::UNION(child_list_t members) { D_ASSERT(members.size() <= UnionType::MAX_UNION_MEMBERS); // union types always have a hidden "tag" field in front members.insert(members.begin(), {"", LogicalType::UTINYINT}); - auto info = make_refcounted(std::move(members)); + auto info = make_shared_ptr(std::move(members)); return LogicalType(LogicalTypeId::UNION, std::move(info)); } @@ -1397,12 +1397,12 @@ const string &UserType::GetTypeName(const LogicalType &type) { } LogicalType LogicalType::USER(const string &user_type_name) { - auto info = make_refcounted(user_type_name); + auto info = make_shared_ptr(user_type_name); return LogicalType(LogicalTypeId::USER, std::move(info)); } LogicalType LogicalType::USER(string catalog, string schema, string name) { - auto info = make_refcounted(std::move(catalog), std::move(schema), std::move(name)); + auto info = make_shared_ptr(std::move(catalog), std::move(schema), std::move(name)); return LogicalType(LogicalTypeId::USER, std::move(info)); } @@ -1518,12 +1518,12 @@ LogicalType ArrayType::ConvertToList(const LogicalType &type) { LogicalType LogicalType::ARRAY(const LogicalType &child, idx_t size) { D_ASSERT(size > 0); D_ASSERT(size < ArrayType::MAX_ARRAY_SIZE); - auto info = make_refcounted(child, size); + auto info = make_shared_ptr(child, size); return LogicalType(LogicalTypeId::ARRAY, std::move(info)); } LogicalType LogicalType::ARRAY(const LogicalType &child) { - auto info = make_refcounted(child, 0); + auto info = make_shared_ptr(child, 0); return LogicalType(LogicalTypeId::ARRAY, std::move(info)); } @@ -1531,7 +1531,7 @@ LogicalType LogicalType::ARRAY(const LogicalType &child) { // Any Type //===--------------------------------------------------------------------===// LogicalType LogicalType::ANY_PARAMS(LogicalType target, idx_t cast_score) { // NOLINT - auto type_info = make_refcounted(std::move(target), cast_score); + auto type_info = make_shared_ptr(std::move(target), cast_score); return LogicalType(LogicalTypeId::ANY, std::move(type_info)); } @@ -1584,7 +1584,7 @@ LogicalType LogicalType::INTEGER_LITERAL(const Value &constant) { // NOLINT if (!constant.type().IsIntegral()) { throw InternalException("INTEGER_LITERAL can only be made from literals of integer types"); } - auto type_info = make_refcounted(constant); + auto type_info = make_shared_ptr(constant); return LogicalType(LogicalTypeId::INTEGER_LITERAL, std::move(type_info)); } diff --git a/src/common/types/column/column_data_collection.cpp b/src/common/types/column/column_data_collection.cpp index 0f931d7cf909..46aed955ad8b 100644 --- a/src/common/types/column/column_data_collection.cpp +++ b/src/common/types/column/column_data_collection.cpp @@ -51,17 +51,17 @@ ColumnDataCollection::ColumnDataCollection(Allocator &allocator_p) { types.clear(); count = 0; this->finished_append = false; - allocator = make_refcounted(allocator_p); + allocator = make_shared_ptr(allocator_p); } ColumnDataCollection::ColumnDataCollection(Allocator &allocator_p, vector types_p) { Initialize(std::move(types_p)); - allocator = make_refcounted(allocator_p); + allocator = make_shared_ptr(allocator_p); } ColumnDataCollection::ColumnDataCollection(BufferManager &buffer_manager, vector types_p) { Initialize(std::move(types_p)); - allocator = make_refcounted(buffer_manager); + allocator = make_shared_ptr(buffer_manager); } ColumnDataCollection::ColumnDataCollection(shared_ptr allocator_p, vector types_p) { @@ -71,7 +71,7 @@ ColumnDataCollection::ColumnDataCollection(shared_ptr alloc ColumnDataCollection::ColumnDataCollection(ClientContext &context, vector types_p, ColumnDataAllocatorType type) - : ColumnDataCollection(make_refcounted(context, type), std::move(types_p)) { + : ColumnDataCollection(make_shared_ptr(context, type), std::move(types_p)) { D_ASSERT(!types.empty()); } @@ -199,7 +199,7 @@ ColumnDataChunkIterationHelper::ColumnDataChunkIterationHelper(const ColumnDataC ColumnDataChunkIterationHelper::ColumnDataChunkIterator::ColumnDataChunkIterator( const ColumnDataCollection *collection_p, vector column_ids_p) - : collection(collection_p), scan_chunk(make_refcounted()), row_index(0) { + : collection(collection_p), scan_chunk(make_shared_ptr()), row_index(0) { if (!collection) { return; } @@ -246,7 +246,7 @@ ColumnDataRowIterationHelper::ColumnDataRowIterationHelper(const ColumnDataColle } ColumnDataRowIterationHelper::ColumnDataRowIterator::ColumnDataRowIterator(const ColumnDataCollection *collection_p) - : collection(collection_p), scan_chunk(make_refcounted()), current_row(*scan_chunk, 0, 0) { + : collection(collection_p), scan_chunk(make_shared_ptr()), current_row(*scan_chunk, 0, 0) { if (!collection) { return; } @@ -1041,7 +1041,7 @@ void ColumnDataCollection::Reset() { segments.clear(); // Refreshes the ColumnDataAllocator to prevent holding on to allocated data unnecessarily - allocator = make_refcounted(*allocator); + allocator = make_shared_ptr(*allocator); } struct ValueResultEquals { diff --git a/src/common/types/column/column_data_collection_segment.cpp b/src/common/types/column/column_data_collection_segment.cpp index 1f815d521974..918680c13911 100644 --- a/src/common/types/column/column_data_collection_segment.cpp +++ b/src/common/types/column/column_data_collection_segment.cpp @@ -7,7 +7,7 @@ namespace duckdb { ColumnDataCollectionSegment::ColumnDataCollectionSegment(shared_ptr allocator_p, vector types_p) : allocator(std::move(allocator_p)), types(std::move(types_p)), count(0), - heap(make_refcounted(allocator->GetAllocator())) { + heap(make_shared_ptr(allocator->GetAllocator())) { } idx_t ColumnDataCollectionSegment::GetDataSize(idx_t type_size) { diff --git a/src/common/types/column/partitioned_column_data.cpp b/src/common/types/column/partitioned_column_data.cpp index 7d47e129f26b..0ac1e065b24f 100644 --- a/src/common/types/column/partitioned_column_data.cpp +++ b/src/common/types/column/partitioned_column_data.cpp @@ -9,7 +9,7 @@ namespace duckdb { PartitionedColumnData::PartitionedColumnData(PartitionedColumnDataType type_p, ClientContext &context_p, vector types_p) : type(type_p), context(context_p), types(std::move(types_p)), - allocators(make_refcounted()) { + allocators(make_shared_ptr()) { } PartitionedColumnData::PartitionedColumnData(const PartitionedColumnData &other) @@ -165,7 +165,7 @@ vector> &PartitionedColumnData::GetPartitions() } void PartitionedColumnData::CreateAllocator() { - allocators->allocators.emplace_back(make_refcounted(BufferManager::GetBufferManager(context))); + allocators->allocators.emplace_back(make_shared_ptr(BufferManager::GetBufferManager(context))); allocators->allocators.back()->MakeShared(); } diff --git a/src/common/types/row/partitioned_tuple_data.cpp b/src/common/types/row/partitioned_tuple_data.cpp index cd67c32abb0d..b134a9b9784e 100644 --- a/src/common/types/row/partitioned_tuple_data.cpp +++ b/src/common/types/row/partitioned_tuple_data.cpp @@ -9,7 +9,7 @@ namespace duckdb { PartitionedTupleData::PartitionedTupleData(PartitionedTupleDataType type_p, BufferManager &buffer_manager_p, const TupleDataLayout &layout_p) : type(type_p), buffer_manager(buffer_manager_p), layout(layout_p.Copy()), count(0), data_size(0), - allocators(make_refcounted()) { + allocators(make_shared_ptr()) { } PartitionedTupleData::PartitionedTupleData(const PartitionedTupleData &other) @@ -434,7 +434,7 @@ void PartitionedTupleData::Print() { // LCOV_EXCL_STOP void PartitionedTupleData::CreateAllocator() { - allocators->allocators.emplace_back(make_refcounted(buffer_manager, layout)); + allocators->allocators.emplace_back(make_shared_ptr(buffer_manager, layout)); } } // namespace duckdb diff --git a/src/common/types/row/tuple_data_collection.cpp b/src/common/types/row/tuple_data_collection.cpp index 7ffcac79abce..86cbb144c804 100644 --- a/src/common/types/row/tuple_data_collection.cpp +++ b/src/common/types/row/tuple_data_collection.cpp @@ -12,7 +12,7 @@ namespace duckdb { using ValidityBytes = TupleDataLayout::ValidityBytes; TupleDataCollection::TupleDataCollection(BufferManager &buffer_manager, const TupleDataLayout &layout_p) - : layout(layout_p.Copy()), allocator(make_refcounted(buffer_manager, layout)) { + : layout(layout_p.Copy()), allocator(make_shared_ptr(buffer_manager, layout)) { Initialize(); } @@ -377,7 +377,7 @@ void TupleDataCollection::Reset() { segments.clear(); // Refreshes the TupleDataAllocator to prevent holding on to allocated data unnecessarily - allocator = make_refcounted(*allocator); + allocator = make_shared_ptr(*allocator); } void TupleDataCollection::InitializeChunk(DataChunk &chunk) const { diff --git a/src/common/types/value.cpp b/src/common/types/value.cpp index a2558d6cea42..9a10ea324d37 100644 --- a/src/common/types/value.cpp +++ b/src/common/types/value.cpp @@ -162,7 +162,7 @@ Value::Value(string val) : type_(LogicalType::VARCHAR), is_null(false) { if (!Value::StringIsValid(val.c_str(), val.size())) { throw ErrorManager::InvalidUnicodeError(val, "value construction"); } - value_info_ = make_refcounted(std::move(val)); + value_info_ = make_shared_ptr(std::move(val)); } Value::~Value() { @@ -668,7 +668,7 @@ Value Value::STRUCT(const LogicalType &type, vector struct_values) { for (size_t i = 0; i < struct_values.size(); i++) { struct_values[i] = struct_values[i].DefaultCastAs(child_types[i].second); } - result.value_info_ = make_refcounted(std::move(struct_values)); + result.value_info_ = make_shared_ptr(std::move(struct_values)); result.type_ = type; result.is_null = false; return result; @@ -711,7 +711,7 @@ Value Value::MAP(const LogicalType &key_type, const LogicalType &value_type, vec new_children.push_back(std::make_pair("value", std::move(values[i]))); values[i] = Value::STRUCT(std::move(new_children)); } - result.value_info_ = make_refcounted(std::move(values)); + result.value_info_ = make_shared_ptr(std::move(values)); return result; } @@ -735,7 +735,7 @@ Value Value::UNION(child_list_t members, uint8_t tag, Value value) } } union_values[tag + 1] = std::move(value); - result.value_info_ = make_refcounted(std::move(union_values)); + result.value_info_ = make_shared_ptr(std::move(union_values)); result.type_ = LogicalType::UNION(std::move(members)); return result; } @@ -752,7 +752,7 @@ Value Value::LIST(vector values) { #endif Value result; result.type_ = LogicalType::LIST(values[0].type()); - result.value_info_ = make_refcounted(std::move(values)); + result.value_info_ = make_shared_ptr(std::move(values)); result.is_null = false; return result; } @@ -770,7 +770,7 @@ Value Value::LIST(const LogicalType &child_type, vector values) { Value Value::EMPTYLIST(const LogicalType &child_type) { Value result; result.type_ = LogicalType::LIST(child_type); - result.value_info_ = make_refcounted(); + result.value_info_ = make_shared_ptr(); result.is_null = false; return result; } @@ -787,7 +787,7 @@ Value Value::ARRAY(vector values) { #endif Value result; result.type_ = LogicalType::ARRAY(values[0].type(), values.size()); - result.value_info_ = make_refcounted(std::move(values)); + result.value_info_ = make_shared_ptr(std::move(values)); result.is_null = false; return result; } @@ -805,7 +805,7 @@ Value Value::ARRAY(const LogicalType &child_type, vector values) { Value Value::EMPTYARRAY(const LogicalType &child_type, uint32_t size) { Value result; result.type_ = LogicalType::ARRAY(child_type, size); - result.value_info_ = make_refcounted(); + result.value_info_ = make_shared_ptr(); result.is_null = false; return result; } @@ -813,35 +813,35 @@ Value Value::EMPTYARRAY(const LogicalType &child_type, uint32_t size) { Value Value::BLOB(const_data_ptr_t data, idx_t len) { Value result(LogicalType::BLOB); result.is_null = false; - result.value_info_ = make_refcounted(string(const_char_ptr_cast(data), len)); + result.value_info_ = make_shared_ptr(string(const_char_ptr_cast(data), len)); return result; } Value Value::BLOB(const string &data) { Value result(LogicalType::BLOB); result.is_null = false; - result.value_info_ = make_refcounted(Blob::ToBlob(string_t(data))); + result.value_info_ = make_shared_ptr(Blob::ToBlob(string_t(data))); return result; } Value Value::AGGREGATE_STATE(const LogicalType &type, const_data_ptr_t data, idx_t len) { // NOLINT Value result(type); result.is_null = false; - result.value_info_ = make_refcounted(string(const_char_ptr_cast(data), len)); + result.value_info_ = make_shared_ptr(string(const_char_ptr_cast(data), len)); return result; } Value Value::BIT(const_data_ptr_t data, idx_t len) { Value result(LogicalType::BIT); result.is_null = false; - result.value_info_ = make_refcounted(string(const_char_ptr_cast(data), len)); + result.value_info_ = make_shared_ptr(string(const_char_ptr_cast(data), len)); return result; } Value Value::BIT(const string &data) { Value result(LogicalType::BIT); result.is_null = false; - result.value_info_ = make_refcounted(Bit::ToBit(string_t(data))); + result.value_info_ = make_shared_ptr(Bit::ToBit(string_t(data))); return result; } @@ -1936,27 +1936,27 @@ Value Value::Deserialize(Deserializer &deserializer) { case PhysicalType::VARCHAR: { auto str = deserializer.ReadProperty(102, "value"); if (type.id() == LogicalTypeId::BLOB) { - new_value.value_info_ = make_refcounted(Blob::ToBlob(str)); + new_value.value_info_ = make_shared_ptr(Blob::ToBlob(str)); } else { - new_value.value_info_ = make_refcounted(str); + new_value.value_info_ = make_shared_ptr(str); } } break; case PhysicalType::LIST: { deserializer.ReadObject(102, "value", [&](Deserializer &obj) { auto children = obj.ReadProperty>(100, "children"); - new_value.value_info_ = make_refcounted(children); + new_value.value_info_ = make_shared_ptr(children); }); } break; case PhysicalType::STRUCT: { deserializer.ReadObject(102, "value", [&](Deserializer &obj) { auto children = obj.ReadProperty>(100, "children"); - new_value.value_info_ = make_refcounted(children); + new_value.value_info_ = make_shared_ptr(children); }); } break; case PhysicalType::ARRAY: { deserializer.ReadObject(102, "value", [&](Deserializer &obj) { auto children = obj.ReadProperty>(100, "children"); - new_value.value_info_ = make_refcounted(children); + new_value.value_info_ = make_shared_ptr(children); }); } break; default: diff --git a/src/common/types/vector_cache.cpp b/src/common/types/vector_cache.cpp index 0c5075a54b45..56664319c5da 100644 --- a/src/common/types/vector_cache.cpp +++ b/src/common/types/vector_cache.cpp @@ -18,7 +18,7 @@ class VectorCacheBuffer : public VectorBuffer { auto &child_type = ListType::GetChildType(type); child_caches.push_back(make_buffer(allocator, child_type, capacity)); auto child_vector = make_uniq(child_type, false, false); - auxiliary = make_refcounted(std::move(child_vector)); + auxiliary = make_shared_ptr(std::move(child_vector)); break; } case PhysicalType::ARRAY: { @@ -26,7 +26,7 @@ class VectorCacheBuffer : public VectorBuffer { auto array_size = ArrayType::GetSize(type); child_caches.push_back(make_buffer(allocator, child_type, array_size * capacity)); auto child_vector = make_uniq(child_type, true, false, array_size * capacity); - auxiliary = make_refcounted(std::move(child_vector), array_size, capacity); + auxiliary = make_shared_ptr(std::move(child_vector), array_size, capacity); break; } case PhysicalType::STRUCT: { @@ -34,7 +34,7 @@ class VectorCacheBuffer : public VectorBuffer { for (auto &child_type : child_types) { child_caches.push_back(make_buffer(allocator, child_type.second, capacity)); } - auto struct_buffer = make_refcounted(type); + auto struct_buffer = make_shared_ptr(type); auxiliary = std::move(struct_buffer); break; } diff --git a/src/execution/aggregate_hashtable.cpp b/src/execution/aggregate_hashtable.cpp index 95f826a35596..08293a8025f6 100644 --- a/src/execution/aggregate_hashtable.cpp +++ b/src/execution/aggregate_hashtable.cpp @@ -40,7 +40,7 @@ GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, All vector aggregate_objects_p, idx_t initial_capacity, idx_t radix_bits) : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)), - radix_bits(radix_bits), count(0), capacity(0), aggregate_allocator(make_refcounted(allocator)) { + radix_bits(radix_bits), count(0), capacity(0), aggregate_allocator(make_shared_ptr(allocator)) { // Append hash column to the end and initialise the row layout group_types_p.emplace_back(LogicalType::HASH); diff --git a/src/execution/index/art/art.cpp b/src/execution/index/art/art.cpp index 7e3c7f3f38d3..b150f7f4398a 100644 --- a/src/execution/index/art/art.cpp +++ b/src/execution/index/art/art.cpp @@ -55,7 +55,7 @@ ART::ART(const string &name, const IndexConstraintType index_constraint_type, co make_uniq(sizeof(Node48), block_manager), make_uniq(sizeof(Node256), block_manager)}; allocators = - make_refcounted, ALLOCATOR_COUNT>>(std::move(allocator_array)); + make_shared_ptr, ALLOCATOR_COUNT>>(std::move(allocator_array)); } // deserialize lazily diff --git a/src/execution/operator/aggregate/aggregate_object.cpp b/src/execution/operator/aggregate/aggregate_object.cpp index 79a524ea76b8..af61fa17fb4d 100644 --- a/src/execution/operator/aggregate/aggregate_object.cpp +++ b/src/execution/operator/aggregate/aggregate_object.cpp @@ -9,7 +9,7 @@ AggregateObject::AggregateObject(AggregateFunction function, FunctionData *bind_ idx_t payload_size, AggregateType aggr_type, PhysicalType return_type, Expression *filter) : function(std::move(function)), - bind_data_wrapper(bind_data ? make_refcounted(bind_data->Copy()) : nullptr), + bind_data_wrapper(bind_data ? make_shared_ptr(bind_data->Copy()) : nullptr), child_count(child_count), payload_size(payload_size), aggr_type(aggr_type), return_type(return_type), filter(filter) { } diff --git a/src/execution/operator/aggregate/physical_hash_aggregate.cpp b/src/execution/operator/aggregate/physical_hash_aggregate.cpp index 420a9aecb918..b17fd7fcdca5 100644 --- a/src/execution/operator/aggregate/physical_hash_aggregate.cpp +++ b/src/execution/operator/aggregate/physical_hash_aggregate.cpp @@ -608,7 +608,7 @@ idx_t HashAggregateDistinctFinalizeEvent::CreateGlobalSources() { void HashAggregateDistinctFinalizeEvent::FinishEvent() { // Now that everything is added to the main ht, we can actually finalize - auto new_event = make_refcounted(context, pipeline.get(), op, gstate); + auto new_event = make_shared_ptr(context, pipeline.get(), op, gstate); this->InsertEvent(std::move(new_event)); } @@ -755,7 +755,7 @@ SinkFinalizeType PhysicalHashAggregate::FinalizeDistinct(Pipeline &pipeline, Eve radix_table->Finalize(context, radix_state); } } - auto new_event = make_refcounted(context, pipeline, *this, gstate); + auto new_event = make_shared_ptr(context, pipeline, *this, gstate); event.InsertEvent(std::move(new_event)); return SinkFinalizeType::READY; } diff --git a/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp b/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp index 97008097513f..5dbbe6d993ea 100644 --- a/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp +++ b/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp @@ -586,7 +586,7 @@ SinkFinalizeType PhysicalUngroupedAggregate::FinalizeDistinct(Pipeline &pipeline auto &radix_state = *distinct_state.radix_states[table_idx]; radix_table_p->Finalize(context, radix_state); } - auto new_event = make_refcounted(context, *this, gstate, pipeline); + auto new_event = make_shared_ptr(context, *this, gstate, pipeline); event.InsertEvent(std::move(new_event)); return SinkFinalizeType::READY; } diff --git a/src/execution/operator/aggregate/physical_window.cpp b/src/execution/operator/aggregate/physical_window.cpp index b945615bab16..d4dbf4982a26 100644 --- a/src/execution/operator/aggregate/physical_window.cpp +++ b/src/execution/operator/aggregate/physical_window.cpp @@ -171,7 +171,7 @@ SinkFinalizeType PhysicalWindow::Finalize(Pipeline &pipeline, Event &event, Clie } // Schedule all the sorts for maximum thread utilisation - auto new_event = make_refcounted(*state.global_partition, pipeline); + auto new_event = make_shared_ptr(*state.global_partition, pipeline); event.InsertEvent(std::move(new_event)); return SinkFinalizeType::READY; diff --git a/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp b/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp index 79de691c2003..934025425b4a 100644 --- a/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp +++ b/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp @@ -41,7 +41,7 @@ shared_ptr CSVBuffer::Next(CSVFileHandle &file_handle, idx_t buffer_s file_handle.Seek(global_csv_start + actual_buffer_size); has_seaked = false; } - auto next_csv_buffer = make_refcounted( + auto next_csv_buffer = make_shared_ptr( file_handle, context, buffer_size, global_csv_start + actual_buffer_size, file_number_p, buffer_idx + 1); if (next_csv_buffer->GetBufferSize() == 0) { // We are done reading @@ -76,7 +76,7 @@ shared_ptr CSVBuffer::Pin(CSVFileHandle &file_handle, bool &has Reload(file_handle); has_seeked = true; } - return make_refcounted(buffer_manager.Pin(block), actual_buffer_size, requested_size, last_buffer, + return make_shared_ptr(buffer_manager.Pin(block), actual_buffer_size, requested_size, last_buffer, file_number, buffer_idx); } diff --git a/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp b/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp index 494c4fd04b39..4aba439f1b6d 100644 --- a/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp +++ b/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp @@ -28,7 +28,7 @@ void CSVBufferManager::UnpinBuffer(const idx_t cache_idx) { void CSVBufferManager::Initialize() { if (cached_buffers.empty()) { cached_buffers.emplace_back( - make_refcounted(context, buffer_size, *file_handle, global_csv_pos, file_idx)); + make_shared_ptr(context, buffer_size, *file_handle, global_csv_pos, file_idx)); last_buffer = cached_buffers.front(); } } diff --git a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp index 791dfe725090..d23bd0e7a056 100644 --- a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp +++ b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp @@ -651,14 +651,14 @@ StringValueScanner::StringValueScanner(const shared_ptr &buffe } unique_ptr StringValueScanner::GetCSVScanner(ClientContext &context, CSVReaderOptions &options) { - auto state_machine = make_refcounted(options, options.dialect_options.state_machine_options, + auto state_machine = make_shared_ptr(options, options.dialect_options.state_machine_options, CSVStateMachineCache::Get(context)); state_machine->dialect_options.num_cols = options.dialect_options.num_cols; state_machine->dialect_options.header = options.dialect_options.header; - auto buffer_manager = make_refcounted(context, options, options.file_path, 0); - auto scanner = make_uniq(buffer_manager, state_machine, make_refcounted()); - scanner->csv_file_scan = make_refcounted(context, options.file_path, options); + auto buffer_manager = make_shared_ptr(context, options, options.file_path, 0); + auto scanner = make_uniq(buffer_manager, state_machine, make_shared_ptr()); + scanner->csv_file_scan = make_shared_ptr(context, options.file_path, options); scanner->csv_file_scan->InitializeProjection(); return scanner; } @@ -1222,7 +1222,7 @@ void StringValueScanner::SetStart() { } scan_finder = - make_uniq(0, buffer_manager, state_machine, make_refcounted(true), + make_uniq(0, buffer_manager, state_machine, make_shared_ptr(true), csv_file_scan, false, iterator, 1); auto &tuples = scan_finder->ParseChunk(); line_found = true; diff --git a/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp b/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp index 15fa8db67f03..d98797842994 100644 --- a/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp +++ b/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp @@ -13,8 +13,8 @@ CSVSniffer::CSVSniffer(CSVReaderOptions &options_p, shared_ptr } // Initialize max columns found to either 0 or however many were set max_columns_found = set_columns.Size(); - error_handler = make_refcounted(options.ignore_errors.GetValue()); - detection_error_handler = make_refcounted(true); + error_handler = make_shared_ptr(options.ignore_errors.GetValue()); + detection_error_handler = make_shared_ptr(true); } bool SetColumns::IsSet() { diff --git a/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp b/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp index a2ec291de390..eeb563cf771c 100644 --- a/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp +++ b/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp @@ -10,7 +10,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, shared_ptr bu vector &file_schema) : file_path(options_p.file_path), file_idx(0), buffer_manager(std::move(buffer_manager_p)), state_machine(std::move(state_machine_p)), file_size(buffer_manager->file_handle->FileSize()), - error_handler(make_refcounted(options_p.ignore_errors.GetValue())), + error_handler(make_shared_ptr(options_p.ignore_errors.GetValue())), on_disk_file(buffer_manager->file_handle->OnDiskFile()), options(options_p) { if (bind_data.initial_reader.get()) { auto &union_reader = *bind_data.initial_reader; @@ -43,7 +43,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons const idx_t file_idx_p, const ReadCSVData &bind_data, const vector &column_ids, const vector &file_schema) : file_path(file_path_p), file_idx(file_idx_p), - error_handler(make_refcounted(options_p.ignore_errors.GetValue())), options(options_p) { + error_handler(make_shared_ptr(options_p.ignore_errors.GetValue())), options(options_p) { if (file_idx < bind_data.union_readers.size()) { // we are doing UNION BY NAME - fetch the options from the union reader for this file optional_ptr union_reader_ptr; @@ -73,7 +73,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons } // Initialize Buffer Manager - buffer_manager = make_refcounted(context, options, file_path, file_idx); + buffer_manager = make_shared_ptr(context, options, file_path, file_idx); // Initialize On Disk and Size of file on_disk_file = buffer_manager->file_handle->OnDiskFile(); file_size = buffer_manager->file_handle->FileSize(); @@ -89,7 +89,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons CSVSniffer sniffer(options, buffer_manager, state_machine_cache); sniffer.SniffCSV(); } - state_machine = make_refcounted( + state_machine = make_shared_ptr( state_machine_cache.Get(options.dialect_options.state_machine_options), options); MultiFileReader::InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, @@ -120,7 +120,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons names = bind_data.csv_names; types = bind_data.csv_types; - state_machine = make_refcounted( + state_machine = make_shared_ptr( state_machine_cache.Get(options.dialect_options.state_machine_options), options); MultiFileReader::InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, @@ -130,8 +130,8 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons CSVFileScan::CSVFileScan(ClientContext &context, const string &file_name, CSVReaderOptions &options_p) : file_path(file_name), file_idx(0), - error_handler(make_refcounted(options_p.ignore_errors.GetValue())), options(options_p) { - buffer_manager = make_refcounted(context, options, file_path, file_idx); + error_handler(make_shared_ptr(options_p.ignore_errors.GetValue())), options(options_p) { + buffer_manager = make_shared_ptr(context, options, file_path, file_idx); // Initialize On Disk and Size of file on_disk_file = buffer_manager->file_handle->OnDiskFile(); file_size = buffer_manager->file_handle->FileSize(); @@ -151,7 +151,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_name, CSVRea options.dialect_options.num_cols = options.sql_type_list.size(); } // Initialize State Machine - state_machine = make_refcounted( + state_machine = make_shared_ptr( state_machine_cache.Get(options.dialect_options.state_machine_options), options); } diff --git a/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp b/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp index d358c19aeae7..c1182fc859bf 100644 --- a/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp +++ b/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp @@ -14,7 +14,7 @@ CSVGlobalState::CSVGlobalState(ClientContext &context_p, const shared_ptrGetFilePath() == files[0]) { - auto state_machine = make_refcounted( + auto state_machine = make_shared_ptr( CSVStateMachineCache::Get(context).Get(options.dialect_options.state_machine_options), options); // If we already have a buffer manager, we don't need to reconstruct it to the first file file_scans.emplace_back(make_uniq(context, buffer_manager, state_machine, options, bind_data, @@ -36,7 +36,7 @@ CSVGlobalState::CSVGlobalState(ClientContext &context_p, const shared_ptrbuffer_manager->GetBuffer(0)->actual_size; current_boundary = CSVIterator(0, 0, 0, 0, buffer_size); } - current_buffer_in_use = make_refcounted(*file_scans.back()->buffer_manager, 0); + current_buffer_in_use = make_shared_ptr(*file_scans.back()->buffer_manager, 0); } double CSVGlobalState::GetProgress(const ReadCSVData &bind_data_p) const { @@ -67,7 +67,7 @@ unique_ptr CSVGlobalState::Next(optional_ptr parallel_lock(main_mutex); - file_scans.emplace_back(make_refcounted(context, bind_data.files[cur_idx], bind_data.options, + file_scans.emplace_back(make_shared_ptr(context, bind_data.files[cur_idx], bind_data.options, cur_idx, bind_data, column_ids, file_schema)); current_file = file_scans.back(); } @@ -88,7 +88,7 @@ unique_ptr CSVGlobalState::Next(optional_ptrbuffer_idx != current_boundary.GetBufferIdx()) { current_buffer_in_use = - make_refcounted(*file_scans.back()->buffer_manager, current_boundary.GetBufferIdx()); + make_shared_ptr(*file_scans.back()->buffer_manager, current_boundary.GetBufferIdx()); } // We first create the scanner for the current boundary auto ¤t_file = *file_scans.back(); @@ -111,13 +111,13 @@ unique_ptr CSVGlobalState::Next(optional_ptr(context, bind_data.files[current_file_idx], + file_scans.emplace_back(make_shared_ptr(context, bind_data.files[current_file_idx], bind_data.options, current_file_idx, bind_data, column_ids, file_schema)); // And re-start the boundary-iterator auto buffer_size = file_scans.back()->buffer_manager->GetBuffer(0)->actual_size; current_boundary = CSVIterator(current_file_idx, 0, 0, 0, buffer_size); - current_buffer_in_use = make_refcounted(*file_scans.back()->buffer_manager, 0); + current_buffer_in_use = make_shared_ptr(*file_scans.back()->buffer_manager, 0); } else { // If not we are done with this CSV Scanning finished = true; diff --git a/src/execution/operator/helper/physical_buffered_collector.cpp b/src/execution/operator/helper/physical_buffered_collector.cpp index 90708953a643..9f2ac70db09d 100644 --- a/src/execution/operator/helper/physical_buffered_collector.cpp +++ b/src/execution/operator/helper/physical_buffered_collector.cpp @@ -55,7 +55,7 @@ SinkCombineResultType PhysicalBufferedCollector::Combine(ExecutionContext &conte unique_ptr PhysicalBufferedCollector::GetGlobalSinkState(ClientContext &context) const { auto state = make_uniq(); state->context = context.shared_from_this(); - state->buffered_data = make_refcounted(state->context); + state->buffered_data = make_shared_ptr(state->context); return std::move(state); } diff --git a/src/execution/operator/join/physical_asof_join.cpp b/src/execution/operator/join/physical_asof_join.cpp index 05e45d7a455f..29bdef93bb5d 100644 --- a/src/execution/operator/join/physical_asof_join.cpp +++ b/src/execution/operator/join/physical_asof_join.cpp @@ -169,7 +169,7 @@ SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, Cl } // Schedule all the sorts for maximum thread utilisation - auto new_event = make_refcounted(gstate.rhs_sink, pipeline); + auto new_event = make_shared_ptr(gstate.rhs_sink, pipeline); event.InsertEvent(std::move(new_event)); return SinkFinalizeType::READY; diff --git a/src/execution/operator/join/physical_hash_join.cpp b/src/execution/operator/join/physical_hash_join.cpp index 09c336458300..f7333cd660fc 100644 --- a/src/execution/operator/join/physical_hash_join.cpp +++ b/src/execution/operator/join/physical_hash_join.cpp @@ -359,7 +359,7 @@ void HashJoinGlobalSinkState::ScheduleFinalize(Pipeline &pipeline, Event &event) return; } hash_table->InitializePointerTable(); - auto new_event = make_refcounted(pipeline, *this); + auto new_event = make_shared_ptr(pipeline, *this); event.InsertEvent(std::move(new_event)); } @@ -474,7 +474,7 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl // We have to repartition ht.SetRepartitionRadixBits(sink.local_hash_tables, sink.temporary_memory_state->GetReservation(), max_partition_size, max_partition_count); - auto new_event = make_refcounted(pipeline, sink, sink.local_hash_tables); + auto new_event = make_shared_ptr(pipeline, sink, sink.local_hash_tables); event.InsertEvent(std::move(new_event)); } else { // No repartitioning! diff --git a/src/execution/operator/join/physical_range_join.cpp b/src/execution/operator/join/physical_range_join.cpp index d89360236a3b..996c0a4a7e4b 100644 --- a/src/execution/operator/join/physical_range_join.cpp +++ b/src/execution/operator/join/physical_range_join.cpp @@ -149,7 +149,7 @@ class RangeJoinMergeEvent : public BasePipelineEvent { void PhysicalRangeJoin::GlobalSortedTable::ScheduleMergeTasks(Pipeline &pipeline, Event &event) { // Initialize global sort state for a round of merging global_sort_state.InitializeMergeRound(); - auto new_event = make_refcounted(*this, pipeline); + auto new_event = make_shared_ptr(*this, pipeline); event.InsertEvent(std::move(new_event)); } diff --git a/src/execution/operator/order/physical_order.cpp b/src/execution/operator/order/physical_order.cpp index 8687acfe5ee1..f12076b80687 100644 --- a/src/execution/operator/order/physical_order.cpp +++ b/src/execution/operator/order/physical_order.cpp @@ -187,7 +187,7 @@ SinkFinalizeType PhysicalOrder::Finalize(Pipeline &pipeline, Event &event, Clien void PhysicalOrder::ScheduleMergeTasks(Pipeline &pipeline, Event &event, OrderGlobalSinkState &state) { // Initialize global sort state for a round of merging state.global_sort_state.InitializeMergeRound(); - auto new_event = make_refcounted(state, pipeline); + auto new_event = make_shared_ptr(state, pipeline); event.InsertEvent(std::move(new_event)); } diff --git a/src/execution/operator/persistent/physical_batch_copy_to_file.cpp b/src/execution/operator/persistent/physical_batch_copy_to_file.cpp index 831ea35b6e49..30699ee47c2f 100644 --- a/src/execution/operator/persistent/physical_batch_copy_to_file.cpp +++ b/src/execution/operator/persistent/physical_batch_copy_to_file.cpp @@ -308,7 +308,7 @@ SinkFinalizeType PhysicalBatchCopyToFile::Finalize(Pipeline &pipeline, Event &ev FinalFlush(context, input.global_state); } else { // we have multiple tasks remaining - launch an event to execute the tasks in parallel - auto new_event = make_refcounted(*this, gstate, pipeline, context); + auto new_event = make_shared_ptr(*this, gstate, pipeline, context); event.InsertEvent(std::move(new_event)); } return SinkFinalizeType::READY; diff --git a/src/execution/operator/persistent/physical_copy_to_file.cpp b/src/execution/operator/persistent/physical_copy_to_file.cpp index be9ca4dbbd61..343551dd2c94 100644 --- a/src/execution/operator/persistent/physical_copy_to_file.cpp +++ b/src/execution/operator/persistent/physical_copy_to_file.cpp @@ -284,7 +284,7 @@ unique_ptr PhysicalCopyToFile::GetGlobalSinkState(ClientContext } if (partition_output) { - state->partition_state = make_refcounted(); + state->partition_state = make_shared_ptr(); } return std::move(state); diff --git a/src/execution/operator/schema/physical_create_art_index.cpp b/src/execution/operator/schema/physical_create_art_index.cpp index fe88d7c3bfa4..94aa2b74b9e9 100644 --- a/src/execution/operator/schema/physical_create_art_index.cpp +++ b/src/execution/operator/schema/physical_create_art_index.cpp @@ -178,7 +178,7 @@ SinkFinalizeType PhysicalCreateARTIndex::Finalize(Pipeline &pipeline, Event &eve auto &index = index_entry->Cast(); index.initial_index_size = state.global_index->GetInMemorySize(); - index.info = make_refcounted(storage.info, index.name); + index.info = make_shared_ptr(storage.info, index.name); for (auto &parsed_expr : info->parsed_expressions) { index.parsed_expressions.push_back(parsed_expr->Copy()); } diff --git a/src/execution/operator/set/physical_recursive_cte.cpp b/src/execution/operator/set/physical_recursive_cte.cpp index 5210987325ab..406299ac2939 100644 --- a/src/execution/operator/set/physical_recursive_cte.cpp +++ b/src/execution/operator/set/physical_recursive_cte.cpp @@ -200,7 +200,7 @@ void PhysicalRecursiveCTE::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_ initial_state_pipeline.Build(*children[0]); // the RHS is the recursive pipeline - recursive_meta_pipeline = make_refcounted(executor, state, this); + recursive_meta_pipeline = make_shared_ptr(executor, state, this); recursive_meta_pipeline->SetRecursiveCTE(); recursive_meta_pipeline->Build(*children[1]); diff --git a/src/execution/physical_plan/plan_cte.cpp b/src/execution/physical_plan/plan_cte.cpp index c11286a8cf3f..190cb9319bc7 100644 --- a/src/execution/physical_plan/plan_cte.cpp +++ b/src/execution/physical_plan/plan_cte.cpp @@ -12,7 +12,7 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalMaterializ D_ASSERT(op.children.size() == 2); // Create the working_table that the PhysicalCTE will use for evaluation. - auto working_table = make_refcounted(context, op.children[0]->types); + auto working_table = make_shared_ptr(context, op.children[0]->types); // Add the ColumnDataCollection to the context of this PhysicalPlanGenerator recursive_cte_tables[op.table_index] = working_table; diff --git a/src/execution/physical_plan/plan_recursive_cte.cpp b/src/execution/physical_plan/plan_recursive_cte.cpp index 5ddb767612b6..7933dd1660ba 100644 --- a/src/execution/physical_plan/plan_recursive_cte.cpp +++ b/src/execution/physical_plan/plan_recursive_cte.cpp @@ -12,7 +12,7 @@ unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalRecursiveC D_ASSERT(op.children.size() == 2); // Create the working_table that the PhysicalRecursiveCTE will use for evaluation. - auto working_table = make_refcounted(context, op.types); + auto working_table = make_shared_ptr(context, op.types); // Add the ColumnDataCollection to the context of this PhysicalPlanGenerator recursive_cte_tables[op.table_index] = working_table; diff --git a/src/function/table/copy_csv.cpp b/src/function/table/copy_csv.cpp index e45a6a643120..0814d0f2d655 100644 --- a/src/function/table/copy_csv.cpp +++ b/src/function/table/copy_csv.cpp @@ -156,7 +156,7 @@ static unique_ptr ReadCSVBind(ClientContext &context, CopyInfo &in } if (options.auto_detect) { - auto buffer_manager = make_refcounted(context, options, bind_data->files[0], 0); + auto buffer_manager = make_shared_ptr(context, options, bind_data->files[0], 0); CSVSniffer sniffer(options, buffer_manager, CSVStateMachineCache::Get(context), {&expected_types, &expected_names}); sniffer.SniffCSV(); diff --git a/src/function/table/read_csv.cpp b/src/function/table/read_csv.cpp index dc8ec7f3abe7..6c949913762b 100644 --- a/src/function/table/read_csv.cpp +++ b/src/function/table/read_csv.cpp @@ -88,7 +88,7 @@ static unique_ptr ReadCSVBind(ClientContext &context, TableFunctio } if (options.auto_detect && !options.file_options.union_by_name) { options.file_path = result->files[0]; - result->buffer_manager = make_refcounted(context, options, result->files[0], 0); + result->buffer_manager = make_shared_ptr(context, options, result->files[0], 0); CSVSniffer sniffer(options, result->buffer_manager, CSVStateMachineCache::Get(context), {&return_types, &names}); auto sniffer_result = sniffer.SniffCSV(); diff --git a/src/function/table/sniff_csv.cpp b/src/function/table/sniff_csv.cpp index b7998902d93d..11dfb7797577 100644 --- a/src/function/table/sniff_csv.cpp +++ b/src/function/table/sniff_csv.cpp @@ -109,7 +109,7 @@ static void CSVSniffFunction(ClientContext &context, TableFunctionInput &data_p, auto sniffer_options = data.options; sniffer_options.file_path = data.path; - auto buffer_manager = make_refcounted(context, sniffer_options, sniffer_options.file_path, 0); + auto buffer_manager = make_shared_ptr(context, sniffer_options, sniffer_options.file_path, 0); if (sniffer_options.name_list.empty()) { sniffer_options.name_list = data.names_csv; } diff --git a/src/include/duckdb/common/helper.hpp b/src/include/duckdb/common/helper.hpp index ed3e6e1f1be3..789e9042febf 100644 --- a/src/include/duckdb/common/helper.hpp +++ b/src/include/duckdb/common/helper.hpp @@ -68,7 +68,7 @@ make_uniq(ARGS&&... args) // NOLINT: mimic std style template inline shared_ptr -make_refcounted(ARGS&&... args) // NOLINT: mimic std style +make_shared_ptr(ARGS&&... args) // NOLINT: mimic std style { return shared_ptr(std::make_shared(std::forward(args)...)); } @@ -125,7 +125,7 @@ shared_ptr shared_ptr_cast(shared_ptr src) { // NOLINT: mimic std style struct SharedConstructor { template static shared_ptr Create(ARGS &&...args) { - return make_refcounted(std::forward(args)...); + return make_shared_ptr(std::forward(args)...); } }; diff --git a/src/include/duckdb/common/multi_file_reader.hpp b/src/include/duckdb/common/multi_file_reader.hpp index ec318dea76e0..d7c106d00a4b 100644 --- a/src/include/duckdb/common/multi_file_reader.hpp +++ b/src/include/duckdb/common/multi_file_reader.hpp @@ -151,7 +151,7 @@ struct MultiFileReader { return BindUnionReader(context, return_types, names, result, options); } else { shared_ptr reader; - reader = make_refcounted(context, result.files[0], options); + reader = make_shared_ptr(context, result.files[0], options); return_types = reader->return_types; names = reader->names; result.Initialize(std::move(reader)); diff --git a/src/include/duckdb/common/types.hpp b/src/include/duckdb/common/types.hpp index e9e31b488ea5..6e38fa83b89d 100644 --- a/src/include/duckdb/common/types.hpp +++ b/src/include/duckdb/common/types.hpp @@ -35,7 +35,7 @@ using buffer_ptr = shared_ptr; template buffer_ptr make_buffer(ARGS &&...args) { // NOLINT: mimic std casing - return make_refcounted(std::forward(args)...); + return make_shared_ptr(std::forward(args)...); } struct list_entry_t { // NOLINT: mimic std casing diff --git a/src/include/duckdb/common/types/selection_vector.hpp b/src/include/duckdb/common/types/selection_vector.hpp index a0f0b185beae..969c785690d4 100644 --- a/src/include/duckdb/common/types/selection_vector.hpp +++ b/src/include/duckdb/common/types/selection_vector.hpp @@ -71,7 +71,7 @@ struct SelectionVector { sel_vector = sel; } void Initialize(idx_t count = STANDARD_VECTOR_SIZE) { - selection_data = make_refcounted(count); + selection_data = make_shared_ptr(count); sel_vector = selection_data->owned_data.get(); } void Initialize(buffer_ptr data) { diff --git a/src/include/duckdb/planner/binder.hpp b/src/include/duckdb/planner/binder.hpp index fa1734b1f360..c92559fab142 100644 --- a/src/include/duckdb/planner/binder.hpp +++ b/src/include/duckdb/planner/binder.hpp @@ -376,7 +376,7 @@ class Binder : public enable_shared_from_this { unique_ptr BindSummarize(ShowRef &ref); public: - // This should really be a private constructor, but make_refcounted does not allow it... + // This should really be a private constructor, but make_shared_ptr does not allow it... // If you are thinking about calling this, you should probably call Binder::CreateBinder Binder(bool i_know_what_i_am_doing, ClientContext &context, shared_ptr parent, bool inherit_ctes); }; diff --git a/src/include/duckdb/storage/object_cache.hpp b/src/include/duckdb/storage/object_cache.hpp index 06a5c2d3a767..5ef1a6733201 100644 --- a/src/include/duckdb/storage/object_cache.hpp +++ b/src/include/duckdb/storage/object_cache.hpp @@ -52,7 +52,7 @@ class ObjectCache { auto entry = cache.find(key); if (entry == cache.end()) { - auto value = make_refcounted(args...); + auto value = make_shared_ptr(args...); cache[key] = value; return value; } diff --git a/src/include/duckdb/storage/serialization/types.json b/src/include/duckdb/storage/serialization/types.json index 5433f50a9d15..8f3adec0837e 100644 --- a/src/include/duckdb/storage/serialization/types.json +++ b/src/include/duckdb/storage/serialization/types.json @@ -155,7 +155,7 @@ "class": "GenericTypeInfo", "base": "ExtraTypeInfo", "enum": "GENERIC_TYPE_INFO", - "custom_switch_code": "result = make_refcounted(type);\nbreak;" + "custom_switch_code": "result = make_shared_ptr(type);\nbreak;" }, { "class": "AnyTypeInfo", diff --git a/src/main/capi/table_function-c.cpp b/src/main/capi/table_function-c.cpp index 57be51d01f4b..a6916e8e2007 100644 --- a/src/main/capi/table_function-c.cpp +++ b/src/main/capi/table_function-c.cpp @@ -179,7 +179,7 @@ void CTableFunction(ClientContext &context, TableFunctionInput &data_p, DataChun duckdb_table_function duckdb_create_table_function() { auto function = new duckdb::TableFunction("", {}, duckdb::CTableFunction, duckdb::CTableFunctionBind, duckdb::CTableFunctionInit, duckdb::CTableFunctionLocalInit); - function->function_info = duckdb::make_refcounted(); + function->function_info = duckdb::make_shared_ptr(); function->cardinality = duckdb::CTableFunctionCardinality; return function; } diff --git a/src/main/client_context.cpp b/src/main/client_context.cpp index 49275d1dd7d0..515ab5fc732b 100644 --- a/src/main/client_context.cpp +++ b/src/main/client_context.cpp @@ -312,7 +312,7 @@ ClientContext::CreatePreparedStatementInternal(ClientContextLock &lock, const st unique_ptr statement, optional_ptr> values) { StatementType statement_type = statement->type; - auto result = make_refcounted(statement_type); + auto result = make_shared_ptr(statement_type); auto &profiler = QueryProfiler::Get(*this); profiler.StartQuery(query, IsExplainAnalyze(statement.get()), true); diff --git a/src/main/client_data.cpp b/src/main/client_data.cpp index 4dda60a7e053..54c5292cf339 100644 --- a/src/main/client_data.cpp +++ b/src/main/client_data.cpp @@ -35,8 +35,8 @@ class ClientFileSystem : public OpenerFileSystem { ClientData::ClientData(ClientContext &context) : catalog_search_path(make_uniq(context)) { auto &db = DatabaseInstance::GetDatabase(context); - profiler = make_refcounted(context); - temporary_objects = make_refcounted(db, AttachedDatabaseType::TEMP_DATABASE); + profiler = make_shared_ptr(context); + temporary_objects = make_shared_ptr(db, AttachedDatabaseType::TEMP_DATABASE); temporary_objects->oid = DatabaseManager::Get(db).ModifyCatalog(); random_engine = make_uniq(); file_opener = make_uniq(context); diff --git a/src/main/connection.cpp b/src/main/connection.cpp index b76d440647f7..11346a3a9846 100644 --- a/src/main/connection.cpp +++ b/src/main/connection.cpp @@ -19,7 +19,7 @@ namespace duckdb { Connection::Connection(DatabaseInstance &database) - : context(make_refcounted(database.shared_from_this())) { + : context(make_shared_ptr(database.shared_from_this())) { ConnectionManager::Get(database).AddConnection(*context); #ifdef DEBUG EnableProfiling(); @@ -187,7 +187,7 @@ shared_ptr Connection::Table(const string &schema_name, const string & if (!table_info) { throw CatalogException("Table '%s' does not exist!", table_name); } - return make_refcounted(context, std::move(table_info)); + return make_shared_ptr(context, std::move(table_info)); } shared_ptr Connection::View(const string &tname) { @@ -195,7 +195,7 @@ shared_ptr Connection::View(const string &tname) { } shared_ptr Connection::View(const string &schema_name, const string &table_name) { - return make_refcounted(context, schema_name, table_name); + return make_shared_ptr(context, schema_name, table_name); } shared_ptr Connection::TableFunction(const string &fname) { @@ -206,11 +206,11 @@ shared_ptr Connection::TableFunction(const string &fname) { shared_ptr Connection::TableFunction(const string &fname, const vector &values, const named_parameter_map_t &named_parameters) { - return make_refcounted(context, fname, values, named_parameters); + return make_shared_ptr(context, fname, values, named_parameters); } shared_ptr Connection::TableFunction(const string &fname, const vector &values) { - return make_refcounted(context, fname, values); + return make_shared_ptr(context, fname, values); } shared_ptr Connection::Values(const vector> &values) { @@ -220,7 +220,7 @@ shared_ptr Connection::Values(const vector> &values) { shared_ptr Connection::Values(const vector> &values, const vector &column_names, const string &alias) { - return make_refcounted(context, values, column_names, alias); + return make_shared_ptr(context, values, column_names, alias); } shared_ptr Connection::Values(const string &values) { @@ -229,7 +229,7 @@ shared_ptr Connection::Values(const string &values) { } shared_ptr Connection::Values(const string &values, const vector &column_names, const string &alias) { - return make_refcounted(context, values, column_names, alias); + return make_shared_ptr(context, values, column_names, alias); } shared_ptr Connection::ReadCSV(const string &csv_file) { @@ -238,7 +238,7 @@ shared_ptr Connection::ReadCSV(const string &csv_file) { } shared_ptr Connection::ReadCSV(const vector &csv_input, named_parameter_map_t &&options) { - return make_refcounted(context, csv_input, std::move(options)); + return make_shared_ptr(context, csv_input, std::move(options)); } shared_ptr Connection::ReadCSV(const string &csv_input, named_parameter_map_t &&options) { @@ -259,7 +259,7 @@ shared_ptr Connection::ReadCSV(const string &csv_file, const vector files {csv_file}; - return make_refcounted(context, files, std::move(options)); + return make_shared_ptr(context, files, std::move(options)); } shared_ptr Connection::ReadParquet(const string &parquet_file, bool binary_as_string) { @@ -278,7 +278,7 @@ shared_ptr Connection::RelationFromQuery(const string &query, const st } shared_ptr Connection::RelationFromQuery(unique_ptr select_stmt, const string &alias) { - return make_refcounted(context, std::move(select_stmt), alias); + return make_shared_ptr(context, std::move(select_stmt), alias); } void Connection::BeginTransaction() { diff --git a/src/main/database.cpp b/src/main/database.cpp index 4fb6ef547d99..5568ca26bffc 100644 --- a/src/main/database.cpp +++ b/src/main/database.cpp @@ -272,7 +272,7 @@ void DatabaseInstance::Initialize(const char *database_path, DBConfig *user_conf scheduler->RelaunchThreads(); } -DuckDB::DuckDB(const char *path, DBConfig *new_config) : instance(make_refcounted()) { +DuckDB::DuckDB(const char *path, DBConfig *new_config) : instance(make_shared_ptr()) { instance->Initialize(path, new_config); if (instance->config.options.load_extensions) { ExtensionHelper::LoadAllExtensions(*this); @@ -380,7 +380,7 @@ void DatabaseInstance::Configure(DBConfig &new_config) { if (new_config.buffer_pool) { config.buffer_pool = std::move(new_config.buffer_pool); } else { - config.buffer_pool = make_refcounted(config.options.maximum_memory); + config.buffer_pool = make_shared_ptr(config.options.maximum_memory); } } diff --git a/src/main/db_instance_cache.cpp b/src/main/db_instance_cache.cpp index 6105066c6c88..a72592fd8346 100644 --- a/src/main/db_instance_cache.cpp +++ b/src/main/db_instance_cache.cpp @@ -66,7 +66,7 @@ shared_ptr DBInstanceCache::CreateInstanceInternal(const string &databas if (abs_database_path.rfind(IN_MEMORY_PATH, 0) == 0) { instance_path = IN_MEMORY_PATH; } - auto db_instance = make_refcounted(instance_path, &config); + auto db_instance = make_shared_ptr(instance_path, &config); if (cache_instance) { db_instances[abs_database_path] = db_instance; } diff --git a/src/main/relation.cpp b/src/main/relation.cpp index f315906bd200..46841d4e70ed 100644 --- a/src/main/relation.cpp +++ b/src/main/relation.cpp @@ -39,7 +39,7 @@ shared_ptr Relation::Project(const string &expression, const string &a shared_ptr Relation::Project(const string &select_list, const vector &aliases) { auto expressions = Parser::ParseExpressionList(select_list, context.GetContext()->GetParserOptions()); - return make_refcounted(shared_from_this(), std::move(expressions), aliases); + return make_shared_ptr(shared_from_this(), std::move(expressions), aliases); } shared_ptr Relation::Project(const vector &expressions) { @@ -49,7 +49,7 @@ shared_ptr Relation::Project(const vector &expressions) { shared_ptr Relation::Project(vector> expressions, const vector &aliases) { - return make_refcounted(shared_from_this(), std::move(expressions), aliases); + return make_shared_ptr(shared_from_this(), std::move(expressions), aliases); } static vector> StringListToExpressionList(ClientContext &context, @@ -70,7 +70,7 @@ static vector> StringListToExpressionList(ClientCon shared_ptr Relation::Project(const vector &expressions, const vector &aliases) { auto result_list = StringListToExpressionList(*context.GetContext(), expressions); - return make_refcounted(shared_from_this(), std::move(result_list), aliases); + return make_shared_ptr(shared_from_this(), std::move(result_list), aliases); } shared_ptr Relation::Filter(const string &expression) { @@ -82,7 +82,7 @@ shared_ptr Relation::Filter(const string &expression) { } shared_ptr Relation::Filter(unique_ptr expression) { - return make_refcounted(shared_from_this(), std::move(expression)); + return make_shared_ptr(shared_from_this(), std::move(expression)); } shared_ptr Relation::Filter(const vector &expressions) { @@ -95,11 +95,11 @@ shared_ptr Relation::Filter(const vector &expressions) { expr = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(expr), std::move(expression_list[i])); } - return make_refcounted(shared_from_this(), std::move(expr)); + return make_shared_ptr(shared_from_this(), std::move(expr)); } shared_ptr Relation::Limit(int64_t limit, int64_t offset) { - return make_refcounted(shared_from_this(), limit, offset); + return make_shared_ptr(shared_from_this(), limit, offset); } shared_ptr Relation::Order(const string &expression) { @@ -108,7 +108,7 @@ shared_ptr Relation::Order(const string &expression) { } shared_ptr Relation::Order(vector order_list) { - return make_refcounted(shared_from_this(), std::move(order_list)); + return make_shared_ptr(shared_from_this(), std::move(order_list)); } shared_ptr Relation::Order(const vector &expressions) { @@ -149,51 +149,51 @@ shared_ptr Relation::Join(const shared_ptr &other, } using_columns.push_back(colref.column_names[0]); } - return make_refcounted(shared_from_this(), other, std::move(using_columns), type, ref_type); + return make_shared_ptr(shared_from_this(), other, std::move(using_columns), type, ref_type); } else { // single expression that is not a column reference: use the expression as a join condition - return make_refcounted(shared_from_this(), other, std::move(expression_list[0]), type, ref_type); + return make_shared_ptr(shared_from_this(), other, std::move(expression_list[0]), type, ref_type); } } shared_ptr Relation::CrossProduct(const shared_ptr &other, JoinRefType join_ref_type) { - return make_refcounted(shared_from_this(), other, join_ref_type); + return make_shared_ptr(shared_from_this(), other, join_ref_type); } shared_ptr Relation::Union(const shared_ptr &other) { - return make_refcounted(shared_from_this(), other, SetOperationType::UNION, true); + return make_shared_ptr(shared_from_this(), other, SetOperationType::UNION, true); } shared_ptr Relation::Except(const shared_ptr &other) { - return make_refcounted(shared_from_this(), other, SetOperationType::EXCEPT, true); + return make_shared_ptr(shared_from_this(), other, SetOperationType::EXCEPT, true); } shared_ptr Relation::Intersect(const shared_ptr &other) { - return make_refcounted(shared_from_this(), other, SetOperationType::INTERSECT, true); + return make_shared_ptr(shared_from_this(), other, SetOperationType::INTERSECT, true); } shared_ptr Relation::Distinct() { - return make_refcounted(shared_from_this()); + return make_shared_ptr(shared_from_this()); } shared_ptr Relation::Alias(const string &alias) { - return make_refcounted(shared_from_this(), alias); + return make_shared_ptr(shared_from_this(), alias); } shared_ptr Relation::Aggregate(const string &aggregate_list) { auto expression_list = Parser::ParseExpressionList(aggregate_list, context.GetContext()->GetParserOptions()); - return make_refcounted(shared_from_this(), std::move(expression_list)); + return make_shared_ptr(shared_from_this(), std::move(expression_list)); } shared_ptr Relation::Aggregate(const string &aggregate_list, const string &group_list) { auto expression_list = Parser::ParseExpressionList(aggregate_list, context.GetContext()->GetParserOptions()); auto groups = Parser::ParseGroupByList(group_list, context.GetContext()->GetParserOptions()); - return make_refcounted(shared_from_this(), std::move(expression_list), std::move(groups)); + return make_shared_ptr(shared_from_this(), std::move(expression_list), std::move(groups)); } shared_ptr Relation::Aggregate(const vector &aggregates) { auto aggregate_list = StringListToExpressionList(*context.GetContext(), aggregates); - return make_refcounted(shared_from_this(), std::move(aggregate_list)); + return make_shared_ptr(shared_from_this(), std::move(aggregate_list)); } shared_ptr Relation::Aggregate(const vector &aggregates, const vector &groups) { @@ -204,7 +204,7 @@ shared_ptr Relation::Aggregate(const vector &aggregates, const shared_ptr Relation::Aggregate(vector> expressions, const string &group_list) { auto groups = Parser::ParseGroupByList(group_list, context.GetContext()->GetParserOptions()); - return make_refcounted(shared_from_this(), std::move(expressions), std::move(groups)); + return make_shared_ptr(shared_from_this(), std::move(expressions), std::move(groups)); } string Relation::GetAlias() { @@ -237,7 +237,7 @@ BoundStatement Relation::Bind(Binder &binder) { } shared_ptr Relation::InsertRel(const string &schema_name, const string &table_name) { - return make_refcounted(shared_from_this(), schema_name, table_name); + return make_shared_ptr(shared_from_this(), schema_name, table_name); } void Relation::Insert(const string &table_name) { @@ -255,12 +255,12 @@ void Relation::Insert(const string &schema_name, const string &table_name) { void Relation::Insert(const vector> &values) { vector column_names; - auto rel = make_refcounted(context.GetContext(), values, std::move(column_names), "values"); + auto rel = make_shared_ptr(context.GetContext(), values, std::move(column_names), "values"); rel->Insert(GetAlias()); } shared_ptr Relation::CreateRel(const string &schema_name, const string &table_name) { - return make_refcounted(shared_from_this(), schema_name, table_name); + return make_shared_ptr(shared_from_this(), schema_name, table_name); } void Relation::Create(const string &table_name) { @@ -277,7 +277,7 @@ void Relation::Create(const string &schema_name, const string &table_name) { } shared_ptr Relation::WriteCSVRel(const string &csv_file, case_insensitive_map_t> options) { - return make_refcounted(shared_from_this(), csv_file, std::move(options)); + return make_shared_ptr(shared_from_this(), csv_file, std::move(options)); } void Relation::WriteCSV(const string &csv_file, case_insensitive_map_t> options) { @@ -292,7 +292,7 @@ void Relation::WriteCSV(const string &csv_file, case_insensitive_map_t Relation::WriteParquetRel(const string &parquet_file, case_insensitive_map_t> options) { auto write_parquet = - make_refcounted(shared_from_this(), parquet_file, std::move(options)); + make_shared_ptr(shared_from_this(), parquet_file, std::move(options)); return std::move(write_parquet); } @@ -310,7 +310,7 @@ shared_ptr Relation::CreateView(const string &name, bool replace, bool } shared_ptr Relation::CreateView(const string &schema_name, const string &name, bool replace, bool temporary) { - auto view = make_refcounted(shared_from_this(), schema_name, name, replace, temporary); + auto view = make_shared_ptr(shared_from_this(), schema_name, name, replace, temporary); auto res = view->Execute(); if (res->HasError()) { const string prepended_message = "Failed to create view '" + name + "': "; @@ -329,7 +329,7 @@ unique_ptr Relation::Query(const string &name, const string &sql) { } unique_ptr Relation::Explain(ExplainType type) { - auto explain = make_refcounted(shared_from_this(), type); + auto explain = make_shared_ptr(shared_from_this(), type); return explain->Execute(); } @@ -343,12 +343,12 @@ void Relation::Delete(const string &condition) { shared_ptr Relation::TableFunction(const std::string &fname, const vector &values, const named_parameter_map_t &named_parameters) { - return make_refcounted(context.GetContext(), fname, values, named_parameters, + return make_shared_ptr(context.GetContext(), fname, values, named_parameters, shared_from_this()); } shared_ptr Relation::TableFunction(const std::string &fname, const vector &values) { - return make_refcounted(context.GetContext(), fname, values, shared_from_this()); + return make_shared_ptr(context.GetContext(), fname, values, shared_from_this()); } string Relation::ToString() { diff --git a/src/main/relation/read_csv_relation.cpp b/src/main/relation/read_csv_relation.cpp index f63d535c13ab..cce7dba4ec8d 100644 --- a/src/main/relation/read_csv_relation.cpp +++ b/src/main/relation/read_csv_relation.cpp @@ -56,7 +56,7 @@ ReadCSVRelation::ReadCSVRelation(const shared_ptr &context, const shared_ptr buffer_manager; context->RunFunctionInTransaction([&]() { - buffer_manager = make_refcounted(*context, csv_options, files[0], 0); + buffer_manager = make_shared_ptr(*context, csv_options, files[0], 0); CSVSniffer sniffer(csv_options, buffer_manager, CSVStateMachineCache::Get(*context)); auto sniffer_result = sniffer.SniffCSV(); auto &types = sniffer_result.return_types; diff --git a/src/main/relation/table_relation.cpp b/src/main/relation/table_relation.cpp index c37c88507849..4a0ff6e0e50d 100644 --- a/src/main/relation/table_relation.cpp +++ b/src/main/relation/table_relation.cpp @@ -56,14 +56,14 @@ void TableRelation::Update(const string &update_list, const string &condition) { vector> expressions; auto cond = ParseCondition(*context.GetContext(), condition); Parser::ParseUpdateList(update_list, update_columns, expressions, context.GetContext()->GetParserOptions()); - auto update = make_refcounted(context, std::move(cond), description->schema, description->table, + auto update = make_shared_ptr(context, std::move(cond), description->schema, description->table, std::move(update_columns), std::move(expressions)); update->Execute(); } void TableRelation::Delete(const string &condition) { auto cond = ParseCondition(*context.GetContext(), condition); - auto del = make_refcounted(context, std::move(cond), description->schema, description->table); + auto del = make_shared_ptr(context, std::move(cond), description->schema, description->table); del->Execute(); } diff --git a/src/parallel/executor.cpp b/src/parallel/executor.cpp index a3be9b315843..827890080bea 100644 --- a/src/parallel/executor.cpp +++ b/src/parallel/executor.cpp @@ -73,11 +73,11 @@ void Executor::SchedulePipeline(const shared_ptr &meta_pipeline, S // create events/stack for the base pipeline auto base_pipeline = meta_pipeline->GetBasePipeline(); - auto base_initialize_event = make_refcounted(base_pipeline); - auto base_event = make_refcounted(base_pipeline); - auto base_finish_event = make_refcounted(base_pipeline); + auto base_initialize_event = make_shared_ptr(base_pipeline); + auto base_event = make_shared_ptr(base_pipeline); + auto base_finish_event = make_shared_ptr(base_pipeline); auto base_complete_event = - make_refcounted(base_pipeline->executor, event_data.initial_schedule); + make_shared_ptr(base_pipeline->executor, event_data.initial_schedule); PipelineEventStack base_stack(*base_initialize_event, *base_event, *base_finish_event, *base_complete_event); events.push_back(std::move(base_initialize_event)); events.push_back(std::move(base_event)); @@ -97,7 +97,7 @@ void Executor::SchedulePipeline(const shared_ptr &meta_pipeline, S D_ASSERT(pipeline); // create events/stack for this pipeline - auto pipeline_event = make_refcounted(pipeline); + auto pipeline_event = make_shared_ptr(pipeline); auto finish_group = meta_pipeline->GetFinishGroup(*pipeline); if (finish_group) { @@ -116,7 +116,7 @@ void Executor::SchedulePipeline(const shared_ptr &meta_pipeline, S event_map.insert(make_pair(reference(*pipeline), pipeline_stack)); } else if (meta_pipeline->HasFinishEvent(*pipeline)) { // this pipeline has its own finish event (despite going into the same sink - Finalize twice!) - auto pipeline_finish_event = make_refcounted(pipeline); + auto pipeline_finish_event = make_shared_ptr(pipeline); PipelineEventStack pipeline_stack(base_stack.pipeline_initialize_event, *pipeline_event, *pipeline_finish_event, base_stack.pipeline_complete_event); events.push_back(std::move(pipeline_finish_event)); @@ -360,7 +360,7 @@ void Executor::InitializeInternal(PhysicalOperator &plan) { // build and ready the pipelines PipelineBuildState state; - auto root_pipeline = make_refcounted(*this, state, nullptr); + auto root_pipeline = make_shared_ptr(*this, state, nullptr); root_pipeline->Build(*physical_plan); root_pipeline->Ready(); @@ -571,7 +571,7 @@ shared_ptr Executor::CreateChildPipeline(Pipeline ¤t, PhysicalOp D_ASSERT(op.IsSource()); // found another operator that is a source, schedule a child pipeline // 'op' is the source, and the sink is the same - auto child_pipeline = make_refcounted(*this); + auto child_pipeline = make_shared_ptr(*this); child_pipeline->sink = current.sink; child_pipeline->source = &op; diff --git a/src/parallel/meta_pipeline.cpp b/src/parallel/meta_pipeline.cpp index b73515d8e694..d383fe9592f8 100644 --- a/src/parallel/meta_pipeline.cpp +++ b/src/parallel/meta_pipeline.cpp @@ -82,7 +82,7 @@ void MetaPipeline::Ready() { } MetaPipeline &MetaPipeline::CreateChildMetaPipeline(Pipeline ¤t, PhysicalOperator &op) { - children.push_back(make_refcounted(executor, state, &op)); + children.push_back(make_shared_ptr(executor, state, &op)); auto child_meta_pipeline = children.back().get(); // child MetaPipeline must finish completely before this MetaPipeline can start current.AddDependency(child_meta_pipeline->GetBasePipeline()); @@ -92,7 +92,7 @@ MetaPipeline &MetaPipeline::CreateChildMetaPipeline(Pipeline ¤t, PhysicalO } Pipeline &MetaPipeline::CreatePipeline() { - pipelines.emplace_back(make_refcounted(executor)); + pipelines.emplace_back(make_shared_ptr(executor)); state.SetPipelineSink(*pipelines.back(), sink, next_batch_index++); return *pipelines.back(); } diff --git a/src/planner/bind_context.cpp b/src/planner/bind_context.cpp index 611c7b34414b..9e7da9c79bd2 100644 --- a/src/planner/bind_context.cpp +++ b/src/planner/bind_context.cpp @@ -514,13 +514,13 @@ void BindContext::AddGenericBinding(idx_t index, const string &alias, const vect void BindContext::AddCTEBinding(idx_t index, const string &alias, const vector &names, const vector &types) { - auto binding = make_refcounted(BindingType::BASE, alias, types, names, index); + auto binding = make_shared_ptr(BindingType::BASE, alias, types, names, index); if (cte_bindings.find(alias) != cte_bindings.end()) { throw BinderException("Duplicate alias \"%s\" in query!", alias); } cte_bindings[alias] = std::move(binding); - cte_references[alias] = make_refcounted(0); + cte_references[alias] = make_shared_ptr(0); } void BindContext::AddContext(BindContext other) { diff --git a/src/planner/binder.cpp b/src/planner/binder.cpp index 9316c6f85f63..66d0949ddd5f 100644 --- a/src/planner/binder.cpp +++ b/src/planner/binder.cpp @@ -47,7 +47,7 @@ shared_ptr Binder::CreateBinder(ClientContext &context, optional_ptr(true, context, parent ? parent->shared_from_this() : nullptr, inherit_ctes); + return make_shared_ptr(true, context, parent ? parent->shared_from_this() : nullptr, inherit_ctes); } Binder::Binder(bool, ClientContext &context, shared_ptr parent_p, bool inherit_ctes_p) diff --git a/src/planner/bound_parameter_map.cpp b/src/planner/bound_parameter_map.cpp index 61571cd8543f..aac752302bae 100644 --- a/src/planner/bound_parameter_map.cpp +++ b/src/planner/bound_parameter_map.cpp @@ -33,7 +33,7 @@ shared_ptr BoundParameterMap::CreateOrGetData(const string & auto entry = parameters.find(identifier); if (entry == parameters.end()) { // no entry yet: create a new one - auto data = make_refcounted(); + auto data = make_shared_ptr(); data->return_type = GetReturnType(identifier); CreateNewParameter(identifier, data); diff --git a/src/planner/planner.cpp b/src/planner/planner.cpp index 37381c2b6ec9..0ec47a7bfbd0 100644 --- a/src/planner/planner.cpp +++ b/src/planner/planner.cpp @@ -101,7 +101,7 @@ shared_ptr Planner::PrepareSQLStatement(unique_ptr(copied_statement->type); + auto prepared_data = make_shared_ptr(copied_statement->type); prepared_data->unbound_statement = std::move(copied_statement); prepared_data->names = names; prepared_data->types = types; diff --git a/src/storage/buffer/block_manager.cpp b/src/storage/buffer/block_manager.cpp index fdead1bc1e6b..e5e21b0e3cbd 100644 --- a/src/storage/buffer/block_manager.cpp +++ b/src/storage/buffer/block_manager.cpp @@ -23,7 +23,7 @@ shared_ptr BlockManager::RegisterBlock(block_id_t block_id) { } } // create a new block pointer for this block - auto result = make_refcounted(*this, block_id, MemoryTag::BASE_TABLE); + auto result = make_shared_ptr(*this, block_id, MemoryTag::BASE_TABLE); // register the block pointer in the set of blocks as a weak pointer blocks[block_id] = weak_ptr(result); return result; diff --git a/src/storage/checkpoint_manager.cpp b/src/storage/checkpoint_manager.cpp index fe91431d986b..429e0b9d1d53 100644 --- a/src/storage/checkpoint_manager.cpp +++ b/src/storage/checkpoint_manager.cpp @@ -425,7 +425,7 @@ void CheckpointReader::ReadIndex(ClientContext &context, Deserializer &deseriali // now we can look for the index in the catalog and assign the table info auto &index = catalog.CreateIndex(context, info)->Cast(); - index.info = make_refcounted(table.GetStorage().info, info.index_name); + index.info = make_shared_ptr(table.GetStorage().info, info.index_name); // insert the parsed expressions into the index so that we can (de)serialize them during consecutive checkpoints for (auto &parsed_expr : info.parsed_expressions) { diff --git a/src/storage/data_table.cpp b/src/storage/data_table.cpp index a1695d079b2e..96c670da95e3 100644 --- a/src/storage/data_table.cpp +++ b/src/storage/data_table.cpp @@ -45,12 +45,12 @@ bool DataTableInfo::IsTemporary() const { DataTable::DataTable(AttachedDatabase &db, shared_ptr table_io_manager_p, const string &schema, const string &table, vector column_definitions_p, unique_ptr data) - : info(make_refcounted(db, std::move(table_io_manager_p), schema, table)), + : info(make_shared_ptr(db, std::move(table_io_manager_p), schema, table)), column_definitions(std::move(column_definitions_p)), db(db), is_root(true) { // initialize the table with the existing data from disk, if any auto types = GetTypes(); this->row_groups = - make_refcounted(info, TableIOManager::Get(*this).GetBlockManagerForRowData(), types, 0); + make_shared_ptr(info, TableIOManager::Get(*this).GetBlockManagerForRowData(), types, 0); if (data && data->row_group_count > 0) { this->row_groups->Initialize(*data); } else { diff --git a/src/storage/local_storage.cpp b/src/storage/local_storage.cpp index 55d3a746a97f..1ba86783bbb5 100644 --- a/src/storage/local_storage.cpp +++ b/src/storage/local_storage.cpp @@ -18,7 +18,7 @@ LocalTableStorage::LocalTableStorage(DataTable &table) : table_ref(table), allocator(Allocator::Get(table.db)), deleted_rows(0), optimistic_writer(table), merged_storage(false) { auto types = table.GetTypes(); - row_groups = make_refcounted(table.info, TableIOManager::Get(table).GetBlockManagerForRowData(), + row_groups = make_shared_ptr(table.info, TableIOManager::Get(table).GetBlockManagerForRowData(), types, MAX_ROW_ID, 0); row_groups->InitializeEmpty(); @@ -250,7 +250,7 @@ LocalTableStorage &LocalTableManager::GetOrCreateStorage(DataTable &table) { lock_guard l(table_storage_lock); auto entry = table_storage.find(table); if (entry == table_storage.end()) { - auto new_storage = make_refcounted(table); + auto new_storage = make_shared_ptr(table); auto storage = new_storage.get(); table_storage.insert(make_pair(reference(table), std::move(new_storage))); return *storage; @@ -531,7 +531,7 @@ void LocalStorage::AddColumn(DataTable &old_dt, DataTable &new_dt, ColumnDefinit if (!storage) { return; } - auto new_storage = make_refcounted(context, new_dt, *storage, new_column, default_value); + auto new_storage = make_shared_ptr(context, new_dt, *storage, new_column, default_value); table_manager.InsertEntry(new_dt, std::move(new_storage)); } @@ -541,7 +541,7 @@ void LocalStorage::DropColumn(DataTable &old_dt, DataTable &new_dt, idx_t remove if (!storage) { return; } - auto new_storage = make_refcounted(new_dt, *storage, removed_column); + auto new_storage = make_shared_ptr(new_dt, *storage, removed_column); table_manager.InsertEntry(new_dt, std::move(new_storage)); } @@ -552,7 +552,7 @@ void LocalStorage::ChangeType(DataTable &old_dt, DataTable &new_dt, idx_t change if (!storage) { return; } - auto new_storage = make_refcounted(context, new_dt, *storage, changed_idx, target_type, + auto new_storage = make_shared_ptr(context, new_dt, *storage, changed_idx, target_type, bound_columns, cast_expr); table_manager.InsertEntry(new_dt, std::move(new_storage)); } diff --git a/src/storage/serialization/serialize_types.cpp b/src/storage/serialization/serialize_types.cpp index 1b3b37d87cd9..f937c0e0df73 100644 --- a/src/storage/serialization/serialize_types.cpp +++ b/src/storage/serialization/serialize_types.cpp @@ -35,7 +35,7 @@ shared_ptr ExtraTypeInfo::Deserialize(Deserializer &deserializer) result = EnumTypeInfo::Deserialize(deserializer); break; case ExtraTypeInfoType::GENERIC_TYPE_INFO: - result = make_refcounted(type); + result = make_shared_ptr(type); break; case ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO: result = IntegerLiteralTypeInfo::Deserialize(deserializer); diff --git a/src/storage/standard_buffer_manager.cpp b/src/storage/standard_buffer_manager.cpp index f0055e8cb3ef..647a283b03e8 100644 --- a/src/storage/standard_buffer_manager.cpp +++ b/src/storage/standard_buffer_manager.cpp @@ -98,7 +98,7 @@ shared_ptr StandardBufferManager::RegisterSmallMemory(idx_t block_s auto buffer = ConstructManagedBuffer(block_size, nullptr, FileBufferType::TINY_BUFFER); // create a new block pointer for this block - auto result = make_refcounted(*temp_block_manager, ++temporary_id, MemoryTag::BASE_TABLE, + auto result = make_shared_ptr(*temp_block_manager, ++temporary_id, MemoryTag::BASE_TABLE, std::move(buffer), false, block_size, std::move(reservation)); #ifdef DUCKDB_DEBUG_DESTROY_BLOCKS // Initialize the memory with garbage data @@ -118,7 +118,7 @@ shared_ptr StandardBufferManager::RegisterMemory(MemoryTag tag, idx auto buffer = ConstructManagedBuffer(block_size, std::move(reusable_buffer)); // create a new block pointer for this block - return make_refcounted(*temp_block_manager, ++temporary_id, tag, std::move(buffer), can_destroy, + return make_shared_ptr(*temp_block_manager, ++temporary_id, tag, std::move(buffer), can_destroy, alloc_size, std::move(res)); } diff --git a/src/storage/statistics/column_statistics.cpp b/src/storage/statistics/column_statistics.cpp index 67d67417b671..8b3ac243e963 100644 --- a/src/storage/statistics/column_statistics.cpp +++ b/src/storage/statistics/column_statistics.cpp @@ -14,7 +14,7 @@ ColumnStatistics::ColumnStatistics(BaseStatistics stats_p, unique_ptr ColumnStatistics::CreateEmptyStats(const LogicalType &type) { - return make_refcounted(BaseStatistics::CreateEmpty(type)); + return make_shared_ptr(BaseStatistics::CreateEmpty(type)); } void ColumnStatistics::Merge(ColumnStatistics &other) { @@ -53,7 +53,7 @@ void ColumnStatistics::UpdateDistinctStatistics(Vector &v, idx_t count) { } shared_ptr ColumnStatistics::Copy() const { - return make_refcounted(stats.Copy(), distinct_stats ? distinct_stats->Copy() : nullptr); + return make_shared_ptr(stats.Copy(), distinct_stats ? distinct_stats->Copy() : nullptr); } void ColumnStatistics::Serialize(Serializer &serializer) const { @@ -65,7 +65,7 @@ shared_ptr ColumnStatistics::Deserialize(Deserializer &deseria auto stats = deserializer.ReadProperty(100, "statistics"); auto distinct_stats = deserializer.ReadPropertyWithDefault>( 101, "distinct", unique_ptr()); - return make_refcounted(std::move(stats), std::move(distinct_stats)); + return make_shared_ptr(std::move(stats), std::move(distinct_stats)); } } // namespace duckdb diff --git a/src/storage/table/row_group.cpp b/src/storage/table/row_group.cpp index 4d232e3f53c9..24c05e50463d 100644 --- a/src/storage/table/row_group.cpp +++ b/src/storage/table/row_group.cpp @@ -624,7 +624,7 @@ shared_ptr &RowGroup::GetOrCreateVersionInfoPtr() { if (!vinfo) { lock_guard lock(row_group_lock); if (!version_info) { - version_info = make_refcounted(start); + version_info = make_shared_ptr(start); } } return version_info; diff --git a/src/storage/table/row_group_collection.cpp b/src/storage/table/row_group_collection.cpp index c333ea6b1c10..e46a0e7663ff 100644 --- a/src/storage/table/row_group_collection.cpp +++ b/src/storage/table/row_group_collection.cpp @@ -55,7 +55,7 @@ RowGroupCollection::RowGroupCollection(shared_ptr info_p, BlockMa vector types_p, idx_t row_start_p, idx_t total_rows_p) : block_manager(block_manager), total_rows(total_rows_p), info(std::move(info_p)), types(std::move(types_p)), row_start(row_start_p), allocation_size(0) { - row_groups = make_refcounted(*this); + row_groups = make_shared_ptr(*this); } idx_t RowGroupCollection::GetTotalRows() const { @@ -1031,7 +1031,7 @@ shared_ptr RowGroupCollection::AddColumn(ClientContext &cont auto new_types = types; new_types.push_back(new_column.GetType()); auto result = - make_refcounted(info, block_manager, std::move(new_types), row_start, total_rows.load()); + make_shared_ptr(info, block_manager, std::move(new_types), row_start, total_rows.load()); ExpressionExecutor executor(context); DataChunk dummy_chunk; @@ -1059,7 +1059,7 @@ shared_ptr RowGroupCollection::RemoveColumn(idx_t col_idx) { new_types.erase(new_types.begin() + col_idx); auto result = - make_refcounted(info, block_manager, std::move(new_types), row_start, total_rows.load()); + make_shared_ptr(info, block_manager, std::move(new_types), row_start, total_rows.load()); result->stats.InitializeRemoveColumn(stats, col_idx); for (auto ¤t_row_group : row_groups->Segments()) { @@ -1077,7 +1077,7 @@ shared_ptr RowGroupCollection::AlterType(ClientContext &cont new_types[changed_idx] = target_type; auto result = - make_refcounted(info, block_manager, std::move(new_types), row_start, total_rows.load()); + make_shared_ptr(info, block_manager, std::move(new_types), row_start, total_rows.load()); result->stats.InitializeAlterType(stats, changed_idx, target_type); vector scan_types; diff --git a/src/storage/table/row_version_manager.cpp b/src/storage/table/row_version_manager.cpp index 711daa7d1881..e20d9b06a836 100644 --- a/src/storage/table/row_version_manager.cpp +++ b/src/storage/table/row_version_manager.cpp @@ -212,7 +212,7 @@ shared_ptr RowVersionManager::Deserialize(MetaBlockPointer de if (!delete_pointer.IsValid()) { return nullptr; } - auto version_info = make_refcounted(start); + auto version_info = make_shared_ptr(start); MetadataReader source(manager, delete_pointer, &version_info->storage_pointers); auto chunk_count = source.Read(); D_ASSERT(chunk_count > 0); diff --git a/src/storage/wal_replay.cpp b/src/storage/wal_replay.cpp index b3fea7dae501..04e38bc9d4aa 100644 --- a/src/storage/wal_replay.cpp +++ b/src/storage/wal_replay.cpp @@ -580,7 +580,7 @@ void WriteAheadLogDeserializer::ReplayCreateIndex() { // create the index in the catalog auto &table = catalog.GetEntry(context, create_info->schema, info.table).Cast(); auto &index = catalog.CreateIndex(context, info)->Cast(); - index.info = make_refcounted(table.GetStorage().info, index.name); + index.info = make_shared_ptr(table.GetStorage().info, index.name); // insert the parsed expressions into the index so that we can (de)serialize them during consecutive checkpoints for (auto &parsed_expr : info.parsed_expressions) { diff --git a/test/api/test_object_cache.cpp b/test/api/test_object_cache.cpp index 04f4ac7c91de..eebc48c94281 100644 --- a/test/api/test_object_cache.cpp +++ b/test/api/test_object_cache.cpp @@ -42,7 +42,7 @@ TEST_CASE("Test ObjectCache", "[api]") { auto &cache = ObjectCache::GetObjectCache(context); REQUIRE(cache.GetObject("test") == nullptr); - cache.Put("test", make_refcounted(42)); + cache.Put("test", make_shared_ptr(42)); REQUIRE(cache.GetObject("test") != nullptr); diff --git a/tools/odbc/include/duckdb_odbc.hpp b/tools/odbc/include/duckdb_odbc.hpp index 14b92ff1df5c..f7752d3b0682 100644 --- a/tools/odbc/include/duckdb_odbc.hpp +++ b/tools/odbc/include/duckdb_odbc.hpp @@ -42,7 +42,7 @@ struct OdbcHandleEnv : public OdbcHandle { OdbcHandleEnv() : OdbcHandle(OdbcHandleType::ENV) { duckdb::DBConfig ODBC_CONFIG; ODBC_CONFIG.SetOptionByName("duckdb_api", "odbc"); - db = make_refcounted(nullptr, &ODBC_CONFIG); + db = make_shared_ptr(nullptr, &ODBC_CONFIG); }; shared_ptr db; diff --git a/tools/pythonpkg/src/pyconnection.cpp b/tools/pythonpkg/src/pyconnection.cpp index 548d83a0331d..5d98994fc1a0 100644 --- a/tools/pythonpkg/src/pyconnection.cpp +++ b/tools/pythonpkg/src/pyconnection.cpp @@ -640,7 +640,7 @@ void DuckDBPyConnection::RegisterArrowObject(const py::object &arrow_object, con } vector> dependencies; dependencies.push_back( - make_refcounted(make_uniq(std::move(stream_factory), arrow_object))); + make_shared_ptr(make_uniq(std::move(stream_factory), arrow_object))); connection->context->external_dependencies[name] = std::move(dependencies); } @@ -665,7 +665,7 @@ shared_ptr DuckDBPyConnection::RegisterPythonObject(const st // keep a reference vector> dependencies; - dependencies.push_back(make_refcounted(make_uniq(python_object), + dependencies.push_back(make_shared_ptr(make_uniq(python_object), make_uniq(new_df))); connection->context->external_dependencies[name] = std::move(dependencies); } @@ -776,7 +776,7 @@ unique_ptr DuckDBPyConnection::ReadJSON(const string &name, co } auto read_json_relation = - make_refcounted(connection->context, name, std::move(options), auto_detect); + make_shared_ptr(connection->context, name, std::move(options), auto_detect); if (read_json_relation == nullptr) { throw BinderException("read_json can only be used when the JSON extension is (statically) loaded"); } @@ -1319,7 +1319,7 @@ shared_ptr DuckDBPyConnection::Cursor() { if (!connection) { throw ConnectionException("Connection has already been closed"); } - auto res = make_refcounted(); + auto res = make_shared_ptr(); res->database = database; res->connection = make_uniq(*res->database); cursors.push_back(res); @@ -1598,7 +1598,7 @@ static void SetDefaultConfigArguments(ClientContext &context) { } static shared_ptr FetchOrCreateInstance(const string &database, DBConfig &config) { - auto res = make_refcounted(); + auto res = make_shared_ptr(); res->database = instance_cache.GetInstance(database, config); if (!res->database) { //! No cached database, we must create a new instance @@ -1676,7 +1676,7 @@ shared_ptr DuckDBPyConnection::DefaultConnection() { PythonImportCache *DuckDBPyConnection::ImportCache() { if (!import_cache) { - import_cache = make_refcounted(); + import_cache = make_shared_ptr(); } return import_cache.get(); } @@ -1690,7 +1690,7 @@ ModifiedMemoryFileSystem &DuckDBPyConnection::GetObjectFileSystem() { throw InvalidInputException( "This operation could not be completed because required module 'fsspec' is not installed"); } - internal_object_filesystem = make_refcounted(modified_memory_fs()); + internal_object_filesystem = make_shared_ptr(modified_memory_fs()); auto &abstract_fs = reinterpret_cast(*internal_object_filesystem); RegisterFilesystem(abstract_fs); } diff --git a/tools/pythonpkg/src/pyconnection/type_creation.cpp b/tools/pythonpkg/src/pyconnection/type_creation.cpp index 91860e7f936e..1cf2e324c9e8 100644 --- a/tools/pythonpkg/src/pyconnection/type_creation.cpp +++ b/tools/pythonpkg/src/pyconnection/type_creation.cpp @@ -5,17 +5,17 @@ namespace duckdb { shared_ptr DuckDBPyConnection::MapType(const shared_ptr &key_type, const shared_ptr &value_type) { auto map_type = LogicalType::MAP(key_type->Type(), value_type->Type()); - return make_refcounted(map_type); + return make_shared_ptr(map_type); } shared_ptr DuckDBPyConnection::ListType(const shared_ptr &type) { auto array_type = LogicalType::LIST(type->Type()); - return make_refcounted(array_type); + return make_shared_ptr(array_type); } shared_ptr DuckDBPyConnection::ArrayType(const shared_ptr &type, idx_t size) { auto array_type = LogicalType::ARRAY(type->Type(), size); - return make_refcounted(array_type); + return make_shared_ptr(array_type); } static child_list_t GetChildList(const py::object &container) { @@ -59,7 +59,7 @@ shared_ptr DuckDBPyConnection::StructType(const py::object &fields throw InvalidInputException("Can not create an empty struct type!"); } auto struct_type = LogicalType::STRUCT(std::move(types)); - return make_refcounted(struct_type); + return make_shared_ptr(struct_type); } shared_ptr DuckDBPyConnection::UnionType(const py::object &members) { @@ -69,7 +69,7 @@ shared_ptr DuckDBPyConnection::UnionType(const py::object &members throw InvalidInputException("Can not create an empty union type!"); } auto union_type = LogicalType::UNION(std::move(types)); - return make_refcounted(union_type); + return make_shared_ptr(union_type); } shared_ptr DuckDBPyConnection::EnumType(const string &name, const shared_ptr &type, @@ -79,7 +79,7 @@ shared_ptr DuckDBPyConnection::EnumType(const string &name, const shared_ptr DuckDBPyConnection::DecimalType(int width, int scale) { auto decimal_type = LogicalType::DECIMAL(width, scale); - return make_refcounted(decimal_type); + return make_shared_ptr(decimal_type); } shared_ptr DuckDBPyConnection::StringType(const string &collation) { @@ -89,14 +89,14 @@ shared_ptr DuckDBPyConnection::StringType(const string &collation) } else { type = LogicalType::VARCHAR_COLLATION(collation); } - return make_refcounted(type); + return make_shared_ptr(type); } shared_ptr DuckDBPyConnection::Type(const string &type_str) { if (!connection) { throw ConnectionException("Connection already closed!"); } - return make_refcounted(TransformStringToLogicalType(type_str, *connection->context)); + return make_shared_ptr(TransformStringToLogicalType(type_str, *connection->context)); } } // namespace duckdb diff --git a/tools/pythonpkg/src/pyexpression.cpp b/tools/pythonpkg/src/pyexpression.cpp index 09031706acdc..3bc2c43ec47a 100644 --- a/tools/pythonpkg/src/pyexpression.cpp +++ b/tools/pythonpkg/src/pyexpression.cpp @@ -34,19 +34,19 @@ const ParsedExpression &DuckDBPyExpression::GetExpression() const { shared_ptr DuckDBPyExpression::Copy() const { auto expr = GetExpression().Copy(); - return make_refcounted(std::move(expr), order_type, null_order); + return make_shared_ptr(std::move(expr), order_type, null_order); } shared_ptr DuckDBPyExpression::SetAlias(const string &name) const { auto copied_expression = GetExpression().Copy(); copied_expression->alias = name; - return make_refcounted(std::move(copied_expression)); + return make_shared_ptr(std::move(copied_expression)); } shared_ptr DuckDBPyExpression::Cast(const DuckDBPyType &type) const { auto copied_expression = GetExpression().Copy(); auto case_expr = make_uniq(type.Type(), std::move(copied_expression)); - return make_refcounted(std::move(case_expr)); + return make_shared_ptr(std::move(case_expr)); } // Case Expression modifiers @@ -64,7 +64,7 @@ shared_ptr DuckDBPyExpression::InternalWhen(unique_ptrcase_checks.push_back(std::move(check)); - return make_refcounted(std::move(expr)); + return make_shared_ptr(std::move(expr)); } shared_ptr DuckDBPyExpression::When(const DuckDBPyExpression &condition, @@ -82,7 +82,7 @@ shared_ptr DuckDBPyExpression::Else(const DuckDBPyExpression auto expr = unique_ptr_cast(std::move(expr_p)); expr->else_expr = value.GetExpression().Copy(); - return make_refcounted(std::move(expr)); + return make_shared_ptr(std::move(expr)); } // Binary operators @@ -181,7 +181,7 @@ shared_ptr DuckDBPyExpression::In(const py::args &args) { expressions.push_back(std::move(expr)); } auto operator_expr = make_uniq(ExpressionType::COMPARE_IN, std::move(expressions)); - return make_refcounted(std::move(operator_expr)); + return make_shared_ptr(std::move(operator_expr)); } shared_ptr DuckDBPyExpression::NotIn(const py::args &args) { @@ -249,7 +249,7 @@ shared_ptr DuckDBPyExpression::StarExpression(const py::list case_insensitive_set_t exclude; auto star = make_uniq(); PopulateExcludeList(star->exclude_list, exclude_list); - return make_refcounted(std::move(star)); + return make_shared_ptr(std::move(star)); } shared_ptr DuckDBPyExpression::ColumnExpression(const string &column_name) { @@ -267,7 +267,7 @@ shared_ptr DuckDBPyExpression::ColumnExpression(const string } column_names.push_back(qualified_name.name); - return make_refcounted(make_uniq(std::move(column_names))); + return make_shared_ptr(make_uniq(std::move(column_names))); } shared_ptr DuckDBPyExpression::ConstantExpression(const py::object &value) { @@ -292,14 +292,14 @@ DuckDBPyExpression::InternalFunctionExpression(const string &function_name, vector> children, bool is_operator) { auto function_expression = make_uniq(function_name, std::move(children), nullptr, nullptr, false, is_operator); - return make_refcounted(std::move(function_expression)); + return make_shared_ptr(std::move(function_expression)); } shared_ptr DuckDBPyExpression::InternalUnaryOperator(ExpressionType type, const DuckDBPyExpression &arg) { auto expr = arg.GetExpression().Copy(); auto operator_expression = make_uniq(type, std::move(expr)); - return make_refcounted(std::move(operator_expression)); + return make_shared_ptr(std::move(operator_expression)); } shared_ptr DuckDBPyExpression::InternalConjunction(ExpressionType type, @@ -311,11 +311,11 @@ shared_ptr DuckDBPyExpression::InternalConjunction(Expressio children.push_back(other.GetExpression().Copy()); auto operator_expression = make_uniq(type, std::move(children)); - return make_refcounted(std::move(operator_expression)); + return make_shared_ptr(std::move(operator_expression)); } shared_ptr DuckDBPyExpression::InternalConstantExpression(Value val) { - return make_refcounted(make_uniq(std::move(val))); + return make_shared_ptr(make_uniq(std::move(val))); } shared_ptr DuckDBPyExpression::ComparisonExpression(ExpressionType type, @@ -323,7 +323,7 @@ shared_ptr DuckDBPyExpression::ComparisonExpression(Expressi const DuckDBPyExpression &right_p) { auto left = left_p.GetExpression().Copy(); auto right = right_p.GetExpression().Copy(); - return make_refcounted( + return make_shared_ptr( make_uniq(type, std::move(left), std::move(right))); } diff --git a/tools/pythonpkg/src/pyrelation.cpp b/tools/pythonpkg/src/pyrelation.cpp index 2809d2e41089..157eff1b2894 100644 --- a/tools/pythonpkg/src/pyrelation.cpp +++ b/tools/pythonpkg/src/pyrelation.cpp @@ -157,7 +157,7 @@ unique_ptr DuckDBPyRelation::EmptyResult(const shared_ptr> single_row(1, dummy_values); auto values_relation = - make_uniq(make_refcounted(context, single_row, std::move(names))); + make_uniq(make_shared_ptr(context, single_row, std::move(names))); // Add a filter on an impossible condition return values_relation->FilterFromExpression("true = false"); } @@ -1236,7 +1236,7 @@ unique_ptr DuckDBPyRelation::Query(const string &view_name, co if (statement.type == StatementType::SELECT_STATEMENT) { auto select_statement = unique_ptr_cast(std::move(parser.statements[0])); auto query_relation = - make_refcounted(rel->context.GetContext(), std::move(select_statement), "query_relation"); + make_shared_ptr(rel->context.GetContext(), std::move(select_statement), "query_relation"); return make_uniq(std::move(query_relation)); } else if (IsDescribeStatement(statement)) { auto query = PragmaShow(view_name); diff --git a/tools/pythonpkg/src/typing/pytype.cpp b/tools/pythonpkg/src/typing/pytype.cpp index 00edd97af4f9..5fa24fab9939 100644 --- a/tools/pythonpkg/src/typing/pytype.cpp +++ b/tools/pythonpkg/src/typing/pytype.cpp @@ -56,20 +56,20 @@ shared_ptr DuckDBPyType::GetAttribute(const string &name) const { for (idx_t i = 0; i < children.size(); i++) { auto &child = children[i]; if (StringUtil::CIEquals(child.first, name)) { - return make_refcounted(StructType::GetChildType(type, i)); + return make_shared_ptr(StructType::GetChildType(type, i)); } } } if (type.id() == LogicalTypeId::LIST && StringUtil::CIEquals(name, "child")) { - return make_refcounted(ListType::GetChildType(type)); + return make_shared_ptr(ListType::GetChildType(type)); } if (type.id() == LogicalTypeId::MAP) { auto is_key = StringUtil::CIEquals(name, "key"); auto is_value = StringUtil::CIEquals(name, "value"); if (is_key) { - return make_refcounted(MapType::KeyType(type)); + return make_shared_ptr(MapType::KeyType(type)); } else if (is_value) { - return make_refcounted(MapType::ValueType(type)); + return make_shared_ptr(MapType::ValueType(type)); } else { throw py::attribute_error(StringUtil::Format("Tried to get a child from a map by the name of '%s', but " "this type only has 'key' and 'value' children", @@ -313,19 +313,19 @@ void DuckDBPyType::Initialize(py::handle &m) { type_module.def_property_readonly("children", &DuckDBPyType::Children); type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { auto ltype = FromString(type_str, std::move(connection)); - return make_refcounted(ltype); + return make_shared_ptr(ltype); })); type_module.def(py::init<>([](const PyGenericAlias &obj) { auto ltype = FromGenericAlias(obj); - return make_refcounted(ltype); + return make_shared_ptr(ltype); })); type_module.def(py::init<>([](const PyUnionType &obj) { auto ltype = FromUnionType(obj); - return make_refcounted(ltype); + return make_shared_ptr(ltype); })); type_module.def(py::init<>([](const py::object &obj) { auto ltype = FromObject(obj); - return make_refcounted(ltype); + return make_shared_ptr(ltype); })); type_module.def("__getattr__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name")); type_module.def("__getitem__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name")); @@ -357,7 +357,7 @@ py::list DuckDBPyType::Children() const { py::list children; auto id = type.id(); if (id == LogicalTypeId::LIST) { - children.append(py::make_tuple("child", make_refcounted(ListType::GetChildType(type)))); + children.append(py::make_tuple("child", make_shared_ptr(ListType::GetChildType(type)))); return children; } // FIXME: where is ARRAY?? @@ -367,13 +367,13 @@ py::list DuckDBPyType::Children() const { for (idx_t i = 0; i < struct_children.size(); i++) { auto &child = struct_children[i]; children.append( - py::make_tuple(child.first, make_refcounted(StructType::GetChildType(type, i)))); + py::make_tuple(child.first, make_shared_ptr(StructType::GetChildType(type, i)))); } return children; } if (id == LogicalTypeId::MAP) { - children.append(py::make_tuple("key", make_refcounted(MapType::KeyType(type)))); - children.append(py::make_tuple("value", make_refcounted(MapType::ValueType(type)))); + children.append(py::make_tuple("key", make_shared_ptr(MapType::KeyType(type)))); + children.append(py::make_tuple("value", make_shared_ptr(MapType::ValueType(type)))); return children; } if (id == LogicalTypeId::DECIMAL) { diff --git a/tools/pythonpkg/src/typing/typing.cpp b/tools/pythonpkg/src/typing/typing.cpp index 8064acfc22e1..c0e2675ecf2d 100644 --- a/tools/pythonpkg/src/typing/typing.cpp +++ b/tools/pythonpkg/src/typing/typing.cpp @@ -4,38 +4,38 @@ namespace duckdb { static void DefineBaseTypes(py::handle &m) { - m.attr("SQLNULL") = make_refcounted(LogicalType::SQLNULL); - m.attr("BOOLEAN") = make_refcounted(LogicalType::BOOLEAN); - m.attr("TINYINT") = make_refcounted(LogicalType::TINYINT); - m.attr("UTINYINT") = make_refcounted(LogicalType::UTINYINT); - m.attr("SMALLINT") = make_refcounted(LogicalType::SMALLINT); - m.attr("USMALLINT") = make_refcounted(LogicalType::USMALLINT); - m.attr("INTEGER") = make_refcounted(LogicalType::INTEGER); - m.attr("UINTEGER") = make_refcounted(LogicalType::UINTEGER); - m.attr("BIGINT") = make_refcounted(LogicalType::BIGINT); - m.attr("UBIGINT") = make_refcounted(LogicalType::UBIGINT); - m.attr("HUGEINT") = make_refcounted(LogicalType::HUGEINT); - m.attr("UHUGEINT") = make_refcounted(LogicalType::UHUGEINT); - m.attr("UUID") = make_refcounted(LogicalType::UUID); - m.attr("FLOAT") = make_refcounted(LogicalType::FLOAT); - m.attr("DOUBLE") = make_refcounted(LogicalType::DOUBLE); - m.attr("DATE") = make_refcounted(LogicalType::DATE); - - m.attr("TIMESTAMP") = make_refcounted(LogicalType::TIMESTAMP); - m.attr("TIMESTAMP_MS") = make_refcounted(LogicalType::TIMESTAMP_MS); - m.attr("TIMESTAMP_NS") = make_refcounted(LogicalType::TIMESTAMP_NS); - m.attr("TIMESTAMP_S") = make_refcounted(LogicalType::TIMESTAMP_S); - - m.attr("TIME") = make_refcounted(LogicalType::TIME); - - m.attr("TIME_TZ") = make_refcounted(LogicalType::TIME_TZ); - m.attr("TIMESTAMP_TZ") = make_refcounted(LogicalType::TIMESTAMP_TZ); - - m.attr("VARCHAR") = make_refcounted(LogicalType::VARCHAR); - - m.attr("BLOB") = make_refcounted(LogicalType::BLOB); - m.attr("BIT") = make_refcounted(LogicalType::BIT); - m.attr("INTERVAL") = make_refcounted(LogicalType::INTERVAL); + m.attr("SQLNULL") = make_shared_ptr(LogicalType::SQLNULL); + m.attr("BOOLEAN") = make_shared_ptr(LogicalType::BOOLEAN); + m.attr("TINYINT") = make_shared_ptr(LogicalType::TINYINT); + m.attr("UTINYINT") = make_shared_ptr(LogicalType::UTINYINT); + m.attr("SMALLINT") = make_shared_ptr(LogicalType::SMALLINT); + m.attr("USMALLINT") = make_shared_ptr(LogicalType::USMALLINT); + m.attr("INTEGER") = make_shared_ptr(LogicalType::INTEGER); + m.attr("UINTEGER") = make_shared_ptr(LogicalType::UINTEGER); + m.attr("BIGINT") = make_shared_ptr(LogicalType::BIGINT); + m.attr("UBIGINT") = make_shared_ptr(LogicalType::UBIGINT); + m.attr("HUGEINT") = make_shared_ptr(LogicalType::HUGEINT); + m.attr("UHUGEINT") = make_shared_ptr(LogicalType::UHUGEINT); + m.attr("UUID") = make_shared_ptr(LogicalType::UUID); + m.attr("FLOAT") = make_shared_ptr(LogicalType::FLOAT); + m.attr("DOUBLE") = make_shared_ptr(LogicalType::DOUBLE); + m.attr("DATE") = make_shared_ptr(LogicalType::DATE); + + m.attr("TIMESTAMP") = make_shared_ptr(LogicalType::TIMESTAMP); + m.attr("TIMESTAMP_MS") = make_shared_ptr(LogicalType::TIMESTAMP_MS); + m.attr("TIMESTAMP_NS") = make_shared_ptr(LogicalType::TIMESTAMP_NS); + m.attr("TIMESTAMP_S") = make_shared_ptr(LogicalType::TIMESTAMP_S); + + m.attr("TIME") = make_shared_ptr(LogicalType::TIME); + + m.attr("TIME_TZ") = make_shared_ptr(LogicalType::TIME_TZ); + m.attr("TIMESTAMP_TZ") = make_shared_ptr(LogicalType::TIMESTAMP_TZ); + + m.attr("VARCHAR") = make_shared_ptr(LogicalType::VARCHAR); + + m.attr("BLOB") = make_shared_ptr(LogicalType::BLOB); + m.attr("BIT") = make_shared_ptr(LogicalType::BIT); + m.attr("INTERVAL") = make_shared_ptr(LogicalType::INTERVAL); } void DuckDBPyTyping::Initialize(py::module_ &parent) { From 8e5178e1266d738d04dc9cf2bba66dd57ffd6e8d Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 12 Apr 2024 11:11:24 +0200 Subject: [PATCH 103/201] add patch for azure --- .../patches/extensions/azure/shared_ptr.patch | 400 ++++++++++++++++++ 1 file changed, 400 insertions(+) create mode 100644 .github/patches/extensions/azure/shared_ptr.patch diff --git a/.github/patches/extensions/azure/shared_ptr.patch b/.github/patches/extensions/azure/shared_ptr.patch new file mode 100644 index 000000000000..bee19cfb4b8a --- /dev/null +++ b/.github/patches/extensions/azure/shared_ptr.patch @@ -0,0 +1,400 @@ +diff --git a/src/azure_blob_filesystem.cpp b/src/azure_blob_filesystem.cpp +index bc34eb9..42f9323 100644 +--- a/src/azure_blob_filesystem.cpp ++++ b/src/azure_blob_filesystem.cpp +@@ -3,6 +3,8 @@ + #include "azure_storage_account_client.hpp" + #include "duckdb.hpp" + #include "duckdb/common/exception.hpp" ++#include "duckdb/common/helper.hpp" ++#include "duckdb/common/shared_ptr.hpp" + #include "duckdb/common/http_state.hpp" + #include "duckdb/common/file_opener.hpp" + #include "duckdb/common/string_util.hpp" +@@ -201,13 +203,13 @@ void AzureBlobStorageFileSystem::ReadRange(AzureFileHandle &handle, idx_t file_o + } + } + +-std::shared_ptr AzureBlobStorageFileSystem::CreateStorageContext(optional_ptr opener, +- const string &path, +- const AzureParsedUrl &parsed_url) { ++shared_ptr AzureBlobStorageFileSystem::CreateStorageContext(optional_ptr opener, ++ const string &path, ++ const AzureParsedUrl &parsed_url) { + auto azure_read_options = ParseAzureReadOptions(opener); + +- return std::make_shared(ConnectToBlobStorageAccount(opener, path, parsed_url), +- azure_read_options); ++ return make_shared_ptr(ConnectToBlobStorageAccount(opener, path, parsed_url), ++ azure_read_options); + } + + } // namespace duckdb +diff --git a/src/azure_dfs_filesystem.cpp b/src/azure_dfs_filesystem.cpp +index 5ccbed0..739078c 100644 +--- a/src/azure_dfs_filesystem.cpp ++++ b/src/azure_dfs_filesystem.cpp +@@ -1,6 +1,8 @@ + #include "azure_dfs_filesystem.hpp" + #include "azure_storage_account_client.hpp" + #include "duckdb/common/exception.hpp" ++#include "duckdb/common/helper.hpp" ++#include "duckdb/common/shared_ptr.hpp" + #include "duckdb/function/scalar/string_functions.hpp" + #include + #include +@@ -185,13 +187,13 @@ void AzureDfsStorageFileSystem::ReadRange(AzureFileHandle &handle, idx_t file_of + } + } + +-std::shared_ptr AzureDfsStorageFileSystem::CreateStorageContext(optional_ptr opener, +- const string &path, +- const AzureParsedUrl &parsed_url) { ++shared_ptr AzureDfsStorageFileSystem::CreateStorageContext(optional_ptr opener, ++ const string &path, ++ const AzureParsedUrl &parsed_url) { + auto azure_read_options = ParseAzureReadOptions(opener); + +- return std::make_shared(ConnectToDfsStorageAccount(opener, path, parsed_url), +- azure_read_options); ++ return make_shared_ptr(ConnectToDfsStorageAccount(opener, path, parsed_url), ++ azure_read_options); + } + + } // namespace duckdb +diff --git a/src/azure_filesystem.cpp b/src/azure_filesystem.cpp +index bbf5275..6175421 100644 +--- a/src/azure_filesystem.cpp ++++ b/src/azure_filesystem.cpp +@@ -1,5 +1,6 @@ + #include "azure_filesystem.hpp" + #include "duckdb/common/exception.hpp" ++#include "duckdb/common/shared_ptr.hpp" + #include "duckdb/common/types/value.hpp" + #include "duckdb/main/client_context.hpp" + #include +@@ -53,8 +54,8 @@ void AzureStorageFileSystem::LoadFileInfo(AzureFileHandle &handle) { + } + } + +-unique_ptr AzureStorageFileSystem::OpenFile(const string &path,FileOpenFlags flags, +- optional_ptr opener) { ++unique_ptr AzureStorageFileSystem::OpenFile(const string &path, FileOpenFlags flags, ++ optional_ptr opener) { + D_ASSERT(flags.Compression() == FileCompressionType::UNCOMPRESSED); + + if (flags.OpenForWriting()) { +@@ -153,16 +154,16 @@ int64_t AzureStorageFileSystem::Read(FileHandle &handle, void *buffer, int64_t n + return nr_bytes; + } + +-std::shared_ptr AzureStorageFileSystem::GetOrCreateStorageContext(optional_ptr opener, +- const string &path, +- const AzureParsedUrl &parsed_url) { ++shared_ptr AzureStorageFileSystem::GetOrCreateStorageContext(optional_ptr opener, ++ const string &path, ++ const AzureParsedUrl &parsed_url) { + Value value; + bool azure_context_caching = true; + if (FileOpener::TryGetCurrentSetting(opener, "azure_context_caching", value)) { + azure_context_caching = value.GetValue(); + } + +- std::shared_ptr result; ++ shared_ptr result; + if (azure_context_caching) { + auto client_context = FileOpener::TryGetClientContext(opener); + +@@ -182,7 +183,7 @@ std::shared_ptr AzureStorageFileSystem::GetOrCreateStorageCon + result = CreateStorageContext(opener, path, parsed_url); + registered_state[context_key] = result; + } else { +- result = std::shared_ptr(storage_account_it->second, azure_context_state); ++ result = shared_ptr(storage_account_it->second, azure_context_state); + } + } + } else { +diff --git a/src/azure_storage_account_client.cpp b/src/azure_storage_account_client.cpp +index 5a22e60..11ad859 100644 +--- a/src/azure_storage_account_client.cpp ++++ b/src/azure_storage_account_client.cpp +@@ -3,6 +3,8 @@ + #include "duckdb/catalog/catalog_transaction.hpp" + #include "duckdb/common/enums/statement_type.hpp" + #include "duckdb/common/exception.hpp" ++#include "duckdb/common/shared_ptr.hpp" ++#include "duckdb/common/helper.hpp" + #include "duckdb/common/file_opener.hpp" + #include "duckdb/common/string_util.hpp" + #include "duckdb/main/client_context.hpp" +@@ -75,12 +77,12 @@ static std::string AccountUrl(const AzureParsedUrl &azure_parsed_url) { + + template + static T ToClientOptions(const Azure::Core::Http::Policies::TransportOptions &transport_options, +- std::shared_ptr http_state) { ++ shared_ptr http_state) { + static_assert(std::is_base_of::value, + "type parameter must be an Azure ClientOptions"); + T options; + options.Transport = transport_options; +- if (nullptr != http_state) { ++ if (http_state != nullptr) { + // Because we mainly want to have stats on what has been needed and not on + // what has been used on the network, we register the policy on `PerOperationPolicies` + // part and not the `PerRetryPolicies`. Network issues will result in retry that can +@@ -92,13 +94,13 @@ static T ToClientOptions(const Azure::Core::Http::Policies::TransportOptions &tr + + static Azure::Storage::Blobs::BlobClientOptions + ToBlobClientOptions(const Azure::Core::Http::Policies::TransportOptions &transport_options, +- std::shared_ptr http_state) { ++ shared_ptr http_state) { + return ToClientOptions(transport_options, std::move(http_state)); + } + + static Azure::Storage::Files::DataLake::DataLakeClientOptions + ToDfsClientOptions(const Azure::Core::Http::Policies::TransportOptions &transport_options, +- std::shared_ptr http_state) { ++ shared_ptr http_state) { + return ToClientOptions(transport_options, + std::move(http_state)); + } +@@ -110,14 +112,14 @@ ToTokenCredentialOptions(const Azure::Core::Http::Policies::TransportOptions &tr + return options; + } + +-static std::shared_ptr GetHttpState(optional_ptr opener) { ++static shared_ptr GetHttpState(optional_ptr opener) { + Value value; + bool enable_http_stats = false; + if (FileOpener::TryGetCurrentSetting(opener, "azure_http_stats", value)) { + enable_http_stats = value.GetValue(); + } + +- std::shared_ptr http_state; ++ shared_ptr http_state; + if (enable_http_stats) { + http_state = HTTPState::TryGetState(opener); + } +@@ -168,7 +170,7 @@ CreateClientCredential(const std::string &tenant_id, const std::string &client_i + auto credential_options = ToTokenCredentialOptions(transport_options); + if (!client_secret.empty()) { + return std::make_shared(tenant_id, client_id, client_secret, +- credential_options); ++ credential_options); + } else if (!client_certificate_path.empty()) { + return std::make_shared( + tenant_id, client_id, client_certificate_path, credential_options); +@@ -408,8 +410,9 @@ GetDfsStorageAccountClientFromServicePrincipalProvider(optional_ptr + return Azure::Storage::Files::DataLake::DataLakeServiceClient(account_url, token_credential, dfs_options); + } + +-static Azure::Storage::Blobs::BlobServiceClient +-GetBlobStorageAccountClient(optional_ptr opener, const KeyValueSecret &secret, const AzureParsedUrl &azure_parsed_url) { ++static Azure::Storage::Blobs::BlobServiceClient GetBlobStorageAccountClient(optional_ptr opener, ++ const KeyValueSecret &secret, ++ const AzureParsedUrl &azure_parsed_url) { + auto &provider = secret.GetProvider(); + // default provider + if (provider == "config") { +@@ -424,7 +427,8 @@ GetBlobStorageAccountClient(optional_ptr opener, const KeyValueSecre + } + + static Azure::Storage::Files::DataLake::DataLakeServiceClient +-GetDfsStorageAccountClient(optional_ptr opener, const KeyValueSecret &secret, const AzureParsedUrl &azure_parsed_url) { ++GetDfsStorageAccountClient(optional_ptr opener, const KeyValueSecret &secret, ++ const AzureParsedUrl &azure_parsed_url) { + auto &provider = secret.GetProvider(); + // default provider + if (provider == "config") { +@@ -505,7 +509,8 @@ const SecretMatch LookupSecret(optional_ptr opener, const std::strin + return {}; + } + +-Azure::Storage::Blobs::BlobServiceClient ConnectToBlobStorageAccount(optional_ptr opener, const std::string &path, ++Azure::Storage::Blobs::BlobServiceClient ConnectToBlobStorageAccount(optional_ptr opener, ++ const std::string &path, + const AzureParsedUrl &azure_parsed_url) { + + auto secret_match = LookupSecret(opener, path); +@@ -519,7 +524,8 @@ Azure::Storage::Blobs::BlobServiceClient ConnectToBlobStorageAccount(optional_pt + } + + Azure::Storage::Files::DataLake::DataLakeServiceClient +-ConnectToDfsStorageAccount(optional_ptr opener, const std::string &path, const AzureParsedUrl &azure_parsed_url) { ++ConnectToDfsStorageAccount(optional_ptr opener, const std::string &path, ++ const AzureParsedUrl &azure_parsed_url) { + auto secret_match = LookupSecret(opener, path); + if (secret_match.HasMatch()) { + const auto &base_secret = secret_match.GetSecret(); +diff --git a/src/http_state_policy.cpp b/src/http_state_policy.cpp +index be2d61d..baa3f97 100644 +--- a/src/http_state_policy.cpp ++++ b/src/http_state_policy.cpp +@@ -1,5 +1,6 @@ + #include "http_state_policy.hpp" + #include ++#include "duckdb/common/shared_ptr.hpp" + #include + #include + #include +@@ -8,7 +9,7 @@ const static std::string CONTENT_LENGTH = "content-length"; + + namespace duckdb { + +-HttpStatePolicy::HttpStatePolicy(std::shared_ptr http_state) : http_state(std::move(http_state)) { ++HttpStatePolicy::HttpStatePolicy(shared_ptr http_state) : http_state(std::move(http_state)) { + } + + std::unique_ptr +@@ -34,12 +35,12 @@ HttpStatePolicy::Send(Azure::Core::Http::Request &request, Azure::Core::Http::Po + } + + const auto *body_stream = request.GetBodyStream(); +- if (nullptr != body_stream) { ++ if (body_stream != nullptr) { + http_state->total_bytes_sent += body_stream->Length(); + } + + auto result = next_policy.Send(request, context); +- if (nullptr != result) { ++ if (result != nullptr) { + const auto &response_body = result->GetBody(); + if (response_body.size() != 0) { + http_state->total_bytes_received += response_body.size(); +diff --git a/src/include/azure_blob_filesystem.hpp b/src/include/azure_blob_filesystem.hpp +index 4d10ebe..638864a 100644 +--- a/src/include/azure_blob_filesystem.hpp ++++ b/src/include/azure_blob_filesystem.hpp +@@ -1,6 +1,8 @@ + #pragma once + + #include "duckdb.hpp" ++#include "duckdb/common/shared_ptr.hpp" ++#include "duckdb/common/unique_ptr.hpp" + #include "azure_parsed_url.hpp" + #include "azure_filesystem.hpp" + #include +@@ -57,10 +59,10 @@ protected: + const string &GetContextPrefix() const override { + return PATH_PREFIX; + } +- std::shared_ptr CreateStorageContext(optional_ptr opener, const string &path, +- const AzureParsedUrl &parsed_url) override; +- duckdb::unique_ptr CreateHandle(const string &path, FileOpenFlags flags, +- optional_ptr opener) override; ++ shared_ptr CreateStorageContext(optional_ptr opener, const string &path, ++ const AzureParsedUrl &parsed_url) override; ++ unique_ptr CreateHandle(const string &path, FileOpenFlags flags, ++ optional_ptr opener) override; + + void ReadRange(AzureFileHandle &handle, idx_t file_offset, char *buffer_out, idx_t buffer_out_len) override; + }; +diff --git a/src/include/azure_dfs_filesystem.hpp b/src/include/azure_dfs_filesystem.hpp +index cdcdb23..8f35dc7 100644 +--- a/src/include/azure_dfs_filesystem.hpp ++++ b/src/include/azure_dfs_filesystem.hpp +@@ -2,6 +2,8 @@ + + #include "azure_filesystem.hpp" + #include "duckdb/common/file_opener.hpp" ++#include "duckdb/common/shared_ptr.hpp" ++#include "duckdb/common/unique_ptr.hpp" + #include + #include + #include +@@ -55,10 +57,10 @@ protected: + const string &GetContextPrefix() const override { + return PATH_PREFIX; + } +- std::shared_ptr CreateStorageContext(optional_ptr opener, const string &path, +- const AzureParsedUrl &parsed_url) override; +- duckdb::unique_ptr CreateHandle(const string &path, FileOpenFlags flags, +- optional_ptr opener) override; ++ shared_ptr CreateStorageContext(optional_ptr opener, const string &path, ++ const AzureParsedUrl &parsed_url) override; ++ unique_ptr CreateHandle(const string &path, FileOpenFlags flags, ++ optional_ptr opener) override; + + void ReadRange(AzureFileHandle &handle, idx_t file_offset, char *buffer_out, idx_t buffer_out_len) override; + }; +diff --git a/src/include/azure_filesystem.hpp b/src/include/azure_filesystem.hpp +index 338c744..02ed44b 100644 +--- a/src/include/azure_filesystem.hpp ++++ b/src/include/azure_filesystem.hpp +@@ -3,6 +3,7 @@ + #include "azure_parsed_url.hpp" + #include "duckdb/common/assert.hpp" + #include "duckdb/common/file_opener.hpp" ++#include "duckdb/common/shared_ptr.hpp" + #include "duckdb/common/file_system.hpp" + #include "duckdb/main/client_context_state.hpp" + #include +@@ -99,14 +100,14 @@ public: + + protected: + virtual duckdb::unique_ptr CreateHandle(const string &path, FileOpenFlags flags, +- optional_ptr opener) = 0; ++ optional_ptr opener) = 0; + virtual void ReadRange(AzureFileHandle &handle, idx_t file_offset, char *buffer_out, idx_t buffer_out_len) = 0; + + virtual const string &GetContextPrefix() const = 0; +- std::shared_ptr GetOrCreateStorageContext(optional_ptr opener, const string &path, +- const AzureParsedUrl &parsed_url); +- virtual std::shared_ptr CreateStorageContext(optional_ptr opener, const string &path, +- const AzureParsedUrl &parsed_url) = 0; ++ shared_ptr GetOrCreateStorageContext(optional_ptr opener, const string &path, ++ const AzureParsedUrl &parsed_url); ++ virtual shared_ptr CreateStorageContext(optional_ptr opener, const string &path, ++ const AzureParsedUrl &parsed_url) = 0; + + virtual void LoadRemoteFileInfo(AzureFileHandle &handle) = 0; + static AzureReadOptions ParseAzureReadOptions(optional_ptr opener); +diff --git a/src/include/azure_storage_account_client.hpp b/src/include/azure_storage_account_client.hpp +index 600fa10..aa9a6e5 100644 +--- a/src/include/azure_storage_account_client.hpp ++++ b/src/include/azure_storage_account_client.hpp +@@ -8,10 +8,12 @@ + + namespace duckdb { + +-Azure::Storage::Blobs::BlobServiceClient ConnectToBlobStorageAccount(optional_ptr opener, const std::string &path, ++Azure::Storage::Blobs::BlobServiceClient ConnectToBlobStorageAccount(optional_ptr opener, ++ const std::string &path, + const AzureParsedUrl &azure_parsed_url); + + Azure::Storage::Files::DataLake::DataLakeServiceClient +-ConnectToDfsStorageAccount(optional_ptr opener, const std::string &path, const AzureParsedUrl &azure_parsed_url); ++ConnectToDfsStorageAccount(optional_ptr opener, const std::string &path, ++ const AzureParsedUrl &azure_parsed_url); + + } // namespace duckdb +diff --git a/src/include/http_state_policy.hpp b/src/include/http_state_policy.hpp +index 310b9c3..9db73b6 100644 +--- a/src/include/http_state_policy.hpp ++++ b/src/include/http_state_policy.hpp +@@ -1,6 +1,7 @@ + #pragma once + + #include "duckdb/common/http_state.hpp" ++#include "duckdb/common/shared_ptr.hpp" + #include + #include + #include +@@ -11,7 +12,7 @@ namespace duckdb { + + class HttpStatePolicy : public Azure::Core::Http::Policies::HttpPolicy { + public: +- HttpStatePolicy(std::shared_ptr http_state); ++ HttpStatePolicy(shared_ptr http_state); + + std::unique_ptr Send(Azure::Core::Http::Request &request, + Azure::Core::Http::Policies::NextHttpPolicy next_policy, +@@ -20,7 +21,7 @@ public: + std::unique_ptr Clone() const override; + + private: +- std::shared_ptr http_state; ++ shared_ptr http_state; + }; + + } // namespace duckdb From 5e204ab3761d67e7bb15c3624a426176520f1358 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 12 Apr 2024 11:19:22 +0200 Subject: [PATCH 104/201] add patch for aws --- .github/config/out_of_tree_extensions.cmake | 1 + .../patches/extensions/aws/shared_ptr.patch | 37 +++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 .github/patches/extensions/aws/shared_ptr.patch diff --git a/.github/config/out_of_tree_extensions.cmake b/.github/config/out_of_tree_extensions.cmake index 0f919af9b1e8..d17b2e7ce508 100644 --- a/.github/config/out_of_tree_extensions.cmake +++ b/.github/config/out_of_tree_extensions.cmake @@ -28,6 +28,7 @@ if (NOT MINGW) LOAD_TESTS GIT_URL https://github.com/duckdb/duckdb_aws GIT_TAG f7b8729f1cce5ada5d4add70e1486de50763fb97 + APPLY_PATCHES ) endif() diff --git a/.github/patches/extensions/aws/shared_ptr.patch b/.github/patches/extensions/aws/shared_ptr.patch new file mode 100644 index 000000000000..7cd90185e268 --- /dev/null +++ b/.github/patches/extensions/aws/shared_ptr.patch @@ -0,0 +1,37 @@ +diff --git a/src/aws_secret.cpp b/src/aws_secret.cpp +index 75062b9..179cc29 100644 +--- a/src/aws_secret.cpp ++++ b/src/aws_secret.cpp +@@ -40,24 +40,24 @@ public: + + for (const auto &item : chain_list) { + if (item == "sts") { +- AddProvider(make_shared()); ++ AddProvider(std::make_shared()); + } else if (item == "sso") { + if (profile.empty()) { +- AddProvider(make_shared()); ++ AddProvider(std::make_shared()); + } else { +- AddProvider(make_shared(profile)); ++ AddProvider(std::make_shared(profile)); + } + } else if (item == "env") { +- AddProvider(make_shared()); ++ AddProvider(std::make_shared()); + } else if (item == "instance") { +- AddProvider(make_shared()); ++ AddProvider(std::make_shared()); + } else if (item == "process") { +- AddProvider(make_shared()); ++ AddProvider(std::make_shared()); + } else if (item == "config") { + if (profile.empty()) { +- AddProvider(make_shared()); ++ AddProvider(std::make_shared()); + } else { +- AddProvider(make_shared(profile.c_str())); ++ AddProvider(std::make_shared(profile.c_str())); + } + } else { + throw InvalidInputException("Unknown provider found while parsing AWS credential chain string: '%s'", From 51f89479c773642c87b859553307571347de9b15 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 12 Apr 2024 12:23:40 +0200 Subject: [PATCH 105/201] add patch for postgres_scanner --- .github/config/out_of_tree_extensions.cmake | 1 + .../postgres_scanner/shared_ptr.patch | 214 ++++++++++++++++++ 2 files changed, 215 insertions(+) create mode 100644 .github/patches/extensions/postgres_scanner/shared_ptr.patch diff --git a/.github/config/out_of_tree_extensions.cmake b/.github/config/out_of_tree_extensions.cmake index d17b2e7ce508..a96f35e1d7be 100644 --- a/.github/config/out_of_tree_extensions.cmake +++ b/.github/config/out_of_tree_extensions.cmake @@ -67,6 +67,7 @@ if (NOT MINGW) DONT_LINK GIT_URL https://github.com/duckdb/postgres_scanner GIT_TAG 96206f41d5ca7015920a66b54e936c986fe0b0f8 + APPLY_PATCHES ) endif() diff --git a/.github/patches/extensions/postgres_scanner/shared_ptr.patch b/.github/patches/extensions/postgres_scanner/shared_ptr.patch new file mode 100644 index 000000000000..98d351393f23 --- /dev/null +++ b/.github/patches/extensions/postgres_scanner/shared_ptr.patch @@ -0,0 +1,214 @@ +diff --git a/src/include/postgres_binary_copy.hpp b/src/include/postgres_binary_copy.hpp +index 4d49b00..f69e4fa 100644 +--- a/src/include/postgres_binary_copy.hpp ++++ b/src/include/postgres_binary_copy.hpp +@@ -19,18 +19,19 @@ public: + PostgresBinaryCopyFunction(); + + static unique_ptr PostgresBinaryWriteBind(ClientContext &context, CopyFunctionBindInput &input, +- const vector &names, const vector &sql_types); +- +- static unique_ptr PostgresBinaryWriteInitializeGlobal(ClientContext &context, FunctionData &bind_data, +- const string &file_path); +- static unique_ptr PostgresBinaryWriteInitializeLocal(ExecutionContext &context, FunctionData &bind_data_p); +- static void PostgresBinaryWriteSink(ExecutionContext &context, FunctionData &bind_data_p, GlobalFunctionData &gstate, +- LocalFunctionData &lstate, DataChunk &input); +- static void PostgresBinaryWriteCombine(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, +- LocalFunctionData &lstate); +- static void PostgresBinaryWriteFinalize(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate); ++ const vector &names, ++ const vector &sql_types); ++ ++ static unique_ptr ++ PostgresBinaryWriteInitializeGlobal(ClientContext &context, FunctionData &bind_data, const string &file_path); ++ static unique_ptr PostgresBinaryWriteInitializeLocal(ExecutionContext &context, ++ FunctionData &bind_data_p); ++ static void PostgresBinaryWriteSink(ExecutionContext &context, FunctionData &bind_data_p, ++ GlobalFunctionData &gstate, LocalFunctionData &lstate, DataChunk &input); ++ static void PostgresBinaryWriteCombine(ExecutionContext &context, FunctionData &bind_data, ++ GlobalFunctionData &gstate, LocalFunctionData &lstate); ++ static void PostgresBinaryWriteFinalize(ClientContext &context, FunctionData &bind_data, ++ GlobalFunctionData &gstate); + }; + +- +- + } // namespace duckdb +diff --git a/src/include/postgres_connection.hpp b/src/include/postgres_connection.hpp +index e4595e2..08fd19c 100644 +--- a/src/include/postgres_connection.hpp ++++ b/src/include/postgres_connection.hpp +@@ -10,6 +10,7 @@ + + #include "postgres_utils.hpp" + #include "postgres_result.hpp" ++#include "duckdb/common/shared_ptr.hpp" + + namespace duckdb { + class PostgresBinaryWriter; +diff --git a/src/include/storage/postgres_catalog_set.hpp b/src/include/storage/postgres_catalog_set.hpp +index e20a803..4fe45f6 100644 +--- a/src/include/storage/postgres_catalog_set.hpp ++++ b/src/include/storage/postgres_catalog_set.hpp +@@ -11,6 +11,7 @@ + #include "duckdb/transaction/transaction.hpp" + #include "duckdb/common/case_insensitive_map.hpp" + #include "duckdb/common/mutex.hpp" ++#include "duckdb/common/shared_ptr.hpp" + + namespace duckdb { + struct DropInfo; +diff --git a/src/postgres_binary_copy.cpp b/src/postgres_binary_copy.cpp +index f0d86a3..4c89c3f 100644 +--- a/src/postgres_binary_copy.cpp ++++ b/src/postgres_binary_copy.cpp +@@ -5,8 +5,7 @@ + + namespace duckdb { + +-PostgresBinaryCopyFunction::PostgresBinaryCopyFunction() : +- CopyFunction("postgres_binary") { ++PostgresBinaryCopyFunction::PostgresBinaryCopyFunction() : CopyFunction("postgres_binary") { + + copy_to_bind = PostgresBinaryWriteBind; + copy_to_initialize_global = PostgresBinaryWriteInitializeGlobal; +@@ -54,16 +53,18 @@ struct PostgresBinaryCopyGlobalState : public GlobalFunctionData { + } + }; + +-struct PostgresBinaryWriteBindData : public TableFunctionData { +-}; ++struct PostgresBinaryWriteBindData : public TableFunctionData {}; + +-unique_ptr PostgresBinaryCopyFunction::PostgresBinaryWriteBind(ClientContext &context, CopyFunctionBindInput &input, +- const vector &names, const vector &sql_types) { ++unique_ptr PostgresBinaryCopyFunction::PostgresBinaryWriteBind(ClientContext &context, ++ CopyFunctionBindInput &input, ++ const vector &names, ++ const vector &sql_types) { + return make_uniq(); + } + +-unique_ptr PostgresBinaryCopyFunction::PostgresBinaryWriteInitializeGlobal(ClientContext &context, FunctionData &bind_data, +- const string &file_path) { ++unique_ptr ++PostgresBinaryCopyFunction::PostgresBinaryWriteInitializeGlobal(ClientContext &context, FunctionData &bind_data, ++ const string &file_path) { + auto result = make_uniq(); + auto &fs = FileSystem::GetFileSystem(context); + result->file_writer = make_uniq(fs, file_path); +@@ -72,25 +73,27 @@ unique_ptr PostgresBinaryCopyFunction::PostgresBinaryWriteIn + return std::move(result); + } + +-unique_ptr PostgresBinaryCopyFunction::PostgresBinaryWriteInitializeLocal(ExecutionContext &context, FunctionData &bind_data_p) { ++unique_ptr ++PostgresBinaryCopyFunction::PostgresBinaryWriteInitializeLocal(ExecutionContext &context, FunctionData &bind_data_p) { + return make_uniq(); + } + +-void PostgresBinaryCopyFunction::PostgresBinaryWriteSink(ExecutionContext &context, FunctionData &bind_data_p, GlobalFunctionData &gstate_p, +- LocalFunctionData &lstate, DataChunk &input) { ++void PostgresBinaryCopyFunction::PostgresBinaryWriteSink(ExecutionContext &context, FunctionData &bind_data_p, ++ GlobalFunctionData &gstate_p, LocalFunctionData &lstate, ++ DataChunk &input) { + auto &gstate = gstate_p.Cast(); + gstate.WriteChunk(input); + } + +-void PostgresBinaryCopyFunction::PostgresBinaryWriteCombine(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, +- LocalFunctionData &lstate) { ++void PostgresBinaryCopyFunction::PostgresBinaryWriteCombine(ExecutionContext &context, FunctionData &bind_data, ++ GlobalFunctionData &gstate, LocalFunctionData &lstate) { + } + +-void PostgresBinaryCopyFunction::PostgresBinaryWriteFinalize(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate_p) { ++void PostgresBinaryCopyFunction::PostgresBinaryWriteFinalize(ClientContext &context, FunctionData &bind_data, ++ GlobalFunctionData &gstate_p) { + auto &gstate = gstate_p.Cast(); + // write the footer and close the file + gstate.Flush(); + } + +- +-} +\ No newline at end of file ++} // namespace duckdb +\ No newline at end of file +diff --git a/src/postgres_connection.cpp b/src/postgres_connection.cpp +index 6372055..18fbd77 100644 +--- a/src/postgres_connection.cpp ++++ b/src/postgres_connection.cpp +@@ -3,6 +3,8 @@ + #include "duckdb/parser/parser.hpp" + #include "postgres_connection.hpp" + #include "duckdb/common/types/uuid.hpp" ++#include "duckdb/common/shared_ptr.hpp" ++#include "duckdb/common/helper.hpp" + + namespace duckdb { + +@@ -40,7 +42,7 @@ PostgresConnection &PostgresConnection::operator=(PostgresConnection &&other) no + + PostgresConnection PostgresConnection::Open(const string &connection_string) { + PostgresConnection result; +- result.connection = make_shared(PostgresUtils::PGConnect(connection_string)); ++ result.connection = make_shared_ptr(PostgresUtils::PGConnect(connection_string)); + result.dsn = connection_string; + return result; + } +diff --git a/src/postgres_extension.cpp b/src/postgres_extension.cpp +index 34d46d0..95988f2 100644 +--- a/src/postgres_extension.cpp ++++ b/src/postgres_extension.cpp +@@ -9,6 +9,7 @@ + #include "duckdb/catalog/catalog.hpp" + #include "duckdb/parser/parsed_data/create_table_function_info.hpp" + #include "duckdb/main/extension_util.hpp" ++#include "duckdb/common/helper.hpp" + #include "duckdb/main/database_manager.hpp" + #include "duckdb/main/attached_database.hpp" + #include "storage/postgres_catalog.hpp" +@@ -47,7 +48,7 @@ public: + class PostgresExtensionCallback : public ExtensionCallback { + public: + void OnConnectionOpened(ClientContext &context) override { +- context.registered_state.insert(make_pair("postgres_extension", make_shared())); ++ context.registered_state.insert(make_pair("postgres_extension", make_shared_ptr())); + } + }; + +@@ -123,7 +124,7 @@ static void LoadInternal(DatabaseInstance &db) { + + config.extension_callbacks.push_back(make_uniq()); + for (auto &connection : ConnectionManager::Get(db).GetConnectionList()) { +- connection->registered_state.insert(make_pair("postgres_extension", make_shared())); ++ connection->registered_state.insert(make_pair("postgres_extension", make_shared_ptr())); + } + } + +diff --git a/src/postgres_scanner.cpp b/src/postgres_scanner.cpp +index 449df0b..75c029f 100644 +--- a/src/postgres_scanner.cpp ++++ b/src/postgres_scanner.cpp +@@ -3,6 +3,8 @@ + #include + + #include "duckdb/main/extension_util.hpp" ++#include "duckdb/common/shared_ptr.hpp" ++#include "duckdb/common/helper.hpp" + #include "duckdb/parser/parsed_data/create_table_function_info.hpp" + #include "postgres_filter_pushdown.hpp" + #include "postgres_scanner.hpp" +diff --git a/src/storage/postgres_schema_set.cpp b/src/storage/postgres_schema_set.cpp +index 93c3f28..cd3b46f 100644 +--- a/src/storage/postgres_schema_set.cpp ++++ b/src/storage/postgres_schema_set.cpp +@@ -6,6 +6,7 @@ + #include "duckdb/parser/parsed_data/create_schema_info.hpp" + #include "storage/postgres_table_set.hpp" + #include "storage/postgres_catalog.hpp" ++#include "duckdb/common/shared_ptr.hpp" + + namespace duckdb { + From dd40c1c7392723a3f8b001bcea9aa193be8d99ed Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 12 Apr 2024 12:46:26 +0200 Subject: [PATCH 106/201] add patch for substrait --- .github/config/out_of_tree_extensions.cmake | 1 + .../extensions/substrait/shared_ptr.patch | 130 ++++++++++++++++++ 2 files changed, 131 insertions(+) create mode 100644 .github/patches/extensions/substrait/shared_ptr.patch diff --git a/.github/config/out_of_tree_extensions.cmake b/.github/config/out_of_tree_extensions.cmake index a96f35e1d7be..3c20819cfd5d 100644 --- a/.github/config/out_of_tree_extensions.cmake +++ b/.github/config/out_of_tree_extensions.cmake @@ -102,5 +102,6 @@ if (NOT WIN32) LOAD_TESTS DONT_LINK GIT_URL https://github.com/duckdb/substrait GIT_TAG 1116fb580edd3e26e675436dbdbdf4a0aa5e456e + APPLY_PATCHES ) endif() diff --git a/.github/patches/extensions/substrait/shared_ptr.patch b/.github/patches/extensions/substrait/shared_ptr.patch new file mode 100644 index 000000000000..d04cfb9655af --- /dev/null +++ b/.github/patches/extensions/substrait/shared_ptr.patch @@ -0,0 +1,130 @@ +diff --git a/src/from_substrait.cpp b/src/from_substrait.cpp +index 566e21d..afbbb0b 100644 +--- a/src/from_substrait.cpp ++++ b/src/from_substrait.cpp +@@ -14,6 +14,8 @@ + #include "duckdb/main/connection.hpp" + #include "duckdb/parser/parser.hpp" + #include "duckdb/common/exception.hpp" ++#include "duckdb/common/helper.hpp" ++#include "duckdb/common/shared_ptr.hpp" + #include "duckdb/common/types.hpp" + #include "duckdb/common/enums/set_operation_type.hpp" + +@@ -404,25 +406,25 @@ shared_ptr SubstraitToDuckDB::TransformJoinOp(const substrait::Rel &so + throw InternalException("Unsupported join type"); + } + unique_ptr join_condition = TransformExpr(sjoin.expression()); +- return make_shared(TransformOp(sjoin.left())->Alias("left"), ++ return make_shared_ptr(TransformOp(sjoin.left())->Alias("left"), + TransformOp(sjoin.right())->Alias("right"), std::move(join_condition), djointype); + } + + shared_ptr SubstraitToDuckDB::TransformCrossProductOp(const substrait::Rel &sop) { + auto &sub_cross = sop.cross(); + +- return make_shared(TransformOp(sub_cross.left())->Alias("left"), ++ return make_shared_ptr(TransformOp(sub_cross.left())->Alias("left"), + TransformOp(sub_cross.right())->Alias("right")); + } + + shared_ptr SubstraitToDuckDB::TransformFetchOp(const substrait::Rel &sop) { + auto &slimit = sop.fetch(); +- return make_shared(TransformOp(slimit.input()), slimit.count(), slimit.offset()); ++ return make_shared_ptr(TransformOp(slimit.input()), slimit.count(), slimit.offset()); + } + + shared_ptr SubstraitToDuckDB::TransformFilterOp(const substrait::Rel &sop) { + auto &sfilter = sop.filter(); +- return make_shared(TransformOp(sfilter.input()), TransformExpr(sfilter.condition())); ++ return make_shared_ptr(TransformOp(sfilter.input()), TransformExpr(sfilter.condition())); + } + + shared_ptr SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop) { +@@ -435,7 +437,7 @@ shared_ptr SubstraitToDuckDB::TransformProjectOp(const substrait::Rel + for (size_t i = 0; i < expressions.size(); i++) { + mock_aliases.push_back("expr_" + to_string(i)); + } +- return make_shared(TransformOp(sop.project().input()), std::move(expressions), ++ return make_shared_ptr(TransformOp(sop.project().input()), std::move(expressions), + std::move(mock_aliases)); + } + +@@ -463,7 +465,7 @@ shared_ptr SubstraitToDuckDB::TransformAggregateOp(const substrait::Re + expressions.push_back(make_uniq(RemapFunctionName(function_name), std::move(children))); + } + +- return make_shared(TransformOp(sop.aggregate().input()), std::move(expressions), ++ return make_shared_ptr(TransformOp(sop.aggregate().input()), std::move(expressions), + std::move(groups)); + } + +@@ -502,7 +504,7 @@ shared_ptr SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so + } + + if (sget.has_filter()) { +- scan = make_shared(std::move(scan), TransformExpr(sget.filter())); ++ scan = make_shared_ptr(std::move(scan), TransformExpr(sget.filter())); + } + + if (sget.has_projection()) { +@@ -516,7 +518,7 @@ shared_ptr SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so + expressions.push_back(make_uniq(sproj.field() + 1)); + } + +- scan = make_shared(std::move(scan), std::move(expressions), std::move(aliases)); ++ scan = make_shared_ptr(std::move(scan), std::move(expressions), std::move(aliases)); + } + + return scan; +@@ -527,7 +529,7 @@ shared_ptr SubstraitToDuckDB::TransformSortOp(const substrait::Rel &so + for (auto &sordf : sop.sort().sorts()) { + order_nodes.push_back(TransformOrder(sordf)); + } +- return make_shared(TransformOp(sop.sort().input()), std::move(order_nodes)); ++ return make_shared_ptr(TransformOp(sop.sort().input()), std::move(order_nodes)); + } + + static duckdb::SetOperationType TransformSetOperationType(substrait::SetRel_SetOp setop) { +@@ -562,7 +564,7 @@ shared_ptr SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop + auto lhs = TransformOp(inputs[0]); + auto rhs = TransformOp(inputs[1]); + +- return make_shared(std::move(lhs), std::move(rhs), type); ++ return make_shared_ptr(std::move(lhs), std::move(rhs), type); + } + + shared_ptr SubstraitToDuckDB::TransformOp(const substrait::Rel &sop) { +@@ -599,7 +601,7 @@ shared_ptr SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot + aliases.push_back(column_name); + expressions.push_back(make_uniq(id++)); + } +- return make_shared(TransformOp(sop.input()), std::move(expressions), aliases); ++ return make_shared_ptr(TransformOp(sop.input()), std::move(expressions), aliases); + } + + shared_ptr SubstraitToDuckDB::TransformPlan() { +diff --git a/src/include/from_substrait.hpp b/src/include/from_substrait.hpp +index 8ea96cd..3a632ce 100644 +--- a/src/include/from_substrait.hpp ++++ b/src/include/from_substrait.hpp +@@ -5,6 +5,7 @@ + #include + #include "substrait/plan.pb.h" + #include "duckdb/main/connection.hpp" ++#include "duckdb/common/shared_ptr.hpp" + + namespace duckdb { + class SubstraitToDuckDB { +diff --git a/src/substrait_extension.cpp b/src/substrait_extension.cpp +index fae645c..6422ebd 100644 +--- a/src/substrait_extension.cpp ++++ b/src/substrait_extension.cpp +@@ -6,6 +6,7 @@ + + #ifndef DUCKDB_AMALGAMATION + #include "duckdb/common/enums/optimizer_type.hpp" ++#include "duckdb/common/shared_ptr.hpp" + #include "duckdb/function/table_function.hpp" + #include "duckdb/parser/parsed_data/create_table_function_info.hpp" + #include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" From 8247c57d95b736b0b258b8677cd753b48d0593c0 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 12 Apr 2024 12:59:49 +0200 Subject: [PATCH 107/201] add patch for arrow --- .github/config/out_of_tree_extensions.cmake | 1 + .../patches/extensions/arrow/shared_ptr.patch | 1134 +++++++++++++++++ 2 files changed, 1135 insertions(+) create mode 100644 .github/patches/extensions/arrow/shared_ptr.patch diff --git a/.github/config/out_of_tree_extensions.cmake b/.github/config/out_of_tree_extensions.cmake index 3c20819cfd5d..e1bcd5e2aec9 100644 --- a/.github/config/out_of_tree_extensions.cmake +++ b/.github/config/out_of_tree_extensions.cmake @@ -20,6 +20,7 @@ duckdb_extension_load(arrow LOAD_TESTS DONT_LINK GIT_URL https://github.com/duckdb/arrow GIT_TAG 9e10240da11f61ea7fbfe3fc9988ffe672ccd40f + APPLY_PATCHES ) ################## AWS diff --git a/.github/patches/extensions/arrow/shared_ptr.patch b/.github/patches/extensions/arrow/shared_ptr.patch new file mode 100644 index 000000000000..d523ffe01119 --- /dev/null +++ b/.github/patches/extensions/arrow/shared_ptr.patch @@ -0,0 +1,1134 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 72b1370..c95486c 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -7,56 +7,55 @@ set(EXTENSION_NAME ${TARGET_NAME}_extension) + project(${TARGET_NAME}) + include_directories(src/include) + +-set(EXTENSION_SOURCES +- src/arrow_extension.cpp +- src/arrow_stream_buffer.cpp +- src/arrow_scan_ipc.cpp +- src/arrow_to_ipc.cpp) ++set(EXTENSION_SOURCES src/arrow_extension.cpp src/arrow_stream_buffer.cpp ++ src/arrow_scan_ipc.cpp src/arrow_to_ipc.cpp) + + if(NOT "${OSX_BUILD_ARCH}" STREQUAL "") +- set(OSX_ARCH_FLAG -DCMAKE_OSX_ARCHITECTURES=${OSX_BUILD_ARCH}) ++ set(OSX_ARCH_FLAG -DCMAKE_OSX_ARCHITECTURES=${OSX_BUILD_ARCH}) + else() +- set(OSX_ARCH_FLAG "") ++ set(OSX_ARCH_FLAG "") + endif() + + # Building Arrow + include(ExternalProject) + ExternalProject_Add( +- ARROW_EP +- GIT_REPOSITORY "https://github.com/apache/arrow" +- GIT_TAG ea6875fd2a3ac66547a9a33c5506da94f3ff07f2 +- PREFIX "${CMAKE_BINARY_DIR}/third_party/arrow" +- INSTALL_DIR "${CMAKE_BINARY_DIR}/third_party/arrow/install" +- BUILD_BYPRODUCTS /lib/libarrow.a +- CONFIGURE_COMMAND +- ${CMAKE_COMMAND} -G${CMAKE_GENERATOR} ${OSX_ARCH_FLAG} +- -DCMAKE_BUILD_TYPE=Release +- -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/third_party/arrow/install +- -DCMAKE_INSTALL_LIBDIR=lib -DARROW_BUILD_STATIC=ON -DARROW_BUILD_SHARED=OFF +- -DARROW_NO_DEPRECATED_API=ON -DARROW_POSITION_INDEPENDENT_CODE=ON +- -DARROW_SIMD_LEVEL=NONE -DARROW_ENABLE_TIMING_TESTS=OFF -DARROW_IPC=ON +- -DARROW_JEMALLOC=OFF -DARROW_DEPENDENCY_SOURCE=BUNDLED +- -DARROW_VERBOSE_THIRDPARTY_BUILD=OFF -DARROW_DEPENDENCY_USE_SHARED=OFF +- -DARROW_BOOST_USE_SHARED=OFF -DARROW_BROTLI_USE_SHARED=OFF +- -DARROW_BZ2_USE_SHARED=OFF -DARROW_GFLAGS_USE_SHARED=OFF +- -DARROW_GRPC_USE_SHARED=OFF -DARROW_JEMALLOC_USE_SHARED=OFF +- -DARROW_LZ4_USE_SHARED=OFF -DARROW_OPENSSL_USE_SHARED=OFF +- -DARROW_PROTOBUF_USE_SHARED=OFF -DARROW_SNAPPY_USE_SHARED=OFF +- -DARROW_THRIFT_USE_SHARED=OFF -DARROW_UTF8PROC_USE_SHARED=OFF +- -DARROW_ZSTD_USE_SHARED=OFF -DARROW_USE_GLOG=OFF -DARROW_WITH_BACKTRACE=OFF +- -DARROW_WITH_OPENTELEMETRY=OFF -DARROW_WITH_BROTLI=OFF -DARROW_WITH_BZ2=OFF +- -DARROW_WITH_LZ4=OFF -DARROW_WITH_SNAPPY=OFF -DARROW_WITH_ZLIB=OFF +- -DARROW_WITH_ZSTD=OFF -DARROW_WITH_UCX=OFF -DARROW_WITH_UTF8PROC=OFF +- -DARROW_WITH_RE2=OFF /cpp +- CMAKE_ARGS -Wno-dev +- UPDATE_COMMAND "") ++ ARROW_EP ++ GIT_REPOSITORY "https://github.com/apache/arrow" ++ GIT_TAG ea6875fd2a3ac66547a9a33c5506da94f3ff07f2 ++ PREFIX "${CMAKE_BINARY_DIR}/third_party/arrow" ++ INSTALL_DIR "${CMAKE_BINARY_DIR}/third_party/arrow/install" ++ BUILD_BYPRODUCTS /lib/libarrow.a ++ CONFIGURE_COMMAND ++ ${CMAKE_COMMAND} -G${CMAKE_GENERATOR} ${OSX_ARCH_FLAG} ++ -DCMAKE_BUILD_TYPE=Release ++ -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/third_party/arrow/install ++ -DCMAKE_INSTALL_LIBDIR=lib -DARROW_BUILD_STATIC=ON -DARROW_BUILD_SHARED=OFF ++ -DARROW_NO_DEPRECATED_API=ON -DARROW_POSITION_INDEPENDENT_CODE=ON ++ -DARROW_SIMD_LEVEL=NONE -DARROW_ENABLE_TIMING_TESTS=OFF -DARROW_IPC=ON ++ -DARROW_JEMALLOC=OFF -DARROW_DEPENDENCY_SOURCE=BUNDLED ++ -DARROW_VERBOSE_THIRDPARTY_BUILD=OFF -DARROW_DEPENDENCY_USE_SHARED=OFF ++ -DARROW_BOOST_USE_SHARED=OFF -DARROW_BROTLI_USE_SHARED=OFF ++ -DARROW_BZ2_USE_SHARED=OFF -DARROW_GFLAGS_USE_SHARED=OFF ++ -DARROW_GRPC_USE_SHARED=OFF -DARROW_JEMALLOC_USE_SHARED=OFF ++ -DARROW_LZ4_USE_SHARED=OFF -DARROW_OPENSSL_USE_SHARED=OFF ++ -DARROW_PROTOBUF_USE_SHARED=OFF -DARROW_SNAPPY_USE_SHARED=OFF ++ -DARROW_THRIFT_USE_SHARED=OFF -DARROW_UTF8PROC_USE_SHARED=OFF ++ -DARROW_ZSTD_USE_SHARED=OFF -DARROW_USE_GLOG=OFF -DARROW_WITH_BACKTRACE=OFF ++ -DARROW_WITH_OPENTELEMETRY=OFF -DARROW_WITH_BROTLI=OFF -DARROW_WITH_BZ2=OFF ++ -DARROW_WITH_LZ4=OFF -DARROW_WITH_SNAPPY=OFF -DARROW_WITH_ZLIB=OFF ++ -DARROW_WITH_ZSTD=OFF -DARROW_WITH_UCX=OFF -DARROW_WITH_UTF8PROC=OFF ++ -DARROW_WITH_RE2=OFF /cpp ++ CMAKE_ARGS -Wno-dev ++ UPDATE_COMMAND "") + + ExternalProject_Get_Property(ARROW_EP install_dir) + add_library(arrow STATIC IMPORTED GLOBAL) + if(WIN32) +- set_target_properties(arrow PROPERTIES IMPORTED_LOCATION ${install_dir}/lib/arrow_static.lib) ++ set_target_properties(arrow PROPERTIES IMPORTED_LOCATION ++ ${install_dir}/lib/arrow_static.lib) + else() +- set_target_properties(arrow PROPERTIES IMPORTED_LOCATION ${install_dir}/lib/libarrow.a) ++ set_target_properties(arrow PROPERTIES IMPORTED_LOCATION ++ ${install_dir}/lib/libarrow.a) + endif() + + # create static library +@@ -71,12 +70,14 @@ build_loadable_extension(${TARGET_NAME} ${PARAMETERS} ${EXTENSION_SOURCES}) + add_dependencies(${TARGET_NAME}_loadable_extension ARROW_EP) + target_link_libraries(${TARGET_NAME}_loadable_extension arrow) + if(WIN32) +- target_compile_definitions(${TARGET_NAME}_loadable_extension PUBLIC ARROW_STATIC) ++ target_compile_definitions(${TARGET_NAME}_loadable_extension ++ PUBLIC ARROW_STATIC) + endif() +-target_include_directories(${TARGET_NAME}_loadable_extension PRIVATE ${install_dir}/include) ++target_include_directories(${TARGET_NAME}_loadable_extension ++ PRIVATE ${install_dir}/include) + + install( +- TARGETS ${EXTENSION_NAME} +- EXPORT "${DUCKDB_EXPORT_SET}" +- LIBRARY DESTINATION "${INSTALL_LIB_DIR}" +- ARCHIVE DESTINATION "${INSTALL_LIB_DIR}") +\ No newline at end of file ++ TARGETS ${EXTENSION_NAME} ++ EXPORT "${DUCKDB_EXPORT_SET}" ++ LIBRARY DESTINATION "${INSTALL_LIB_DIR}" ++ ARCHIVE DESTINATION "${INSTALL_LIB_DIR}") +diff --git a/src/arrow_extension.cpp b/src/arrow_extension.cpp +index e4daf26..6fadec0 100644 +--- a/src/arrow_extension.cpp ++++ b/src/arrow_extension.cpp +@@ -18,27 +18,24 @@ + namespace duckdb { + + static void LoadInternal(DatabaseInstance &instance) { +- ExtensionUtil::RegisterFunction(instance, ToArrowIPCFunction::GetFunction()); +- ExtensionUtil::RegisterFunction(instance, ArrowIPCTableFunction::GetFunction()); ++ ExtensionUtil::RegisterFunction(instance, ToArrowIPCFunction::GetFunction()); ++ ExtensionUtil::RegisterFunction(instance, ++ ArrowIPCTableFunction::GetFunction()); + } + +-void ArrowExtension::Load(DuckDB &db) { +- LoadInternal(*db.instance); +-} +-std::string ArrowExtension::Name() { +- return "arrow"; +-} ++void ArrowExtension::Load(DuckDB &db) { LoadInternal(*db.instance); } ++std::string ArrowExtension::Name() { return "arrow"; } + + } // namespace duckdb + + extern "C" { + + DUCKDB_EXTENSION_API void arrow_init(duckdb::DatabaseInstance &db) { +- LoadInternal(db); ++ LoadInternal(db); + } + + DUCKDB_EXTENSION_API const char *arrow_version() { +- return duckdb::DuckDB::LibraryVersion(); ++ return duckdb::DuckDB::LibraryVersion(); + } + } + +diff --git a/src/arrow_scan_ipc.cpp b/src/arrow_scan_ipc.cpp +index 7d5b2ff..a60d255 100644 +--- a/src/arrow_scan_ipc.cpp ++++ b/src/arrow_scan_ipc.cpp +@@ -3,111 +3,131 @@ + namespace duckdb { + + TableFunction ArrowIPCTableFunction::GetFunction() { +- child_list_t make_buffer_struct_children{{"ptr", LogicalType::UBIGINT}, +- {"size", LogicalType::UBIGINT}}; +- +- TableFunction scan_arrow_ipc_func( +- "scan_arrow_ipc", {LogicalType::LIST(LogicalType::STRUCT(make_buffer_struct_children))}, +- ArrowIPCTableFunction::ArrowScanFunction, ArrowIPCTableFunction::ArrowScanBind, +- ArrowTableFunction::ArrowScanInitGlobal, ArrowTableFunction::ArrowScanInitLocal); +- +- scan_arrow_ipc_func.cardinality = ArrowTableFunction::ArrowScanCardinality; +- scan_arrow_ipc_func.get_batch_index = nullptr; // TODO implement +- scan_arrow_ipc_func.projection_pushdown = true; +- scan_arrow_ipc_func.filter_pushdown = false; +- +- return scan_arrow_ipc_func; ++ child_list_t make_buffer_struct_children{ ++ {"ptr", LogicalType::UBIGINT}, {"size", LogicalType::UBIGINT}}; ++ ++ TableFunction scan_arrow_ipc_func( ++ "scan_arrow_ipc", ++ {LogicalType::LIST(LogicalType::STRUCT(make_buffer_struct_children))}, ++ ArrowIPCTableFunction::ArrowScanFunction, ++ ArrowIPCTableFunction::ArrowScanBind, ++ ArrowTableFunction::ArrowScanInitGlobal, ++ ArrowTableFunction::ArrowScanInitLocal); ++ ++ scan_arrow_ipc_func.cardinality = ArrowTableFunction::ArrowScanCardinality; ++ scan_arrow_ipc_func.get_batch_index = nullptr; // TODO implement ++ scan_arrow_ipc_func.projection_pushdown = true; ++ scan_arrow_ipc_func.filter_pushdown = false; ++ ++ return scan_arrow_ipc_func; + } + +-unique_ptr ArrowIPCTableFunction::ArrowScanBind(ClientContext &context, TableFunctionBindInput &input, +- vector &return_types, vector &names) { +- auto stream_decoder = make_uniq(); ++unique_ptr ArrowIPCTableFunction::ArrowScanBind( ++ ClientContext &context, TableFunctionBindInput &input, ++ vector &return_types, vector &names) { ++ auto stream_decoder = make_uniq(); + +- // Decode buffer ptr list +- auto buffer_ptr_list = ListValue::GetChildren(input.inputs[0]); +- for (auto &buffer_ptr_struct: buffer_ptr_list) { +- auto unpacked = StructValue::GetChildren(buffer_ptr_struct); +- uint64_t ptr = unpacked[0].GetValue(); +- uint64_t size = unpacked[1].GetValue(); ++ // Decode buffer ptr list ++ auto buffer_ptr_list = ListValue::GetChildren(input.inputs[0]); ++ for (auto &buffer_ptr_struct : buffer_ptr_list) { ++ auto unpacked = StructValue::GetChildren(buffer_ptr_struct); ++ uint64_t ptr = unpacked[0].GetValue(); ++ uint64_t size = unpacked[1].GetValue(); + +- // Feed stream into decoder +- auto res = stream_decoder->Consume((const uint8_t *) ptr, size); ++ // Feed stream into decoder ++ auto res = stream_decoder->Consume((const uint8_t *)ptr, size); + +- if (!res.ok()) { +- throw IOException("Invalid IPC stream"); +- } ++ if (!res.ok()) { ++ throw IOException("Invalid IPC stream"); + } +- +- if (!stream_decoder->buffer()->is_eos()) { +- throw IOException("IPC buffers passed to arrow scan should contain entire stream"); ++ } ++ ++ if (!stream_decoder->buffer()->is_eos()) { ++ throw IOException( ++ "IPC buffers passed to arrow scan should contain entire stream"); ++ } ++ ++ // These are the params I need to produce from the ipc buffers using the ++ // WebDB.cc code ++ auto stream_factory_ptr = (uintptr_t)&stream_decoder->buffer(); ++ auto stream_factory_produce = ++ (stream_factory_produce_t)&ArrowIPCStreamBufferReader::CreateStream; ++ auto stream_factory_get_schema = ++ (stream_factory_get_schema_t)&ArrowIPCStreamBufferReader::GetSchema; ++ auto res = make_uniq(stream_factory_produce, ++ stream_factory_ptr); ++ ++ // Store decoder ++ res->stream_decoder = std::move(stream_decoder); ++ ++ // TODO Everything below this is identical to the bind in ++ // duckdb/src/function/table/arrow.cpp ++ auto &data = *res; ++ stream_factory_get_schema((ArrowArrayStream *)stream_factory_ptr, ++ data.schema_root.arrow_schema); ++ for (idx_t col_idx = 0; ++ col_idx < (idx_t)data.schema_root.arrow_schema.n_children; col_idx++) { ++ auto &schema = *data.schema_root.arrow_schema.children[col_idx]; ++ if (!schema.release) { ++ throw InvalidInputException("arrow_scan: released schema passed"); + } +- +- // These are the params I need to produce from the ipc buffers using the WebDB.cc code +- auto stream_factory_ptr = (uintptr_t) & stream_decoder->buffer(); +- auto stream_factory_produce = (stream_factory_produce_t) & ArrowIPCStreamBufferReader::CreateStream; +- auto stream_factory_get_schema = (stream_factory_get_schema_t) & ArrowIPCStreamBufferReader::GetSchema; +- auto res = make_uniq(stream_factory_produce, stream_factory_ptr); +- +- // Store decoder +- res->stream_decoder = std::move(stream_decoder); +- +- // TODO Everything below this is identical to the bind in duckdb/src/function/table/arrow.cpp +- auto &data = *res; +- stream_factory_get_schema((ArrowArrayStream *) stream_factory_ptr, data.schema_root.arrow_schema); +- for (idx_t col_idx = 0; col_idx < (idx_t) data.schema_root.arrow_schema.n_children; col_idx++) { +- auto &schema = *data.schema_root.arrow_schema.children[col_idx]; +- if (!schema.release) { +- throw InvalidInputException("arrow_scan: released schema passed"); +- } +- auto arrow_type = GetArrowLogicalType(schema); +- if (schema.dictionary) { +- auto dictionary_type = GetArrowLogicalType(*schema.dictionary); +- return_types.emplace_back(dictionary_type->GetDuckType()); +- arrow_type->SetDictionary(std::move(dictionary_type)); +- } else { +- return_types.emplace_back(arrow_type->GetDuckType()); +- } +- res->arrow_table.AddColumn(col_idx, std::move(arrow_type)); +- auto format = string(schema.format); +- auto name = string(schema.name); +- if (name.empty()) { +- name = string("v") + to_string(col_idx); +- } +- names.push_back(name); ++ auto arrow_type = GetArrowLogicalType(schema); ++ if (schema.dictionary) { ++ auto dictionary_type = GetArrowLogicalType(*schema.dictionary); ++ return_types.emplace_back(dictionary_type->GetDuckType()); ++ arrow_type->SetDictionary(std::move(dictionary_type)); ++ } else { ++ return_types.emplace_back(arrow_type->GetDuckType()); + } +- QueryResult::DeduplicateColumns(names); +- return std::move(res); ++ res->arrow_table.AddColumn(col_idx, std::move(arrow_type)); ++ auto format = string(schema.format); ++ auto name = string(schema.name); ++ if (name.empty()) { ++ name = string("v") + to_string(col_idx); ++ } ++ names.push_back(name); ++ } ++ QueryResult::DeduplicateColumns(names); ++ return std::move(res); + } + +-// Same as regular arrow scan, except ArrowToDuckDB call TODO: refactor to allow nicely overriding this +-void ArrowIPCTableFunction::ArrowScanFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { +- if (!data_p.local_state) { +- return; +- } +- auto &data = data_p.bind_data->CastNoConst(); +- auto &state = data_p.local_state->Cast(); +- auto &global_state = data_p.global_state->Cast(); +- +- //! Out of tuples in this chunk +- if (state.chunk_offset >= (idx_t)state.chunk->arrow_array.length) { +- if (!ArrowScanParallelStateNext(context, data_p.bind_data.get(), state, global_state)) { +- return; +- } ++// Same as regular arrow scan, except ArrowToDuckDB call TODO: refactor to allow ++// nicely overriding this ++void ArrowIPCTableFunction::ArrowScanFunction(ClientContext &context, ++ TableFunctionInput &data_p, ++ DataChunk &output) { ++ if (!data_p.local_state) { ++ return; ++ } ++ auto &data = data_p.bind_data->CastNoConst(); ++ auto &state = data_p.local_state->Cast(); ++ auto &global_state = data_p.global_state->Cast(); ++ ++ //! Out of tuples in this chunk ++ if (state.chunk_offset >= (idx_t)state.chunk->arrow_array.length) { ++ if (!ArrowScanParallelStateNext(context, data_p.bind_data.get(), state, ++ global_state)) { ++ return; + } +- int64_t output_size = MinValue(STANDARD_VECTOR_SIZE, state.chunk->arrow_array.length - state.chunk_offset); +- data.lines_read += output_size; +- if (global_state.CanRemoveFilterColumns()) { +- state.all_columns.Reset(); +- state.all_columns.SetCardinality(output_size); +- ArrowToDuckDB(state, data.arrow_table.GetColumns(), state.all_columns, data.lines_read - output_size, false); +- output.ReferenceColumns(state.all_columns, global_state.projection_ids); +- } else { +- output.SetCardinality(output_size); +- ArrowToDuckDB(state, data.arrow_table.GetColumns(), output, data.lines_read - output_size, false); +- } +- +- output.Verify(); +- state.chunk_offset += output.size(); ++ } ++ int64_t output_size = ++ MinValue(STANDARD_VECTOR_SIZE, ++ state.chunk->arrow_array.length - state.chunk_offset); ++ data.lines_read += output_size; ++ if (global_state.CanRemoveFilterColumns()) { ++ state.all_columns.Reset(); ++ state.all_columns.SetCardinality(output_size); ++ ArrowToDuckDB(state, data.arrow_table.GetColumns(), state.all_columns, ++ data.lines_read - output_size, false); ++ output.ReferenceColumns(state.all_columns, global_state.projection_ids); ++ } else { ++ output.SetCardinality(output_size); ++ ArrowToDuckDB(state, data.arrow_table.GetColumns(), output, ++ data.lines_read - output_size, false); ++ } ++ ++ output.Verify(); ++ state.chunk_offset += output.size(); + } + + } // namespace duckdb +\ No newline at end of file +diff --git a/src/arrow_stream_buffer.cpp b/src/arrow_stream_buffer.cpp +index f097ca1..c9791e4 100644 +--- a/src/arrow_stream_buffer.cpp ++++ b/src/arrow_stream_buffer.cpp +@@ -1,95 +1,108 @@ + #include "arrow_stream_buffer.hpp" + + #include ++#include + + /// File copied from + /// https://github.com/duckdb/duckdb-wasm/blob/0ad10e7db4ef4025f5f4120be37addc4ebe29618/lib/src/arrow_stream_buffer.cc + namespace duckdb { + + /// Constructor +-ArrowIPCStreamBuffer::ArrowIPCStreamBuffer() : schema_(nullptr), batches_(), is_eos_(false) { +-} ++ArrowIPCStreamBuffer::ArrowIPCStreamBuffer() ++ : schema_(nullptr), batches_(), is_eos_(false) {} + /// Decoded a schema +-arrow::Status ArrowIPCStreamBuffer::OnSchemaDecoded(std::shared_ptr s) { +- schema_ = s; +- return arrow::Status::OK(); ++arrow::Status ++ArrowIPCStreamBuffer::OnSchemaDecoded(std::shared_ptr s) { ++ schema_ = s; ++ return arrow::Status::OK(); + } + /// Decoded a record batch +-arrow::Status ArrowIPCStreamBuffer::OnRecordBatchDecoded(std::shared_ptr batch) { +- batches_.push_back(batch); +- return arrow::Status::OK(); ++arrow::Status ArrowIPCStreamBuffer::OnRecordBatchDecoded( ++ std::shared_ptr batch) { ++ batches_.push_back(batch); ++ return arrow::Status::OK(); + } + /// Reached end of stream + arrow::Status ArrowIPCStreamBuffer::OnEOS() { +- is_eos_ = true; +- return arrow::Status::OK(); ++ is_eos_ = true; ++ return arrow::Status::OK(); + } + + /// Constructor +-ArrowIPCStreamBufferReader::ArrowIPCStreamBufferReader(std::shared_ptr buffer) +- : buffer_(buffer), next_batch_id_(0) { +-} ++ArrowIPCStreamBufferReader::ArrowIPCStreamBufferReader( ++ std::shared_ptr buffer) ++ : buffer_(buffer), next_batch_id_(0) {} + + /// Get the schema + std::shared_ptr ArrowIPCStreamBufferReader::schema() const { +- return buffer_->schema(); ++ return buffer_->schema(); + } + /// Read the next record batch in the stream. Return null for batch when + /// reaching end of stream +-arrow::Status ArrowIPCStreamBufferReader::ReadNext(std::shared_ptr *batch) { +- if (next_batch_id_ >= buffer_->batches().size()) { +- *batch = nullptr; +- return arrow::Status::OK(); +- } +- *batch = buffer_->batches()[next_batch_id_++]; +- return arrow::Status::OK(); ++arrow::Status ArrowIPCStreamBufferReader::ReadNext( ++ std::shared_ptr *batch) { ++ if (next_batch_id_ >= buffer_->batches().size()) { ++ *batch = nullptr; ++ return arrow::Status::OK(); ++ } ++ *batch = buffer_->batches()[next_batch_id_++]; ++ return arrow::Status::OK(); + } + + /// Arrow array stream factory function + duckdb::unique_ptr +-ArrowIPCStreamBufferReader::CreateStream(uintptr_t buffer_ptr, ArrowStreamParameters ¶meters) { +- assert(buffer_ptr != 0); +- auto buffer = reinterpret_cast *>(buffer_ptr); +- auto reader = std::make_shared(*buffer); ++ArrowIPCStreamBufferReader::CreateStream(uintptr_t buffer_ptr, ++ ArrowStreamParameters ¶meters) { ++ assert(buffer_ptr != 0); ++ auto buffer = ++ reinterpret_cast *>(buffer_ptr); ++ auto reader = std::make_shared(*buffer); + +- // Create arrow stream +- auto stream_wrapper = duckdb::make_uniq(); +- stream_wrapper->arrow_array_stream.release = nullptr; +- auto maybe_ok = arrow::ExportRecordBatchReader(reader, &stream_wrapper->arrow_array_stream); +- if (!maybe_ok.ok()) { +- if (stream_wrapper->arrow_array_stream.release) { +- stream_wrapper->arrow_array_stream.release(&stream_wrapper->arrow_array_stream); +- } +- return nullptr; +- } ++ // Create arrow stream ++ auto stream_wrapper = duckdb::make_uniq(); ++ stream_wrapper->arrow_array_stream.release = nullptr; ++ auto maybe_ok = arrow::ExportRecordBatchReader( ++ reader, &stream_wrapper->arrow_array_stream); ++ if (!maybe_ok.ok()) { ++ if (stream_wrapper->arrow_array_stream.release) { ++ stream_wrapper->arrow_array_stream.release( ++ &stream_wrapper->arrow_array_stream); ++ } ++ return nullptr; ++ } + +- // Release the stream +- return stream_wrapper; ++ // Release the stream ++ return stream_wrapper; + } + +-void ArrowIPCStreamBufferReader::GetSchema(uintptr_t buffer_ptr, duckdb::ArrowSchemaWrapper &schema) { +- assert(buffer_ptr != 0); +- auto buffer = reinterpret_cast *>(buffer_ptr); +- auto reader = std::make_shared(*buffer); ++void ArrowIPCStreamBufferReader::GetSchema(uintptr_t buffer_ptr, ++ duckdb::ArrowSchemaWrapper &schema) { ++ assert(buffer_ptr != 0); ++ auto buffer = ++ reinterpret_cast *>(buffer_ptr); ++ auto reader = std::make_shared(*buffer); + +- // Create arrow stream +- auto stream_wrapper = duckdb::make_uniq(); +- stream_wrapper->arrow_array_stream.release = nullptr; +- auto maybe_ok = arrow::ExportRecordBatchReader(reader, &stream_wrapper->arrow_array_stream); +- if (!maybe_ok.ok()) { +- if (stream_wrapper->arrow_array_stream.release) { +- stream_wrapper->arrow_array_stream.release(&stream_wrapper->arrow_array_stream); +- } +- return; +- } ++ // Create arrow stream ++ auto stream_wrapper = duckdb::make_uniq(); ++ stream_wrapper->arrow_array_stream.release = nullptr; ++ auto maybe_ok = arrow::ExportRecordBatchReader( ++ reader, &stream_wrapper->arrow_array_stream); ++ if (!maybe_ok.ok()) { ++ if (stream_wrapper->arrow_array_stream.release) { ++ stream_wrapper->arrow_array_stream.release( ++ &stream_wrapper->arrow_array_stream); ++ } ++ return; ++ } + +- // Pass ownership to caller +- stream_wrapper->arrow_array_stream.get_schema(&stream_wrapper->arrow_array_stream, &schema.arrow_schema); ++ // Pass ownership to caller ++ stream_wrapper->arrow_array_stream.get_schema( ++ &stream_wrapper->arrow_array_stream, &schema.arrow_schema); + } + + /// Constructor +-BufferingArrowIPCStreamDecoder::BufferingArrowIPCStreamDecoder(std::shared_ptr buffer) +- : arrow::ipc::StreamDecoder(buffer), buffer_(buffer) { +-} ++BufferingArrowIPCStreamDecoder::BufferingArrowIPCStreamDecoder( ++ std::shared_ptr buffer) ++ : arrow::ipc::StreamDecoder(buffer), buffer_(buffer) {} + + } // namespace duckdb +diff --git a/src/arrow_to_ipc.cpp b/src/arrow_to_ipc.cpp +index e282612..c316d85 100644 +--- a/src/arrow_to_ipc.cpp ++++ b/src/arrow_to_ipc.cpp +@@ -15,6 +15,8 @@ + #include "arrow/type_fwd.h" + #include "arrow/c/bridge.h" + ++#include ++ + #include "duckdb.hpp" + #ifndef DUCKDB_AMALGAMATION + #include "duckdb/common/arrow/result_arrow_wrapper.hpp" +@@ -28,165 +30,180 @@ + namespace duckdb { + + struct ToArrowIpcFunctionData : public TableFunctionData { +- ToArrowIpcFunctionData() { +- } +- shared_ptr schema; +- idx_t chunk_size; ++ ToArrowIpcFunctionData() {} ++ std::shared_ptr schema; ++ idx_t chunk_size; + }; + + struct ToArrowIpcGlobalState : public GlobalTableFunctionState { +- ToArrowIpcGlobalState() : sent_schema(false) { +- } +- atomic sent_schema; +- mutex lock; ++ ToArrowIpcGlobalState() : sent_schema(false) {} ++ atomic sent_schema; ++ mutex lock; + }; + + struct ToArrowIpcLocalState : public LocalTableFunctionState { +- unique_ptr appender; +- idx_t current_count = 0; +- bool checked_schema = false; ++ unique_ptr appender; ++ idx_t current_count = 0; ++ bool checked_schema = false; + }; + +- +-unique_ptr ToArrowIPCFunction::InitLocal(ExecutionContext &context, TableFunctionInitInput &input, +- GlobalTableFunctionState *global_state) { +- return make_uniq(); ++unique_ptr ++ToArrowIPCFunction::InitLocal(ExecutionContext &context, ++ TableFunctionInitInput &input, ++ GlobalTableFunctionState *global_state) { ++ return make_uniq(); + } + +-unique_ptr ToArrowIPCFunction::InitGlobal(ClientContext &context, +- TableFunctionInitInput &input) { +- return make_uniq(); ++unique_ptr ++ToArrowIPCFunction::InitGlobal(ClientContext &context, ++ TableFunctionInitInput &input) { ++ return make_uniq(); + } + +-unique_ptr ToArrowIPCFunction::Bind(ClientContext &context, TableFunctionBindInput &input, +- vector &return_types, vector &names) { +- auto result = make_uniq(); ++unique_ptr ++ToArrowIPCFunction::Bind(ClientContext &context, TableFunctionBindInput &input, ++ vector &return_types, ++ vector &names) { ++ auto result = make_uniq(); + +- result->chunk_size = DEFAULT_CHUNK_SIZE * STANDARD_VECTOR_SIZE; ++ result->chunk_size = DEFAULT_CHUNK_SIZE * STANDARD_VECTOR_SIZE; + +- // Set return schema +- return_types.emplace_back(LogicalType::BLOB); +- names.emplace_back("ipc"); +- return_types.emplace_back(LogicalType::BOOLEAN); +- names.emplace_back("header"); ++ // Set return schema ++ return_types.emplace_back(LogicalType::BLOB); ++ names.emplace_back("ipc"); ++ return_types.emplace_back(LogicalType::BOOLEAN); ++ names.emplace_back("header"); + +- // Create the Arrow schema +- ArrowSchema schema; +- ArrowConverter::ToArrowSchema(&schema, input.input_table_types, input.input_table_names, context.GetClientProperties()); +- result->schema = arrow::ImportSchema(&schema).ValueOrDie(); ++ // Create the Arrow schema ++ ArrowSchema schema; ++ ArrowConverter::ToArrowSchema(&schema, input.input_table_types, ++ input.input_table_names, ++ context.GetClientProperties()); ++ result->schema = arrow::ImportSchema(&schema).ValueOrDie(); + +- return std::move(result); ++ return std::move(result); + } + +-OperatorResultType ToArrowIPCFunction::Function(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input, +- DataChunk &output) { +- std::shared_ptr arrow_serialized_ipc_buffer; +- auto &data = (ToArrowIpcFunctionData &)*data_p.bind_data; +- auto &local_state = (ToArrowIpcLocalState &)*data_p.local_state; +- auto &global_state = (ToArrowIpcGlobalState &)*data_p.global_state; +- +- bool sending_schema = false; +- +- bool caching_disabled = !PhysicalOperator::OperatorCachingAllowed(context); +- +- if (!local_state.checked_schema) { +- if (!global_state.sent_schema) { +- lock_guard init_lock(global_state.lock); +- if (!global_state.sent_schema) { +- // This run will send the schema, other threads can just send the buffers +- global_state.sent_schema = true; +- sending_schema = true; +- } +- } +- local_state.checked_schema = true; ++OperatorResultType ToArrowIPCFunction::Function(ExecutionContext &context, ++ TableFunctionInput &data_p, ++ DataChunk &input, ++ DataChunk &output) { ++ std::shared_ptr arrow_serialized_ipc_buffer; ++ auto &data = (ToArrowIpcFunctionData &)*data_p.bind_data; ++ auto &local_state = (ToArrowIpcLocalState &)*data_p.local_state; ++ auto &global_state = (ToArrowIpcGlobalState &)*data_p.global_state; ++ ++ bool sending_schema = false; ++ ++ bool caching_disabled = !PhysicalOperator::OperatorCachingAllowed(context); ++ ++ if (!local_state.checked_schema) { ++ if (!global_state.sent_schema) { ++ lock_guard init_lock(global_state.lock); ++ if (!global_state.sent_schema) { ++ // This run will send the schema, other threads can just send the ++ // buffers ++ global_state.sent_schema = true; ++ sending_schema = true; ++ } ++ } ++ local_state.checked_schema = true; ++ } ++ ++ if (sending_schema) { ++ auto result = arrow::ipc::SerializeSchema(*data.schema); ++ arrow_serialized_ipc_buffer = result.ValueOrDie(); ++ output.data[1].SetValue(0, Value::BOOLEAN(1)); ++ } else { ++ if (!local_state.appender) { ++ local_state.appender = ++ make_uniq(input.GetTypes(), data.chunk_size, ++ context.client.GetClientProperties()); + } + +- if (sending_schema) { +- auto result = arrow::ipc::SerializeSchema(*data.schema); +- arrow_serialized_ipc_buffer = result.ValueOrDie(); +- output.data[1].SetValue(0, Value::BOOLEAN(1)); ++ // Append input chunk ++ local_state.appender->Append(input, 0, input.size(), input.size()); ++ local_state.current_count += input.size(); ++ ++ // If chunk size is reached, we can flush to IPC blob ++ if (caching_disabled || local_state.current_count >= data.chunk_size) { ++ // Construct record batch from DataChunk ++ ArrowArray arr = local_state.appender->Finalize(); ++ auto record_batch = ++ arrow::ImportRecordBatch(&arr, data.schema).ValueOrDie(); ++ ++ // Serialize recordbatch ++ auto options = arrow::ipc::IpcWriteOptions::Defaults(); ++ auto result = arrow::ipc::SerializeRecordBatch(*record_batch, options); ++ arrow_serialized_ipc_buffer = result.ValueOrDie(); ++ ++ // Reset appender ++ local_state.appender.reset(); ++ local_state.current_count = 0; ++ ++ output.data[1].SetValue(0, Value::BOOLEAN(0)); + } else { +- if (!local_state.appender) { +- local_state.appender = make_uniq(input.GetTypes(), data.chunk_size, context.client.GetClientProperties()); +- } +- +- // Append input chunk +- local_state.appender->Append(input, 0, input.size(), input.size()); +- local_state.current_count += input.size(); +- +- // If chunk size is reached, we can flush to IPC blob +- if (caching_disabled || local_state.current_count >= data.chunk_size) { +- // Construct record batch from DataChunk +- ArrowArray arr = local_state.appender->Finalize(); +- auto record_batch = arrow::ImportRecordBatch(&arr, data.schema).ValueOrDie(); +- +- // Serialize recordbatch +- auto options = arrow::ipc::IpcWriteOptions::Defaults(); +- auto result = arrow::ipc::SerializeRecordBatch(*record_batch, options); +- arrow_serialized_ipc_buffer = result.ValueOrDie(); +- +- // Reset appender +- local_state.appender.reset(); +- local_state.current_count = 0; +- +- output.data[1].SetValue(0, Value::BOOLEAN(0)); +- } else { +- return OperatorResultType::NEED_MORE_INPUT; +- } ++ return OperatorResultType::NEED_MORE_INPUT; + } ++ } ++ ++ // TODO clean up ++ auto wrapped_buffer = ++ make_buffer(arrow_serialized_ipc_buffer); ++ auto &vector = output.data[0]; ++ StringVector::AddBuffer(vector, wrapped_buffer); ++ auto data_ptr = (string_t *)vector.GetData(); ++ *data_ptr = string_t((const char *)arrow_serialized_ipc_buffer->data(), ++ arrow_serialized_ipc_buffer->size()); ++ output.SetCardinality(1); ++ ++ if (sending_schema) { ++ return OperatorResultType::HAVE_MORE_OUTPUT; ++ } else { ++ return OperatorResultType::NEED_MORE_INPUT; ++ } ++} + +- // TODO clean up +- auto wrapped_buffer = make_buffer(arrow_serialized_ipc_buffer); ++OperatorFinalizeResultType ToArrowIPCFunction::FunctionFinal( ++ ExecutionContext &context, TableFunctionInput &data_p, DataChunk &output) { ++ auto &data = (ToArrowIpcFunctionData &)*data_p.bind_data; ++ auto &local_state = (ToArrowIpcLocalState &)*data_p.local_state; ++ std::shared_ptr arrow_serialized_ipc_buffer; ++ ++ // TODO clean up ++ if (local_state.appender) { ++ ArrowArray arr = local_state.appender->Finalize(); ++ auto record_batch = ++ arrow::ImportRecordBatch(&arr, data.schema).ValueOrDie(); ++ ++ // Serialize recordbatch ++ auto options = arrow::ipc::IpcWriteOptions::Defaults(); ++ auto result = arrow::ipc::SerializeRecordBatch(*record_batch, options); ++ arrow_serialized_ipc_buffer = result.ValueOrDie(); ++ ++ auto wrapped_buffer = ++ make_buffer(arrow_serialized_ipc_buffer); + auto &vector = output.data[0]; + StringVector::AddBuffer(vector, wrapped_buffer); + auto data_ptr = (string_t *)vector.GetData(); +- *data_ptr = string_t((const char *)arrow_serialized_ipc_buffer->data(), arrow_serialized_ipc_buffer->size()); ++ *data_ptr = string_t((const char *)arrow_serialized_ipc_buffer->data(), ++ arrow_serialized_ipc_buffer->size()); + output.SetCardinality(1); ++ local_state.appender.reset(); ++ output.data[1].SetValue(0, Value::BOOLEAN(0)); ++ } + +- if (sending_schema) { +- return OperatorResultType::HAVE_MORE_OUTPUT; +- } else { +- return OperatorResultType::NEED_MORE_INPUT; +- } +-} +- +-OperatorFinalizeResultType ToArrowIPCFunction::FunctionFinal(ExecutionContext &context, TableFunctionInput &data_p, +- DataChunk &output) { +- auto &data = (ToArrowIpcFunctionData &)*data_p.bind_data; +- auto &local_state = (ToArrowIpcLocalState &)*data_p.local_state; +- std::shared_ptr arrow_serialized_ipc_buffer; +- +- // TODO clean up +- if (local_state.appender) { +- ArrowArray arr = local_state.appender->Finalize(); +- auto record_batch = arrow::ImportRecordBatch(&arr, data.schema).ValueOrDie(); +- +- // Serialize recordbatch +- auto options = arrow::ipc::IpcWriteOptions::Defaults(); +- auto result = arrow::ipc::SerializeRecordBatch(*record_batch, options); +- arrow_serialized_ipc_buffer = result.ValueOrDie(); +- +- auto wrapped_buffer = make_buffer(arrow_serialized_ipc_buffer); +- auto &vector = output.data[0]; +- StringVector::AddBuffer(vector, wrapped_buffer); +- auto data_ptr = (string_t *)vector.GetData(); +- *data_ptr = string_t((const char *)arrow_serialized_ipc_buffer->data(), arrow_serialized_ipc_buffer->size()); +- output.SetCardinality(1); +- local_state.appender.reset(); +- output.data[1].SetValue(0, Value::BOOLEAN(0)); +- } +- +- return OperatorFinalizeResultType::FINISHED; ++ return OperatorFinalizeResultType::FINISHED; + } + +- + TableFunction ToArrowIPCFunction::GetFunction() { +- TableFunction fun("to_arrow_ipc", {LogicalType::TABLE}, nullptr, ToArrowIPCFunction::Bind, +- ToArrowIPCFunction::InitGlobal,ToArrowIPCFunction::InitLocal); +- fun.in_out_function = ToArrowIPCFunction::Function; +- fun.in_out_function_final = ToArrowIPCFunction::FunctionFinal; ++ TableFunction fun("to_arrow_ipc", {LogicalType::TABLE}, nullptr, ++ ToArrowIPCFunction::Bind, ToArrowIPCFunction::InitGlobal, ++ ToArrowIPCFunction::InitLocal); ++ fun.in_out_function = ToArrowIPCFunction::Function; ++ fun.in_out_function_final = ToArrowIPCFunction::FunctionFinal; + +- return fun; ++ return fun; + } + + } // namespace duckdb +\ No newline at end of file +diff --git a/src/include/arrow_extension.hpp b/src/include/arrow_extension.hpp +index 8ad174e..7d600d1 100644 +--- a/src/include/arrow_extension.hpp ++++ b/src/include/arrow_extension.hpp +@@ -6,8 +6,8 @@ namespace duckdb { + + class ArrowExtension : public Extension { + public: +- void Load(DuckDB &db) override; +- std::string Name() override; ++ void Load(DuckDB &db) override; ++ std::string Name() override; + }; + + } // namespace duckdb +diff --git a/src/include/arrow_scan_ipc.hpp b/src/include/arrow_scan_ipc.hpp +index 4ec1b9f..66a7827 100644 +--- a/src/include/arrow_scan_ipc.hpp ++++ b/src/include/arrow_scan_ipc.hpp +@@ -9,20 +9,23 @@ namespace duckdb { + + struct ArrowIPCScanFunctionData : public ArrowScanFunctionData { + public: +- using ArrowScanFunctionData::ArrowScanFunctionData; +- unique_ptr stream_decoder = nullptr; ++ using ArrowScanFunctionData::ArrowScanFunctionData; ++ unique_ptr stream_decoder = nullptr; + }; + +-// IPC Table scan is identical to ArrowTableFunction arrow scan except instead of CDataInterface header pointers, it +-// takes a bunch of pointers pointing to buffers containing data in Arrow IPC format ++// IPC Table scan is identical to ArrowTableFunction arrow scan except instead ++// of CDataInterface header pointers, it takes a bunch of pointers pointing to ++// buffers containing data in Arrow IPC format + struct ArrowIPCTableFunction : public ArrowTableFunction { + public: +- static TableFunction GetFunction(); ++ static TableFunction GetFunction(); + + private: +- static unique_ptr ArrowScanBind(ClientContext &context, TableFunctionBindInput &input, +- vector &return_types, vector &names); +- static void ArrowScanFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output); ++ static unique_ptr ++ ArrowScanBind(ClientContext &context, TableFunctionBindInput &input, ++ vector &return_types, vector &names); ++ static void ArrowScanFunction(ClientContext &context, ++ TableFunctionInput &data_p, DataChunk &output); + }; + + } // namespace duckdb +diff --git a/src/include/arrow_stream_buffer.hpp b/src/include/arrow_stream_buffer.hpp +index a4cbe97..e486c72 100644 +--- a/src/include/arrow_stream_buffer.hpp ++++ b/src/include/arrow_stream_buffer.hpp +@@ -14,6 +14,7 @@ + #include + #include + #include ++#include + + /// File copied from + /// https://github.com/duckdb/duckdb-wasm/blob/0ad10e7db4ef4025f5f4120be37addc4ebe29618/lib/include/duckdb/web/arrow_stream_buffer.h +@@ -21,76 +22,72 @@ namespace duckdb { + + struct ArrowIPCStreamBuffer : public arrow::ipc::Listener { + protected: +- /// The schema +- std::shared_ptr schema_; +- /// The batches +- std::vector> batches_; +- /// Is eos? +- bool is_eos_; ++ /// The schema ++ std::shared_ptr schema_; ++ /// The batches ++ std::vector> batches_; ++ /// Is eos? ++ bool is_eos_; + +- /// Decoded a record batch +- arrow::Status OnSchemaDecoded(std::shared_ptr schema); +- /// Decoded a record batch +- arrow::Status OnRecordBatchDecoded(std::shared_ptr record_batch); +- /// Reached end of stream +- arrow::Status OnEOS(); ++ /// Decoded a record batch ++ arrow::Status OnSchemaDecoded(std::shared_ptr schema); ++ /// Decoded a record batch ++ arrow::Status ++ OnRecordBatchDecoded(std::shared_ptr record_batch); ++ /// Reached end of stream ++ arrow::Status OnEOS(); + + public: +- /// Constructor +- ArrowIPCStreamBuffer(); ++ /// Constructor ++ ArrowIPCStreamBuffer(); + +- /// Is end of stream? +- bool is_eos() const { +- return is_eos_; +- } +- /// Return the schema +- std::shared_ptr &schema() { +- return schema_; +- } +- /// Return the batches +- std::vector> &batches() { +- return batches_; +- } ++ /// Is end of stream? ++ bool is_eos() const { return is_eos_; } ++ /// Return the schema ++ std::shared_ptr &schema() { return schema_; } ++ /// Return the batches ++ std::vector> &batches() { ++ return batches_; ++ } + }; + + struct ArrowIPCStreamBufferReader : public arrow::RecordBatchReader { + protected: +- /// The buffer +- std::shared_ptr buffer_; +- /// The batch index +- size_t next_batch_id_; ++ /// The buffer ++ std::shared_ptr buffer_; ++ /// The batch index ++ size_t next_batch_id_; + + public: +- /// Constructor +- ArrowIPCStreamBufferReader(std::shared_ptr buffer); +- /// Destructor +- ~ArrowIPCStreamBufferReader() = default; ++ /// Constructor ++ ArrowIPCStreamBufferReader(std::shared_ptr buffer); ++ /// Destructor ++ ~ArrowIPCStreamBufferReader() = default; + +- /// Get the schema +- std::shared_ptr schema() const override; +- /// Read the next record batch in the stream. Return null for batch when reaching end of stream +- arrow::Status ReadNext(std::shared_ptr *batch) override; ++ /// Get the schema ++ std::shared_ptr schema() const override; ++ /// Read the next record batch in the stream. Return null for batch when ++ /// reaching end of stream ++ arrow::Status ReadNext(std::shared_ptr *batch) override; + +- /// Create arrow array stream wrapper +- static duckdb::unique_ptr CreateStream(uintptr_t buffer_ptr, +- ArrowStreamParameters ¶meters); +- /// Create arrow array stream wrapper +- static void GetSchema(uintptr_t buffer_ptr, ArrowSchemaWrapper &schema); ++ /// Create arrow array stream wrapper ++ static duckdb::unique_ptr ++ CreateStream(uintptr_t buffer_ptr, ArrowStreamParameters ¶meters); ++ /// Create arrow array stream wrapper ++ static void GetSchema(uintptr_t buffer_ptr, ArrowSchemaWrapper &schema); + }; + + struct BufferingArrowIPCStreamDecoder : public arrow::ipc::StreamDecoder { + protected: +- /// The buffer +- std::shared_ptr buffer_; ++ /// The buffer ++ std::shared_ptr buffer_; + + public: +- /// Constructor +- BufferingArrowIPCStreamDecoder( +- std::shared_ptr buffer = std::make_shared()); +- /// Get the buffer +- std::shared_ptr &buffer() { +- return buffer_; +- } ++ /// Constructor ++ BufferingArrowIPCStreamDecoder(std::shared_ptr buffer = ++ std::make_shared()); ++ /// Get the buffer ++ std::shared_ptr &buffer() { return buffer_; } + }; + + } // namespace duckdb +diff --git a/src/include/arrow_to_ipc.hpp b/src/include/arrow_to_ipc.hpp +index b4eb9d4..6c8995a 100644 +--- a/src/include/arrow_to_ipc.hpp ++++ b/src/include/arrow_to_ipc.hpp +@@ -3,36 +3,42 @@ + #include "arrow/buffer.h" + #include "duckdb.hpp" + ++#include ++ + namespace duckdb { + + class ArrowStringVectorBuffer : public VectorBuffer { + public: +- explicit ArrowStringVectorBuffer(std::shared_ptr buffer_p) +- : VectorBuffer(VectorBufferType::OPAQUE_BUFFER), buffer(std::move(buffer_p)) { +- } ++ explicit ArrowStringVectorBuffer(std::shared_ptr buffer_p) ++ : VectorBuffer(VectorBufferType::OPAQUE_BUFFER), ++ buffer(std::move(buffer_p)) {} + + private: +- std::shared_ptr buffer; ++ std::shared_ptr buffer; + }; + +- + class ToArrowIPCFunction { + public: +- //! note: this is the number of vectors per chunk +- static constexpr idx_t DEFAULT_CHUNK_SIZE = 120; ++ //! note: this is the number of vectors per chunk ++ static constexpr idx_t DEFAULT_CHUNK_SIZE = 120; + +- static TableFunction GetFunction(); ++ static TableFunction GetFunction(); + + private: +- static unique_ptr InitLocal(ExecutionContext &context, TableFunctionInitInput &input, +- GlobalTableFunctionState *global_state); +- static unique_ptr InitGlobal(ClientContext &context, +- TableFunctionInitInput &input); +- static unique_ptr Bind(ClientContext &context, TableFunctionBindInput &input, +- vector &return_types, vector &names); +- static OperatorResultType Function(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input, +- DataChunk &output); +- static OperatorFinalizeResultType FunctionFinal(ExecutionContext &context, TableFunctionInput &data_p, +- DataChunk &output); ++ static unique_ptr ++ InitLocal(ExecutionContext &context, TableFunctionInitInput &input, ++ GlobalTableFunctionState *global_state); ++ static unique_ptr ++ InitGlobal(ClientContext &context, TableFunctionInitInput &input); ++ static unique_ptr Bind(ClientContext &context, ++ TableFunctionBindInput &input, ++ vector &return_types, ++ vector &names); ++ static OperatorResultType Function(ExecutionContext &context, ++ TableFunctionInput &data_p, ++ DataChunk &input, DataChunk &output); ++ static OperatorFinalizeResultType FunctionFinal(ExecutionContext &context, ++ TableFunctionInput &data_p, ++ DataChunk &output); + }; // namespace duckdb +-} ++} // namespace duckdb From 5e4938b7b34c438237b3e10162e391c1ae80b23b Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 12 Apr 2024 13:19:08 +0200 Subject: [PATCH 108/201] create an assertion out of this, TemporaryDirectoryHandle should never be made if the path is empty --- src/storage/temporary_file_manager.cpp | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/storage/temporary_file_manager.cpp b/src/storage/temporary_file_manager.cpp index aa17eaaea0a0..d90641bf4ad7 100644 --- a/src/storage/temporary_file_manager.cpp +++ b/src/storage/temporary_file_manager.cpp @@ -190,11 +190,10 @@ idx_t TemporaryFileHandle::GetPositionInFile(idx_t index) { TemporaryDirectoryHandle::TemporaryDirectoryHandle(DatabaseInstance &db, string path_p, optional_idx max_swap_space) : db(db), temp_directory(std::move(path_p)), temp_file(make_uniq(db, temp_directory)) { auto &fs = FileSystem::GetFileSystem(db); - if (!temp_directory.empty()) { - if (!fs.DirectoryExists(temp_directory)) { - fs.CreateDirectory(temp_directory); - created_directory = true; - } + D_ASSERT(!temp_directory.empty()); + if (!fs.DirectoryExists(temp_directory)) { + fs.CreateDirectory(temp_directory); + created_directory = true; } temp_file->SetMaxSwapSpace(max_swap_space); } @@ -253,14 +252,13 @@ bool TemporaryFileIndex::IsValid() const { //===--------------------------------------------------------------------===// static idx_t GetDefaultMax(const string &path) { - // Use the available disk space + D_ASSERT(!path.empty()); auto disk_space = FileSystem::GetAvailableDiskSpace(path); - idx_t default_value = 0; - if (disk_space.IsValid()) { - // Only use 90% of the available disk space - default_value = static_cast(static_cast(disk_space.GetIndex()) * 0.9); - } - return default_value; + // Use the available disk space + // We have made sure that the file exists before we call this, it shouldn't fail + D_ASSERT(disk_space.IsValid()); + // Only use 90% of the available disk space + return static_cast(static_cast(disk_space.GetIndex()) * 0.9); } TemporaryFileManager::TemporaryFileManager(DatabaseInstance &db, const string &temp_directory_p) From f94b09f0d03e3835b1d81b17680d881077984b08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Fri, 12 Apr 2024 15:11:37 +0200 Subject: [PATCH 109/201] woo --- .../storage/compression/alp/algorithm/alp.hpp | 4 +-- .../duckdb/storage/table/segment_tree.hpp | 10 +++---- src/storage/table/array_column_data.cpp | 6 ++-- src/storage/table/column_data.cpp | 6 ++-- src/storage/table/column_segment.cpp | 7 +++-- src/storage/table/list_column_data.cpp | 8 +++--- src/storage/table/row_group.cpp | 15 +++++----- src/storage/table/row_group_collection.cpp | 28 ++++++++++--------- src/storage/table/row_version_manager.cpp | 2 +- src/storage/table/struct_column_data.cpp | 2 +- src/storage/table/update_segment.cpp | 12 ++++---- src/transaction/cleanup_state.cpp | 4 +-- src/transaction/commit_state.cpp | 6 ++-- src/transaction/duck_transaction_manager.cpp | 4 +-- 14 files changed, 60 insertions(+), 54 deletions(-) diff --git a/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp b/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp index d216f23e58af..a8c7e91f3e66 100644 --- a/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp +++ b/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp @@ -359,8 +359,8 @@ struct AlpCompression { BitpackingPrimitives::PackBuffer(state.values_encoded, u_encoded_integers, n_values, bit_width); } - state.bit_width = bit_width; // in bits - state.bp_size = bp_size; // in bytes + state.bit_width = bit_width; // in bits + state.bp_size = bp_size; // in bytes state.frame_of_reference = static_cast(min_value); // understood this can be negative } diff --git a/src/include/duckdb/storage/table/segment_tree.hpp b/src/include/duckdb/storage/table/segment_tree.hpp index e2778bc33137..e267d2e351eb 100644 --- a/src/include/duckdb/storage/table/segment_tree.hpp +++ b/src/include/duckdb/storage/table/segment_tree.hpp @@ -83,11 +83,11 @@ class SegmentTree { if (index < 0) { // load all segments LoadAllSegments(l); - index = nodes.size() + index; + index += nodes.size(); if (index < 0) { return nullptr; } - return nodes[index].node.get(); + return nodes[UnsafeNumericCast(index)].node.get(); } else { // lazily load segments until we reach the specific segment while (idx_t(index) >= nodes.size() && LoadNextSegment(l)) { @@ -95,7 +95,7 @@ class SegmentTree { if (idx_t(index) >= nodes.size()) { return nullptr; } - return nodes[index].node.get(); + return nodes[UnsafeNumericCast(index)].node.get(); } } //! Gets the next segment @@ -116,7 +116,7 @@ class SegmentTree { #ifdef DEBUG D_ASSERT(nodes[segment->index].node.get() == segment); #endif - return GetSegmentByIndex(l, segment->index + 1); + return GetSegmentByIndex(l, UnsafeNumericCast(segment->index + 1)); } //! Gets a pointer to the last segment. Useful for appends. @@ -182,7 +182,7 @@ class SegmentTree { if (segment_start >= nodes.size() - 1) { return; } - nodes.erase(nodes.begin() + segment_start + 1, nodes.end()); + nodes.erase(nodes.begin() + UnsafeNumericCast(segment_start) + 1, nodes.end()); } //! Get the segment index of the column segment for the given row diff --git a/src/storage/table/array_column_data.cpp b/src/storage/table/array_column_data.cpp index 31cb93facf7f..be8b8a219c83 100644 --- a/src/storage/table/array_column_data.cpp +++ b/src/storage/table/array_column_data.cpp @@ -120,9 +120,9 @@ void ArrayColumnData::RevertAppend(row_t start_row) { validity.RevertAppend(start_row); // Revert child column auto array_size = ArrayType::GetSize(type); - child_column->RevertAppend(start_row * array_size); + child_column->RevertAppend(start_row * UnsafeNumericCast(array_size)); - this->count = start_row - this->start; + this->count = UnsafeNumericCast(start_row) - this->start; } idx_t ArrayColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { @@ -162,7 +162,7 @@ void ArrayColumnData::FetchRow(TransactionData transaction, ColumnFetchState &st // We need to fetch between [row_id * array_size, (row_id + 1) * array_size) auto child_state = make_uniq(); child_state->Initialize(child_type, nullptr); - child_column->InitializeScanWithOffset(*child_state, row_id * array_size); + child_column->InitializeScanWithOffset(*child_state, UnsafeNumericCast(row_id) * array_size); Vector child_scan(child_type, array_size); child_column->ScanCount(*child_state, child_scan, array_size); VectorOperations::Copy(child_scan, child_vec, array_size, 0, result_idx * array_size); diff --git a/src/storage/table/column_data.cpp b/src/storage/table/column_data.cpp index 1ccc4b8a19f9..5e11c148f2fe 100644 --- a/src/storage/table/column_data.cpp +++ b/src/storage/table/column_data.cpp @@ -105,7 +105,8 @@ idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remai if (state.scan_options && state.scan_options->force_fetch_row) { for (idx_t i = 0; i < scan_count; i++) { ColumnFetchState fetch_state; - state.current->FetchRow(fetch_state, UnsafeNumericCast(state.row_index + i), result, result_offset + i); + state.current->FetchRow(fetch_state, UnsafeNumericCast(state.row_index + i), result, + result_offset + i); } } else { state.current->Scan(state, scan_count, result, result_offset, @@ -342,7 +343,8 @@ idx_t ColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { D_ASSERT(row_id >= 0); D_ASSERT(idx_t(row_id) >= start); // perform the fetch within the segment - state.row_index = start + ((UnsafeNumericCast(row_id) - start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE); + state.row_index = + start + ((UnsafeNumericCast(row_id) - start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE); state.current = data.GetSegment(state.row_index); state.internal_index = state.current->start; return ScanVector(state, result, STANDARD_VECTOR_SIZE, false); diff --git a/src/storage/table/column_segment.cpp b/src/storage/table/column_segment.cpp index 175791b612a4..d0f2cd69545c 100644 --- a/src/storage/table/column_segment.cpp +++ b/src/storage/table/column_segment.cpp @@ -60,8 +60,8 @@ unique_ptr ColumnSegment::CreateTransientSegment(DatabaseInstance } else { buffer_manager.Allocate(MemoryTag::IN_MEMORY_TABLE, segment_size, false, &block); } - return make_uniq(db, std::move(block), type, ColumnSegmentType::TRANSIENT, start, 0, *function, - BaseStatistics::CreateEmpty(type), INVALID_BLOCK, 0, segment_size); + return make_uniq(db, std::move(block), type, ColumnSegmentType::TRANSIENT, start, 0U, *function, + BaseStatistics::CreateEmpty(type), INVALID_BLOCK, 0U, segment_size); } //===--------------------------------------------------------------------===// @@ -132,7 +132,8 @@ void ColumnSegment::ScanPartial(ColumnScanState &state, idx_t scan_count, Vector // Fetch //===--------------------------------------------------------------------===// void ColumnSegment::FetchRow(ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - function.get().fetch_row(*this, state, row_id - this->start, result, result_idx); + function.get().fetch_row(*this, state, UnsafeNumericCast(UnsafeNumericCast(row_id) - this->start), + result, result_idx); } //===--------------------------------------------------------------------===// diff --git a/src/storage/table/list_column_data.cpp b/src/storage/table/list_column_data.cpp index 418a4bd4644e..21aa9b5af410 100644 --- a/src/storage/table/list_column_data.cpp +++ b/src/storage/table/list_column_data.cpp @@ -44,7 +44,7 @@ uint64_t ListColumnData::FetchListOffset(idx_t row_idx) { auto segment = data.GetSegment(row_idx); ColumnFetchState fetch_state; Vector result(type, 1); - segment->FetchRow(fetch_state, row_idx, result, 0); + segment->FetchRow(fetch_state, UnsafeNumericCast(row_idx), result, 0U); // initialize the child scan with the required offset return FlatVector::GetData(result)[0]; @@ -235,7 +235,7 @@ void ListColumnData::RevertAppend(row_t start_row) { if (column_count > start) { // revert append in the child column auto list_offset = FetchListOffset(column_count - 1); - child_column->RevertAppend(list_offset); + child_column->RevertAppend(UnsafeNumericCast(list_offset)); } } @@ -269,8 +269,8 @@ void ListColumnData::FetchRow(TransactionData transaction, ColumnFetchState &sta } // now perform the fetch within the segment - auto start_offset = idx_t(row_id) == this->start ? 0 : FetchListOffset(row_id - 1); - auto end_offset = FetchListOffset(row_id); + auto start_offset = idx_t(row_id) == this->start ? 0 : FetchListOffset(UnsafeNumericCast(row_id - 1)); + auto end_offset = FetchListOffset(UnsafeNumericCast(row_id)); validity.FetchRow(transaction, *state.child_states[0], row_id, result, result_idx); auto &validity = FlatVector::Validity(result); diff --git a/src/storage/table/row_group.cpp b/src/storage/table/row_group.cpp index 824827d76a97..2de5faf7f735 100644 --- a/src/storage/table/row_group.cpp +++ b/src/storage/table/row_group.cpp @@ -487,7 +487,7 @@ void RowGroup::TemplatedScan(TransactionData transaction, CollectionScanState &s if (column == COLUMN_IDENTIFIER_ROW_ID) { // scan row id D_ASSERT(result.data[i].GetType().InternalType() == ROW_TYPE); - result.data[i].Sequence(this->start + current_row, 1, count); + result.data[i].Sequence(UnsafeNumericCast(this->start + current_row), 1, count); } else { auto &col_data = GetColumn(column); if (TYPE != TableScanType::TABLE_SCAN_REGULAR) { @@ -551,7 +551,8 @@ void RowGroup::TemplatedScan(TransactionData transaction, CollectionScanState &s result.data[i].SetVectorType(VectorType::FLAT_VECTOR); auto result_data = FlatVector::GetData(result.data[i]); for (size_t sel_idx = 0; sel_idx < approved_tuple_count; sel_idx++) { - result_data[sel_idx] = this->start + current_row + sel.get_index(sel_idx); + result_data[sel_idx] = + UnsafeNumericCast(this->start + current_row + sel.get_index(sel_idx)); } } else { auto &col_data = GetColumn(column); @@ -703,7 +704,7 @@ void RowGroup::RevertAppend(idx_t row_group_start) { auto &vinfo = GetOrCreateVersionInfo(); vinfo.RevertAppend(row_group_start - this->start); for (auto &column : columns) { - column->RevertAppend(row_group_start); + column->RevertAppend(UnsafeNumericCast(row_group_start)); } this->count = MinValue(row_group_start - this->start, this->count); Verify(); @@ -955,7 +956,7 @@ idx_t RowGroup::Delete(TransactionData transaction, DataTable &table, row_t *ids for (idx_t i = 0; i < count; i++) { D_ASSERT(ids[i] >= 0); D_ASSERT(idx_t(ids[i]) >= this->start && idx_t(ids[i]) < this->start + this->count); - del_state.Delete(ids[i] - this->start); + del_state.Delete(ids[i] - UnsafeNumericCast(this->start)); } del_state.Flush(); return del_state.delete_count; @@ -975,15 +976,15 @@ idx_t RowGroup::DeleteRows(idx_t vector_idx, transaction_t transaction_id, row_t void VersionDeleteState::Delete(row_t row_id) { D_ASSERT(row_id >= 0); - idx_t vector_idx = row_id / STANDARD_VECTOR_SIZE; - idx_t idx_in_vector = row_id - vector_idx * STANDARD_VECTOR_SIZE; + idx_t vector_idx = UnsafeNumericCast(row_id) / STANDARD_VECTOR_SIZE; + idx_t idx_in_vector = UnsafeNumericCast(row_id) - vector_idx * STANDARD_VECTOR_SIZE; if (current_chunk != vector_idx) { Flush(); current_chunk = vector_idx; chunk_row = vector_idx * STANDARD_VECTOR_SIZE; } - rows[count++] = idx_in_vector; + rows[count++] = UnsafeNumericCast(idx_in_vector); } void VersionDeleteState::Flush() { diff --git a/src/storage/table/row_group_collection.cpp b/src/storage/table/row_group_collection.cpp index 5fae55491020..00daedabcdc5 100644 --- a/src/storage/table/row_group_collection.cpp +++ b/src/storage/table/row_group_collection.cpp @@ -95,7 +95,7 @@ void RowGroupCollection::InitializeEmpty() { void RowGroupCollection::AppendRowGroup(SegmentLock &l, idx_t start_row) { D_ASSERT(start_row >= row_start); - auto new_row_group = make_uniq(*this, start_row, 0); + auto new_row_group = make_uniq(*this, start_row, 0U); new_row_group->InitializeEmpty(types); row_groups->AppendSegment(l, std::move(new_row_group)); } @@ -271,13 +271,13 @@ void RowGroupCollection::Fetch(TransactionData transaction, DataChunk &result, c { idx_t segment_index; auto l = row_groups->Lock(); - if (!row_groups->TryGetSegmentIndex(l, row_id, segment_index)) { + if (!row_groups->TryGetSegmentIndex(l, UnsafeNumericCast(row_id), segment_index)) { // in parallel append scenarios it is possible for the row_id continue; } - row_group = row_groups->GetSegmentByIndex(l, segment_index); + row_group = row_groups->GetSegmentByIndex(l, UnsafeNumericCast(segment_index)); } - if (!row_group->Fetch(transaction, row_id - row_group->start)) { + if (!row_group->Fetch(transaction, UnsafeNumericCast(row_id) - row_group->start)) { continue; } row_group->FetchRow(transaction, state, column_ids, row_id, result, count); @@ -306,7 +306,7 @@ bool RowGroupCollection::IsEmpty(SegmentLock &l) const { } void RowGroupCollection::InitializeAppend(TransactionData transaction, TableAppendState &state) { - state.row_start = total_rows; + state.row_start = UnsafeNumericCast(total_rows.load()); state.current_row = state.row_start; state.total_append_count = 0; @@ -433,7 +433,7 @@ void RowGroupCollection::RevertAppendInternal(idx_t start_row) { // revert from the last segment segment_index = segment_count - 1; } - auto &segment = *row_groups->GetSegmentByIndex(l, segment_index); + auto &segment = *row_groups->GetSegmentByIndex(l, UnsafeNumericCast(segment_index)); // remove any segments AFTER this segment: they should be deleted entirely row_groups->EraseSegments(l, segment_index); @@ -468,7 +468,7 @@ idx_t RowGroupCollection::Delete(TransactionData transaction, DataTable &table, idx_t pos = 0; do { idx_t start = pos; - auto row_group = row_groups->GetSegment(ids[start]); + auto row_group = row_groups->GetSegment(UnsafeNumericCast(ids[start])); for (pos++; pos < count; pos++) { D_ASSERT(ids[pos] >= 0); // check if this id still belongs to this row group @@ -494,10 +494,12 @@ void RowGroupCollection::Update(TransactionData transaction, row_t *ids, const v idx_t pos = 0; do { idx_t start = pos; - auto row_group = row_groups->GetSegment(ids[pos]); + auto row_group = row_groups->GetSegment(UnsafeNumericCast(ids[pos])); row_t base_id = - row_group->start + ((ids[pos] - row_group->start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE); - row_t max_id = MinValue(base_id + STANDARD_VECTOR_SIZE, row_group->start + row_group->count); + UnsafeNumericCast(row_group->start + ((UnsafeNumericCast(ids[pos]) - row_group->start) / + STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE)); + auto max_id = MinValue(base_id + STANDARD_VECTOR_SIZE, + UnsafeNumericCast(row_group->start + row_group->count)); for (pos++; pos < updates.size(); pos++) { D_ASSERT(ids[pos] >= 0); // check if this id still belongs to this vector in this row group @@ -544,8 +546,8 @@ void RowGroupCollection::RemoveFromIndexes(TableIndexList &indexes, Vector &row_ result.Reset(); // figure out which row_group to fetch from auto row_id = row_ids[r]; - auto row_group = row_groups->GetSegment(row_id); - auto row_group_vector_idx = (row_id - row_group->start) / STANDARD_VECTOR_SIZE; + auto row_group = row_groups->GetSegment(UnsafeNumericCast(row_id)); + auto row_group_vector_idx = (UnsafeNumericCast(row_id) - row_group->start) / STANDARD_VECTOR_SIZE; auto base_row_id = row_group_vector_idx * STANDARD_VECTOR_SIZE + row_group->start; // fetch the current vector @@ -586,7 +588,7 @@ void RowGroupCollection::UpdateColumn(TransactionData transaction, Vector &row_i } // find the row_group this id belongs to auto primary_column_idx = column_path[0]; - auto row_group = row_groups->GetSegment(first_id); + auto row_group = row_groups->GetSegment(UnsafeNumericCast(first_id)); row_group->UpdateColumn(transaction, updates, row_ids, column_path); row_group->MergeIntoStatistics(primary_column_idx, stats.GetStats(primary_column_idx).Statistics()); diff --git a/src/storage/table/row_version_manager.cpp b/src/storage/table/row_version_manager.cpp index ead21f89234f..c317f246183b 100644 --- a/src/storage/table/row_version_manager.cpp +++ b/src/storage/table/row_version_manager.cpp @@ -69,7 +69,7 @@ bool RowVersionManager::Fetch(TransactionData transaction, idx_t row) { if (!info) { return true; } - return info->Fetch(transaction, row - vector_index * STANDARD_VECTOR_SIZE); + return info->Fetch(transaction, UnsafeNumericCast(row - vector_index * STANDARD_VECTOR_SIZE)); } void RowVersionManager::AppendVersionInfo(TransactionData transaction, idx_t count, idx_t row_group_start, diff --git a/src/storage/table/struct_column_data.cpp b/src/storage/table/struct_column_data.cpp index 5fca4ddae306..e31867a24125 100644 --- a/src/storage/table/struct_column_data.cpp +++ b/src/storage/table/struct_column_data.cpp @@ -157,7 +157,7 @@ void StructColumnData::RevertAppend(row_t start_row) { for (auto &sub_column : sub_columns) { sub_column->RevertAppend(start_row); } - this->count = start_row - this->start; + this->count = UnsafeNumericCast(start_row) - this->start; } idx_t StructColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { diff --git a/src/storage/table/update_segment.cpp b/src/storage/table/update_segment.cpp index 624d987edab6..317023de8900 100644 --- a/src/storage/table/update_segment.cpp +++ b/src/storage/table/update_segment.cpp @@ -581,7 +581,7 @@ void UpdateSegment::InitializeUpdateInfo(UpdateInfo &info, row_t *ids, const Sel auto idx = sel.get_index(i); auto id = ids[idx]; D_ASSERT(idx_t(id) >= vector_offset && idx_t(id) < vector_offset + STANDARD_VECTOR_SIZE); - info.tuples[i] = NumericCast(id - vector_offset); + info.tuples[i] = NumericCast(NumericCast(id) - vector_offset); }; } @@ -697,7 +697,7 @@ static idx_t MergeLoop(row_t a[], sel_t b[], idx_t acount, idx_t bcount, idx_t a idx_t count = 0; while (aidx < acount && bidx < bcount) { auto a_index = asel.get_index(aidx); - auto a_id = a[a_index] - aoffset; + auto a_id = UnsafeNumericCast(a[a_index]) - aoffset; auto b_id = b[bidx]; if (a_id == b_id) { merge(a_id, a_index, bidx, count); @@ -716,7 +716,7 @@ static idx_t MergeLoop(row_t a[], sel_t b[], idx_t acount, idx_t bcount, idx_t a } for (; aidx < acount; aidx++) { auto a_index = asel.get_index(aidx); - pick_a(a[a_index] - aoffset, a_index, count); + pick_a(UnsafeNumericCast(a[a_index]) - aoffset, a_index, count); count++; } for (; bidx < bcount; bidx++) { @@ -777,7 +777,7 @@ static void MergeUpdateLoopInternal(UpdateInfo *base_info, V *base_table_data, U for (idx_t i = 0; i < count; i++) { auto idx = sel.get_index(i); // we have to merge the info for "ids[i]" - auto update_id = ids[idx] - base_id; + auto update_id = UnsafeNumericCast(ids[idx]) - base_id; while (update_info_offset < update_info->N && update_info->tuples[update_info_offset] < update_id) { // old id comes before the current id: write it @@ -1103,7 +1103,7 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect // get the vector index based on the first id // we assert that all updates must be part of the same vector auto first_id = ids[sel.get_index(0)]; - idx_t vector_index = (first_id - column_data.start) / STANDARD_VECTOR_SIZE; + idx_t vector_index = (UnsafeNumericCast(first_id) - column_data.start) / STANDARD_VECTOR_SIZE; idx_t vector_offset = column_data.start + vector_index * STANDARD_VECTOR_SIZE; D_ASSERT(idx_t(first_id) >= column_data.start); @@ -1116,7 +1116,7 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect // there is already a version here, check if there are any conflicts and search for the node that belongs to // this transaction in the version chain auto base_info = root->info[vector_index]->info.get(); - CheckForConflicts(base_info->next, transaction, ids, sel, count, vector_offset, node); + CheckForConflicts(base_info->next, transaction, ids, sel, count, UnsafeNumericCast(vector_offset), node); // there are no conflicts // first, check if this thread has already done any updates diff --git a/src/transaction/cleanup_state.cpp b/src/transaction/cleanup_state.cpp index 0ec438c20339..802405483aeb 100644 --- a/src/transaction/cleanup_state.cpp +++ b/src/transaction/cleanup_state.cpp @@ -72,12 +72,12 @@ void CleanupState::CleanupDelete(DeleteInfo &info) { count = 0; if (info.is_consecutive) { for (idx_t i = 0; i < info.count; i++) { - row_numbers[count++] = info.base_row + i; + row_numbers[count++] = UnsafeNumericCast(info.base_row + i); } } else { auto rows = info.GetRows(); for (idx_t i = 0; i < info.count; i++) { - row_numbers[count++] = info.base_row + rows[i]; + row_numbers[count++] = UnsafeNumericCast(info.base_row + rows[i]); } } Flush(); diff --git a/src/transaction/commit_state.cpp b/src/transaction/commit_state.cpp index 986aac3b17c1..ae78d0d0c840 100644 --- a/src/transaction/commit_state.cpp +++ b/src/transaction/commit_state.cpp @@ -206,12 +206,12 @@ void CommitState::WriteDelete(DeleteInfo &info) { auto rows = FlatVector::GetData(delete_chunk->data[0]); if (info.is_consecutive) { for (idx_t i = 0; i < info.count; i++) { - rows[i] = info.base_row + i; + rows[i] = UnsafeNumericCast(info.base_row + i); } } else { auto delete_rows = info.GetRows(); for (idx_t i = 0; i < info.count; i++) { - rows[i] = info.base_row + delete_rows[i]; + rows[i] = UnsafeNumericCast(info.base_row) + delete_rows[i]; } } delete_chunk->SetCardinality(info.count); @@ -245,7 +245,7 @@ void CommitState::WriteUpdate(UpdateInfo &info) { auto row_ids = FlatVector::GetData(update_chunk->data[1]); idx_t start = column_data.start + info.vector_index * STANDARD_VECTOR_SIZE; for (idx_t i = 0; i < info.N; i++) { - row_ids[info.tuples[i]] = start + info.tuples[i]; + row_ids[info.tuples[i]] = UnsafeNumericCast(start + info.tuples[i]); } if (column_data.type.id() == LogicalTypeId::VALIDITY) { // zero-initialize the booleans diff --git a/src/transaction/duck_transaction_manager.cpp b/src/transaction/duck_transaction_manager.cpp index bf7014c93411..a3ec6d4a3d9f 100644 --- a/src/transaction/duck_transaction_manager.cpp +++ b/src/transaction/duck_transaction_manager.cpp @@ -311,7 +311,7 @@ void DuckTransactionManager::RemoveTransaction(DuckTransaction &transaction) noe if (i > 0) { // we garbage collected transactions: remove them from the list recently_committed_transactions.erase(recently_committed_transactions.begin(), - recently_committed_transactions.begin() + i); + recently_committed_transactions.begin() + UnsafeNumericCast(i)); } // check if we can free the memory of any old transactions i = active_transactions.empty() ? old_transactions.size() : 0; @@ -326,7 +326,7 @@ void DuckTransactionManager::RemoveTransaction(DuckTransaction &transaction) noe } if (i > 0) { // we garbage collected transactions: remove them from the list - old_transactions.erase(old_transactions.begin(), old_transactions.begin() + i); + old_transactions.erase(old_transactions.begin(), old_transactions.begin() + UnsafeNumericCast(i)); } } From 92141c56cece1981ef1e7bbf58a370250ebec5b1 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 12 Apr 2024 15:18:13 +0200 Subject: [PATCH 110/201] expand the keyword arguments in the stubs --- tools/pythonpkg/duckdb-stubs/__init__.pyi | 188 +++++++++--------- .../pythonpkg/scripts/connection_methods.json | 2 +- .../scripts/generate_connection_stubs.py | 12 +- .../generate_connection_wrapper_stubs.py | 4 +- 4 files changed, 105 insertions(+), 101 deletions(-) diff --git a/tools/pythonpkg/duckdb-stubs/__init__.pyi b/tools/pythonpkg/duckdb-stubs/__init__.pyi index f0630284282e..ff708ef930d5 100644 --- a/tools/pythonpkg/duckdb-stubs/__init__.pyi +++ b/tools/pythonpkg/duckdb-stubs/__init__.pyi @@ -269,7 +269,7 @@ class DuckDBPyConnection: def unregister_filesystem(self, name: str) -> None: ... def list_filesystems(self) -> list: ... def filesystem_is_registered(self, name: str) -> bool: ... - def create_function(self, name: str, function: function, parameters: Optional[List[DuckDBPyType]] = None, return_type: Optional[DuckDBPyType] = None, **kwargs) -> DuckDBPyConnection: ... + def create_function(self, name: str, function: function, parameters: Optional[List[DuckDBPyType]] = None, return_type: Optional[DuckDBPyType] = None, *, type: Optional[PythonUDFType] = PythonUDFType.NATIVE, null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, side_effects: bool = False) -> DuckDBPyConnection: ... def remove_function(self, name: str) -> DuckDBPyConnection: ... def sqltype(self, type_str: str) -> DuckDBPyType: ... def dtype(self, type_str: str) -> DuckDBPyType: ... @@ -292,10 +292,10 @@ class DuckDBPyConnection: def fetchmany(self, size: int = 1) -> List[Any]: ... def fetchall(self) -> List[Any]: ... def fetchnumpy(self) -> dict: ... - def fetchdf(self, **kwargs) -> pandas.DataFrame: ... - def fetch_df(self, **kwargs) -> pandas.DataFrame: ... - def df(self, **kwargs) -> pandas.DataFrame: ... - def fetch_df_chunk(self, vectors_per_chunk: int = 1, **kwargs) -> pandas.DataFrame: ... + def fetchdf(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... + def fetch_df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... + def df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... + def fetch_df_chunk(self, vectors_per_chunk: int = 1, *, date_as_object: bool = False) -> pandas.DataFrame: ... def pl(self, rows_per_batch: int = 1000000) -> polars.DataFrame: ... def fetch_arrow_table(self, rows_per_batch: int = 1000000) -> pyarrow.lib.Table: ... def arrow(self, rows_per_batch: int = 1000000) -> pyarrow.lib.Table: ... @@ -305,30 +305,30 @@ class DuckDBPyConnection: def begin(self) -> DuckDBPyConnection: ... def commit(self) -> DuckDBPyConnection: ... def rollback(self) -> DuckDBPyConnection: ... - def append(self, table_name: str, df: pandas.DataFrame, **kwargs) -> DuckDBPyConnection: ... + def append(self, table_name: str, df: pandas.DataFrame, *, by_name: bool = False) -> DuckDBPyConnection: ... def register(self, view_name: str, python_object: object) -> DuckDBPyConnection: ... def unregister(self, view_name: str) -> DuckDBPyConnection: ... def table(self, table_name: str) -> DuckDBPyRelation: ... def view(self, view_name: str) -> DuckDBPyRelation: ... def values(self, values: List[Any]) -> DuckDBPyRelation: ... def table_function(self, name: str, parameters: object = None) -> DuckDBPyRelation: ... - def read_json(self, name: str, **kwargs) -> DuckDBPyRelation: ... + def read_json(self, name: str, *, columns: Optional[Dict[str,str]] = None, sample_size: Optional[int] = None, maximum_depth: Optional[int] = None, records: Optional[str] = None, format: Optional[str] = None) -> DuckDBPyRelation: ... def extract_statements(self, query: str) -> List[Statement]: ... - def sql(self, query: str, **kwargs) -> DuckDBPyRelation: ... - def query(self, query: str, **kwargs) -> DuckDBPyRelation: ... - def from_query(self, query: str, **kwargs) -> DuckDBPyRelation: ... - def read_csv(self, path_or_buffer: Union[str, StringIO, TextIOBase], **kwargs) -> DuckDBPyRelation: ... - def from_csv_auto(self, path_or_buffer: Union[str, StringIO, TextIOBase], **kwargs) -> DuckDBPyRelation: ... + def sql(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... + def query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... + def from_query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... + def read_csv(self, path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, filename: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None) -> DuckDBPyRelation: ... + def from_csv_auto(self, path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, filename: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None) -> DuckDBPyRelation: ... def from_df(self, df: pandas.DataFrame) -> DuckDBPyRelation: ... def from_arrow(self, arrow_object: object) -> DuckDBPyRelation: ... - def from_parquet(self, file_glob: str, binary_as_string: bool = False, **kwargs) -> DuckDBPyRelation: ... - def read_parquet(self, file_glob: str, binary_as_string: bool = False, **kwargs) -> DuckDBPyRelation: ... + def from_parquet(self, file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... + def read_parquet(self, file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... def from_substrait(self, proto: str) -> DuckDBPyRelation: ... - def get_substrait(self, query: str, **kwargs) -> str: ... - def get_substrait_json(self, query: str, **kwargs) -> str: ... + def get_substrait(self, query: str, *, enable_optimizer: bool = True) -> str: ... + def get_substrait_json(self, query: str, *, enable_optimizer: bool = True) -> str: ... def from_substrait_json(self, json: str) -> DuckDBPyRelation: ... def get_table_names(self, query: str) -> List[str]: ... - def install_extension(self, extension: str, **kwargs) -> None: ... + def install_extension(self, extension: str, *, force_install: bool = False) -> None: ... def load_extension(self, extension: str) -> None: ... # END OF CONNECTION METHODS @@ -585,81 +585,81 @@ def tokenize(query: str) -> List[Any]: ... # Do not edit this section manually, your changes will be overwritten! # START OF CONNECTION WRAPPER -def cursor(**kwargs) -> DuckDBPyConnection: ... -def register_filesystem(filesystem: str, **kwargs) -> None: ... -def unregister_filesystem(name: str, **kwargs) -> None: ... -def list_filesystems(**kwargs) -> list: ... -def filesystem_is_registered(name: str, **kwargs) -> bool: ... -def create_function(name: str, function: function, parameters: Optional[List[DuckDBPyType]] = None, return_type: Optional[DuckDBPyType] = None, **kwargs) -> DuckDBPyConnection: ... -def remove_function(name: str, **kwargs) -> DuckDBPyConnection: ... -def sqltype(type_str: str, **kwargs) -> DuckDBPyType: ... -def dtype(type_str: str, **kwargs) -> DuckDBPyType: ... -def type(type_str: str, **kwargs) -> DuckDBPyType: ... -def array_type(type: DuckDBPyType, size: int, **kwargs) -> DuckDBPyType: ... -def list_type(type: DuckDBPyType, **kwargs) -> DuckDBPyType: ... -def union_type(members: DuckDBPyType, **kwargs) -> DuckDBPyType: ... -def string_type(collation: str = "", **kwargs) -> DuckDBPyType: ... -def enum_type(name: str, type: DuckDBPyType, values: List[Any], **kwargs) -> DuckDBPyType: ... -def decimal_type(width: int, scale: int, **kwargs) -> DuckDBPyType: ... -def struct_type(fields: Union[Dict[str, DuckDBPyType], List[str]], **kwargs) -> DuckDBPyType: ... -def row_type(fields: Union[Dict[str, DuckDBPyType], List[str]], **kwargs) -> DuckDBPyType: ... -def map_type(key: DuckDBPyType, value: DuckDBPyType, **kwargs) -> DuckDBPyType: ... -def duplicate(**kwargs) -> DuckDBPyConnection: ... -def execute(query: object, parameters: object = None, multiple_parameter_sets: bool = False, **kwargs) -> DuckDBPyConnection: ... -def executemany(query: object, parameters: object = None, **kwargs) -> DuckDBPyConnection: ... -def close(**kwargs) -> None: ... -def interrupt(**kwargs) -> None: ... -def fetchone(**kwargs) -> Optional[tuple]: ... -def fetchmany(size: int = 1, **kwargs) -> List[Any]: ... -def fetchall(**kwargs) -> List[Any]: ... -def fetchnumpy(**kwargs) -> dict: ... -def fetchdf(**kwargs) -> pandas.DataFrame: ... -def fetch_df(**kwargs) -> pandas.DataFrame: ... -def df(**kwargs) -> pandas.DataFrame: ... -def fetch_df_chunk(vectors_per_chunk: int = 1, **kwargs) -> pandas.DataFrame: ... -def pl(rows_per_batch: int = 1000000, **kwargs) -> polars.DataFrame: ... -def fetch_arrow_table(rows_per_batch: int = 1000000, **kwargs) -> pyarrow.lib.Table: ... -def arrow(rows_per_batch: int = 1000000, **kwargs) -> pyarrow.lib.Table: ... -def fetch_record_batch(rows_per_batch: int = 1000000, **kwargs) -> pyarrow.lib.RecordBatchReader: ... -def torch(**kwargs) -> dict: ... -def tf(**kwargs) -> dict: ... -def begin(**kwargs) -> DuckDBPyConnection: ... -def commit(**kwargs) -> DuckDBPyConnection: ... -def rollback(**kwargs) -> DuckDBPyConnection: ... -def append(table_name: str, df: pandas.DataFrame, **kwargs) -> DuckDBPyConnection: ... -def register(view_name: str, python_object: object, **kwargs) -> DuckDBPyConnection: ... -def unregister(view_name: str, **kwargs) -> DuckDBPyConnection: ... -def table(table_name: str, **kwargs) -> DuckDBPyRelation: ... -def view(view_name: str, **kwargs) -> DuckDBPyRelation: ... -def values(values: List[Any], **kwargs) -> DuckDBPyRelation: ... -def table_function(name: str, parameters: object = None, **kwargs) -> DuckDBPyRelation: ... -def read_json(name: str, **kwargs) -> DuckDBPyRelation: ... -def extract_statements(query: str, **kwargs) -> List[Statement]: ... -def sql(query: str, **kwargs) -> DuckDBPyRelation: ... -def query(query: str, **kwargs) -> DuckDBPyRelation: ... -def from_query(query: str, **kwargs) -> DuckDBPyRelation: ... -def read_csv(path_or_buffer: Union[str, StringIO, TextIOBase], **kwargs) -> DuckDBPyRelation: ... -def from_csv_auto(path_or_buffer: Union[str, StringIO, TextIOBase], **kwargs) -> DuckDBPyRelation: ... -def from_df(df: pandas.DataFrame, **kwargs) -> DuckDBPyRelation: ... -def from_arrow(arrow_object: object, **kwargs) -> DuckDBPyRelation: ... -def from_parquet(file_glob: str, binary_as_string: bool = False, **kwargs) -> DuckDBPyRelation: ... -def read_parquet(file_glob: str, binary_as_string: bool = False, **kwargs) -> DuckDBPyRelation: ... -def from_substrait(proto: str, **kwargs) -> DuckDBPyRelation: ... -def get_substrait(query: str, **kwargs) -> str: ... -def get_substrait_json(query: str, **kwargs) -> str: ... -def from_substrait_json(json: str, **kwargs) -> DuckDBPyRelation: ... -def get_table_names(query: str, **kwargs) -> List[str]: ... -def install_extension(extension: str, **kwargs) -> None: ... -def load_extension(extension: str, **kwargs) -> None: ... -def project(df: pandas.DataFrame, project_expr: str, **kwargs) -> DuckDBPyRelation: ... -def distinct(df: pandas.DataFrame, **kwargs) -> DuckDBPyRelation: ... -def write_csv(df: pandas.DataFrame, *args: Any, **kwargs) -> None: ... -def aggregate(df: pandas.DataFrame, aggr_expr: str, group_expr: str = "", **kwargs) -> DuckDBPyRelation: ... -def alias(df: pandas.DataFrame, alias: str, **kwargs) -> DuckDBPyRelation: ... -def filter(df: pandas.DataFrame, filter_expr: str, **kwargs) -> DuckDBPyRelation: ... -def limit(df: pandas.DataFrame, n: int, offset: int = 0, **kwargs) -> DuckDBPyRelation: ... -def order(df: pandas.DataFrame, order_expr: str, **kwargs) -> DuckDBPyRelation: ... -def query_df(df: pandas.DataFrame, virtual_table_name: str, sql_query: str, **kwargs) -> DuckDBPyRelation: ... -def description(**kwargs) -> Optional[List[Any]]: ... -def rowcount(**kwargs) -> int: ... +def cursor(*, connection: DuckDBPyConnection) -> DuckDBPyConnection: ... +def register_filesystem(filesystem: str, *, connection: DuckDBPyConnection) -> None: ... +def unregister_filesystem(name: str, *, connection: DuckDBPyConnection) -> None: ... +def list_filesystems(*, connection: DuckDBPyConnection) -> list: ... +def filesystem_is_registered(name: str, *, connection: DuckDBPyConnection) -> bool: ... +def create_function(name: str, function: function, parameters: Optional[List[DuckDBPyType]] = None, return_type: Optional[DuckDBPyType] = None, *, type: Optional[PythonUDFType] = PythonUDFType.NATIVE, null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, side_effects: bool = False, connection: DuckDBPyConnection) -> DuckDBPyConnection: ... +def remove_function(name: str, *, connection: DuckDBPyConnection) -> DuckDBPyConnection: ... +def sqltype(type_str: str, *, connection: DuckDBPyConnection) -> DuckDBPyType: ... +def dtype(type_str: str, *, connection: DuckDBPyConnection) -> DuckDBPyType: ... +def type(type_str: str, *, connection: DuckDBPyConnection) -> DuckDBPyType: ... +def array_type(type: DuckDBPyType, size: int, *, connection: DuckDBPyConnection) -> DuckDBPyType: ... +def list_type(type: DuckDBPyType, *, connection: DuckDBPyConnection) -> DuckDBPyType: ... +def union_type(members: DuckDBPyType, *, connection: DuckDBPyConnection) -> DuckDBPyType: ... +def string_type(collation: str = "", *, connection: DuckDBPyConnection) -> DuckDBPyType: ... +def enum_type(name: str, type: DuckDBPyType, values: List[Any], *, connection: DuckDBPyConnection) -> DuckDBPyType: ... +def decimal_type(width: int, scale: int, *, connection: DuckDBPyConnection) -> DuckDBPyType: ... +def struct_type(fields: Union[Dict[str, DuckDBPyType], List[str]], *, connection: DuckDBPyConnection) -> DuckDBPyType: ... +def row_type(fields: Union[Dict[str, DuckDBPyType], List[str]], *, connection: DuckDBPyConnection) -> DuckDBPyType: ... +def map_type(key: DuckDBPyType, value: DuckDBPyType, *, connection: DuckDBPyConnection) -> DuckDBPyType: ... +def duplicate(*, connection: DuckDBPyConnection) -> DuckDBPyConnection: ... +def execute(query: object, parameters: object = None, multiple_parameter_sets: bool = False, *, connection: DuckDBPyConnection) -> DuckDBPyConnection: ... +def executemany(query: object, parameters: object = None, *, connection: DuckDBPyConnection) -> DuckDBPyConnection: ... +def close(*, connection: DuckDBPyConnection) -> None: ... +def interrupt(*, connection: DuckDBPyConnection) -> None: ... +def fetchone(*, connection: DuckDBPyConnection) -> Optional[tuple]: ... +def fetchmany(size: int = 1, *, connection: DuckDBPyConnection) -> List[Any]: ... +def fetchall(*, connection: DuckDBPyConnection) -> List[Any]: ... +def fetchnumpy(*, connection: DuckDBPyConnection) -> dict: ... +def fetchdf(*, date_as_object: bool = False, connection: DuckDBPyConnection) -> pandas.DataFrame: ... +def fetch_df(*, date_as_object: bool = False, connection: DuckDBPyConnection) -> pandas.DataFrame: ... +def df(*, date_as_object: bool = False, connection: DuckDBPyConnection) -> pandas.DataFrame: ... +def fetch_df_chunk(vectors_per_chunk: int = 1, *, date_as_object: bool = False, connection: DuckDBPyConnection) -> pandas.DataFrame: ... +def pl(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection) -> polars.DataFrame: ... +def fetch_arrow_table(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection) -> pyarrow.lib.Table: ... +def arrow(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection) -> pyarrow.lib.Table: ... +def fetch_record_batch(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection) -> pyarrow.lib.RecordBatchReader: ... +def torch(*, connection: DuckDBPyConnection) -> dict: ... +def tf(*, connection: DuckDBPyConnection) -> dict: ... +def begin(*, connection: DuckDBPyConnection) -> DuckDBPyConnection: ... +def commit(*, connection: DuckDBPyConnection) -> DuckDBPyConnection: ... +def rollback(*, connection: DuckDBPyConnection) -> DuckDBPyConnection: ... +def append(table_name: str, df: pandas.DataFrame, *, by_name: bool = False, connection: DuckDBPyConnection) -> DuckDBPyConnection: ... +def register(view_name: str, python_object: object, *, connection: DuckDBPyConnection) -> DuckDBPyConnection: ... +def unregister(view_name: str, *, connection: DuckDBPyConnection) -> DuckDBPyConnection: ... +def table(table_name: str, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def view(view_name: str, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def values(values: List[Any], *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def table_function(name: str, parameters: object = None, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def read_json(name: str, *, columns: Optional[Dict[str,str]] = None, sample_size: Optional[int] = None, maximum_depth: Optional[int] = None, records: Optional[str] = None, format: Optional[str] = None, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def extract_statements(query: str, *, connection: DuckDBPyConnection) -> List[Statement]: ... +def sql(query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def query(query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def from_query(query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def read_csv(path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, filename: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def from_csv_auto(path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, filename: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def from_df(df: pandas.DataFrame, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def from_arrow(arrow_object: object, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def from_parquet(file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def read_parquet(file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def from_substrait(proto: str, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def get_substrait(query: str, *, enable_optimizer: bool = True, connection: DuckDBPyConnection) -> str: ... +def get_substrait_json(query: str, *, enable_optimizer: bool = True, connection: DuckDBPyConnection) -> str: ... +def from_substrait_json(json: str, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def get_table_names(query: str, *, connection: DuckDBPyConnection) -> List[str]: ... +def install_extension(extension: str, *, force_install: bool = False, connection: DuckDBPyConnection) -> None: ... +def load_extension(extension: str, *, connection: DuckDBPyConnection) -> None: ... +def project(df: pandas.DataFrame, project_expr: str, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def distinct(df: pandas.DataFrame, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def write_csv(df: pandas.DataFrame, *args: Any, connection: DuckDBPyConnection) -> None: ... +def aggregate(df: pandas.DataFrame, aggr_expr: str, group_expr: str = "", *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def alias(df: pandas.DataFrame, alias: str, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def filter(df: pandas.DataFrame, filter_expr: str, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def limit(df: pandas.DataFrame, n: int, offset: int = 0, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def order(df: pandas.DataFrame, order_expr: str, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def query_df(df: pandas.DataFrame, virtual_table_name: str, sql_query: str, *, connection: DuckDBPyConnection) -> DuckDBPyRelation: ... +def description(*, connection: DuckDBPyConnection) -> Optional[List[Any]]: ... +def rowcount(*, connection: DuckDBPyConnection) -> int: ... # END OF CONNECTION WRAPPER diff --git a/tools/pythonpkg/scripts/connection_methods.json b/tools/pythonpkg/scripts/connection_methods.json index 3184459c7615..06d51903ccc3 100644 --- a/tools/pythonpkg/scripts/connection_methods.json +++ b/tools/pythonpkg/scripts/connection_methods.json @@ -782,7 +782,7 @@ { "name": "compression", "default": "None", - "type": "str" + "type": "Optional[str]" } ], "return": "DuckDBPyRelation" diff --git a/tools/pythonpkg/scripts/generate_connection_stubs.py b/tools/pythonpkg/scripts/generate_connection_stubs.py index 7343ec3ec889..563ade3daefe 100644 --- a/tools/pythonpkg/scripts/generate_connection_stubs.py +++ b/tools/pythonpkg/scripts/generate_connection_stubs.py @@ -52,13 +52,15 @@ def create_arguments(arguments) -> list: return result def create_definition(name, method) -> str: - definition = f"def {name}(self" + definition = f"def {name}(" + arguments = ['self'] if 'args' in method: - definition += ", " - arguments = create_arguments(method['args']) - definition += ', '.join(arguments) + arguments.extend(create_arguments(method['args'])) if 'kwargs' in method: - definition += ", **kwargs" + if not any(x.startswith('*') for x in arguments): + arguments.append("*") + arguments.extend(create_arguments(method['kwargs'])) + definition += ", ".join(arguments) definition += ")" definition += f" -> {method['return']}: ..." return definition diff --git a/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py b/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py index 7665a91906fd..51d94db3dc2a 100644 --- a/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py +++ b/tools/pythonpkg/scripts/generate_connection_wrapper_stubs.py @@ -74,7 +74,9 @@ def create_definition(name, method) -> str: if 'args' in method: arguments.extend(create_arguments(method['args'])) if 'kwargs' in method: - arguments.append("**kwargs") + if not any(x.startswith('*') for x in arguments): + arguments.append("*") + arguments.extend(create_arguments(method['kwargs'])) definition += ', '.join(arguments) definition += ")" definition += f" -> {method['return']}: ..." From c93ab90358f7ded653549391f327ecac41e9fe44 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 12 Apr 2024 15:24:49 +0200 Subject: [PATCH 111/201] handle failing GetAvailableDiskSpace --- src/storage/temporary_file_manager.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/storage/temporary_file_manager.cpp b/src/storage/temporary_file_manager.cpp index d90641bf4ad7..ffb5a75a8644 100644 --- a/src/storage/temporary_file_manager.cpp +++ b/src/storage/temporary_file_manager.cpp @@ -256,7 +256,11 @@ static idx_t GetDefaultMax(const string &path) { auto disk_space = FileSystem::GetAvailableDiskSpace(path); // Use the available disk space // We have made sure that the file exists before we call this, it shouldn't fail - D_ASSERT(disk_space.IsValid()); + if (!disk_space.IsValid()) { + // But if it does (i.e because the system call is not implemented) + // we don't cap the available swap space + return DConstants::INVALID_INDEX - 1; + } // Only use 90% of the available disk space return static_cast(static_cast(disk_space.GetIndex()) * 0.9); } From 3ccc9a15cc1e0a12e61abe4fd64c8e9c41332e74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Mon, 15 Apr 2024 10:00:43 +0200 Subject: [PATCH 112/201] mc --- src/main/extension/extension_load.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/extension/extension_load.cpp b/src/main/extension/extension_load.cpp index 002cf4aefe15..58a1b698ab46 100644 --- a/src/main/extension/extension_load.cpp +++ b/src/main/extension/extension_load.cpp @@ -71,7 +71,7 @@ static string PrettyPrintString(const string &s) { c == '.') { res += c; } else { - uint8_t value = c; + auto value = UnsafeNumericCast(c); res += "\\x"; uint8_t first = value / 16; if (first < 10) { @@ -276,7 +276,7 @@ bool ExtensionHelper::TryInitialLoad(DBConfig &config, FileSystem &fs, const str } } - auto number_metadata_fields = 3; + idx_t number_metadata_fields = 3; D_ASSERT(number_metadata_fields == 3); // Currently hardcoded value metadata_field.resize(number_metadata_fields + 1); From 7a35c390003464f3d24635af21b8210b15ed5d26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Mon, 15 Apr 2024 10:15:44 +0200 Subject: [PATCH 113/201] fixing platform binary --- src/include/duckdb/common/platform.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/include/duckdb/common/platform.h b/src/include/duckdb/common/platform.h index 32babf9c80eb..3166ff9684e2 100644 --- a/src/include/duckdb/common/platform.h +++ b/src/include/duckdb/common/platform.h @@ -1,5 +1,10 @@ #include -#include "duckdb/common/string_util.hpp" + +// duplicated from string_util.h to avoid linking issues +#ifndef DUCKDB_QUOTE_DEFINE +#define DUCKDB_QUOTE_DEFINE_IMPL(x) #x +#define DUCKDB_QUOTE_DEFINE(x) DUCKDB_QUOTE_DEFINE_IMPL(x) +#endif namespace duckdb { From f26de480f561f089f0955c34453afaa930070d33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Mon, 15 Apr 2024 10:19:40 +0200 Subject: [PATCH 114/201] so much nicer --- src/planner/operator/logical_top_n.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/planner/operator/logical_top_n.cpp b/src/planner/operator/logical_top_n.cpp index adf4019e84b0..e60adc0eaf84 100644 --- a/src/planner/operator/logical_top_n.cpp +++ b/src/planner/operator/logical_top_n.cpp @@ -4,8 +4,8 @@ namespace duckdb { idx_t LogicalTopN::EstimateCardinality(ClientContext &context) { auto child_cardinality = LogicalOperator::EstimateCardinality(context); - if (limit >= 0 && child_cardinality < idx_t(limit)) { - return NumericCast(limit); + if (child_cardinality < limit) { + return limit; } return child_cardinality; } From 35534478c304644d643390714a5d6566f4dee4ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Mon, 15 Apr 2024 10:33:30 +0200 Subject: [PATCH 115/201] mc2 --- src/function/table/arrow_conversion.cpp | 2 +- src/optimizer/topn_optimizer.cpp | 4 ++-- src/planner/expression_binder/order_binder.cpp | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/function/table/arrow_conversion.cpp b/src/function/table/arrow_conversion.cpp index c82b6a64ad4d..78d4cca859f6 100644 --- a/src/function/table/arrow_conversion.cpp +++ b/src/function/table/arrow_conversion.cpp @@ -985,7 +985,7 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca for (idx_t row_idx = 0; row_idx < size; row_idx++) { auto tag = NumericCast(type_ids[row_idx]); - auto out_of_range = tag < 0 || tag >= array.n_children; + auto out_of_range = tag >= array.n_children; if (out_of_range) { throw InvalidInputException("Arrow union tag out of range: %d", tag); } diff --git a/src/optimizer/topn_optimizer.cpp b/src/optimizer/topn_optimizer.cpp index 5e83f0a05f79..7c227cc70d5a 100644 --- a/src/optimizer/topn_optimizer.cpp +++ b/src/optimizer/topn_optimizer.cpp @@ -29,10 +29,10 @@ unique_ptr TopN::Optimize(unique_ptr op) { if (CanOptimize(*op)) { auto &limit = op->Cast(); auto &order_by = (op->children[0])->Cast(); - auto limit_val = NumericCast(limit.limit_val.GetConstantValue()); + auto limit_val = limit.limit_val.GetConstantValue(); idx_t offset_val = 0; if (limit.offset_val.Type() == LimitNodeType::CONSTANT_VALUE) { - offset_val = NumericCast(limit.offset_val.GetConstantValue()); + offset_val = limit.offset_val.GetConstantValue(); } auto topn = make_uniq(std::move(order_by.orders), limit_val, offset_val); topn->AddChild(std::move(order_by.children[0])); diff --git a/src/planner/expression_binder/order_binder.cpp b/src/planner/expression_binder/order_binder.cpp index c5a63de8ce84..2924a3a50170 100644 --- a/src/planner/expression_binder/order_binder.cpp +++ b/src/planner/expression_binder/order_binder.cpp @@ -108,7 +108,8 @@ unique_ptr OrderBinder::Bind(unique_ptr expr) { auto &collation = expr->Cast(); if (collation.child->expression_class == ExpressionClass::CONSTANT) { auto &constant = collation.child->Cast(); - auto index = NumericCast(constant.value.GetValue()) - 1; + D_ASSERT(constant.value.GetValue() > 0); + auto index = constant.value.GetValue() - 1; child_list_t values; values.push_back(make_pair("index", Value::UBIGINT(index))); values.push_back(make_pair("collation", Value(std::move(collation.collation)))); From 161b646796618a54c562096046ec1a53b67a2482 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Mon, 15 Apr 2024 10:50:18 +0200 Subject: [PATCH 116/201] arrow tests --- test/api/capi/test_capi_arrow.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/api/capi/test_capi_arrow.cpp b/test/api/capi/test_capi_arrow.cpp index a8bdd2c9ae33..34ac028d56f5 100644 --- a/test/api/capi/test_capi_arrow.cpp +++ b/test/api/capi/test_capi_arrow.cpp @@ -237,7 +237,7 @@ TEST_CASE("Test arrow in C API", "[capi][arrow]") { auto data_chunk = &data_chunks[i]; data_chunk->Initialize(allocator, logical_types, STANDARD_VECTOR_SIZE); data_chunk->SetCardinality(STANDARD_VECTOR_SIZE); - for (int row = 0; row < STANDARD_VECTOR_SIZE; row++) { + for (idx_t row = 0; row < STANDARD_VECTOR_SIZE; row++) { data_chunk->SetValue(0, row, duckdb::Value(i)); } appender.Append(*data_chunk, 0, data_chunk->size(), data_chunk->size()); From 84a3e4a16f3c86139c0e650f8310da07ed11054a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Mon, 15 Apr 2024 10:57:27 +0200 Subject: [PATCH 117/201] mc test runner --- test/sqlite/sqllogic_test_runner.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/sqlite/sqllogic_test_runner.cpp b/test/sqlite/sqllogic_test_runner.cpp index 6c7f32d38859..628a1fc4922f 100644 --- a/test/sqlite/sqllogic_test_runner.cpp +++ b/test/sqlite/sqllogic_test_runner.cpp @@ -293,7 +293,7 @@ RequireResult SQLLogicTestRunner::CheckRequire(SQLLogicParser &parser, const vec parser.Fail("require vector_size requires a parameter"); } // require a specific vector size - auto required_vector_size = std::stoi(params[1]); + auto required_vector_size = NumericCast(std::stoi(params[1])); if (STANDARD_VECTOR_SIZE < required_vector_size) { // vector size is too low for this test: skip it return RequireResult::MISSING; @@ -306,7 +306,7 @@ RequireResult SQLLogicTestRunner::CheckRequire(SQLLogicParser &parser, const vec parser.Fail("require exact_vector_size requires a parameter"); } // require an exact vector size - auto required_vector_size = std::stoi(params[1]); + auto required_vector_size = NumericCast(std::stoi(params[1])); if (STANDARD_VECTOR_SIZE != required_vector_size) { // vector size does not match the required vector size: skip it return RequireResult::MISSING; @@ -319,7 +319,7 @@ RequireResult SQLLogicTestRunner::CheckRequire(SQLLogicParser &parser, const vec parser.Fail("require block_size requires a parameter"); } // require a specific block size - auto required_block_size = std::stoi(params[1]); + auto required_block_size = NumericCast(std::stoi(params[1])); if (Storage::BLOCK_ALLOC_SIZE != required_block_size) { // block size does not match the required block size: skip it return RequireResult::MISSING; From 6dce9deff00d8ae88c168e143d2df2a72b236e28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Mon, 15 Apr 2024 14:13:34 +0200 Subject: [PATCH 118/201] fighting with FormatSigned --- src/include/duckdb/common/numeric_utils.hpp | 3 +++ src/include/duckdb/common/types/cast_helpers.hpp | 14 ++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/include/duckdb/common/numeric_utils.hpp b/src/include/duckdb/common/numeric_utils.hpp index 5b111c896fee..17ca6be0fe27 100644 --- a/src/include/duckdb/common/numeric_utils.hpp +++ b/src/include/duckdb/common/numeric_utils.hpp @@ -68,6 +68,9 @@ static void ThrowNumericCastError(FROM in, TO minval, TO maxval) { template TO NumericCast(FROM val) { + if (std::is_same::value) { + return static_cast(val); + } // some dance around signed-unsigned integer comparison below auto minval = NumericLimits::Minimum(); auto maxval = NumericLimits::Maximum(); diff --git a/src/include/duckdb/common/types/cast_helpers.hpp b/src/include/duckdb/common/types/cast_helpers.hpp index 7dd91f3a5f8a..fb312ace0d64 100644 --- a/src/include/duckdb/common/types/cast_helpers.hpp +++ b/src/include/duckdb/common/types/cast_helpers.hpp @@ -61,17 +61,15 @@ class NumericHelper { template static string_t FormatSigned(T value, Vector &vector) { - auto is_negative = (value < 0); - auto unsigned_value = static_cast::type>(AbsValue(value)); - auto length = UnsignedLength(unsigned_value); - if (is_negative) { - length++; - } - auto result = StringVector::EmptyString(vector, UnsafeNumericCast(length)); + typedef typename MakeUnsigned::type UNSIGNED; + int8_t sign = -(value < 0); + UNSIGNED unsigned_value = UNSIGNED(value ^ T(sign)) + UNSIGNED(AbsValue(sign)); + auto length = UnsafeNumericCast(UnsignedLength(unsigned_value) + AbsValue(sign)); + string_t result = StringVector::EmptyString(vector, length); auto dataptr = result.GetDataWriteable(); auto endptr = dataptr + length; endptr = FormatUnsigned(unsigned_value, endptr); - if (is_negative) { + if (sign) { *--endptr = '-'; } result.Finalize(); From 9f5ec8d07cebe3839423cb6f416fabcafd0edfd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Mon, 15 Apr 2024 14:48:44 +0200 Subject: [PATCH 119/201] duplicated file --- src/optimizer/CMakeLists.txt | 10 +- src/optimizer/move_constants.cpp | 165 ------------------------------ src/optimizer/rule/CMakeLists.txt | 4 +- 3 files changed, 7 insertions(+), 172 deletions(-) delete mode 100644 src/optimizer/move_constants.cpp diff --git a/src/optimizer/CMakeLists.txt b/src/optimizer/CMakeLists.txt index 53f57e1c168d..bcdcb90d13f7 100644 --- a/src/optimizer/CMakeLists.txt +++ b/src/optimizer/CMakeLists.txt @@ -10,24 +10,24 @@ add_library_unity( duckdb_optimizer OBJECT column_binding_replacer.cpp + column_lifetime_analyzer.cpp common_aggregate_optimizer.cpp compressed_materialization.cpp cse_optimizer.cpp deliminator.cpp - unnest_rewriter.cpp - column_lifetime_analyzer.cpp expression_heuristics.cpp + expression_rewriter.cpp filter_combiner.cpp - filter_pushdown.cpp filter_pullup.cpp + filter_pushdown.cpp in_clause_rewriter.cpp optimizer.cpp - expression_rewriter.cpp regex_range_filter.cpp remove_duplicate_groups.cpp remove_unused_columns.cpp statistics_propagator.cpp - topn_optimizer.cpp) + topn_optimizer.cpp + unnest_rewriter.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ PARENT_SCOPE) diff --git a/src/optimizer/move_constants.cpp b/src/optimizer/move_constants.cpp deleted file mode 100644 index 636265ff9131..000000000000 --- a/src/optimizer/move_constants.cpp +++ /dev/null @@ -1,165 +0,0 @@ -#include "duckdb/optimizer/rule/move_constants.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/value_operations/value_operations.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/optimizer/expression_rewriter.hpp" - -namespace duckdb { - -MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewriter) { - auto op = make_uniq(); - op->matchers.push_back(make_uniq()); - op->policy = SetMatcher::Policy::UNORDERED; - - auto arithmetic = make_uniq(); - // we handle multiplication, addition and subtraction because those are "easy" - // integer division makes the division case difficult - // e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules - arithmetic->function = make_uniq(unordered_set {"+", "-", "*"}); - // we match only on integral numeric types - arithmetic->type = make_uniq(); - auto child_constant_matcher = make_uniq(); - auto child_expression_matcher = make_uniq(); - child_constant_matcher->type = make_uniq(); - child_expression_matcher->type = make_uniq(); - arithmetic->matchers.push_back(std::move(child_constant_matcher)); - arithmetic->matchers.push_back(std::move(child_expression_matcher)); - arithmetic->policy = SetMatcher::Policy::SOME; - op->matchers.push_back(std::move(arithmetic)); - root = std::move(op); -} - -unique_ptr MoveConstantsRule::Apply(LogicalOperator &op, vector> &bindings, - bool &changes_made, bool is_root) { - auto &comparison = bindings[0].get().Cast(); - auto &outer_constant = bindings[1].get().Cast(); - auto &arithmetic = bindings[2].get().Cast(); - auto &inner_constant = bindings[3].get().Cast(); - D_ASSERT(arithmetic.return_type.IsIntegral()); - D_ASSERT(arithmetic.children[0]->return_type.IsIntegral()); - if (inner_constant.value.IsNull() || outer_constant.value.IsNull()) { - return make_uniq(Value(comparison.return_type)); - } - auto &constant_type = outer_constant.return_type; - hugeint_t outer_value = IntegralValue::Get(outer_constant.value); - hugeint_t inner_value = IntegralValue::Get(inner_constant.value); - - idx_t arithmetic_child_index = arithmetic.children[0].get() == &inner_constant ? 1 : 0; - auto &op_type = arithmetic.function.name; - if (op_type == "+") { - // [x + 1 COMP 10] OR [1 + x COMP 10] - // order does not matter in addition: - // simply change right side to 10-1 (outer_constant - inner_constant) - if (!Hugeint::TrySubtractInPlace(outer_value, inner_value)) { - return nullptr; - } - auto result_value = Value::HUGEINT(outer_value); - if (!result_value.DefaultTryCastAs(constant_type)) { - if (comparison.type != ExpressionType::COMPARE_EQUAL) { - return nullptr; - } - // if the cast is not possible then the comparison is not possible - // for example, if we have x + 5 = 3, where x is an unsigned number, we will get x = -2 - // since this is not possible we can remove the entire branch here - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(false)); - } - outer_constant.value = std::move(result_value); - } else if (op_type == "-") { - // [x - 1 COMP 10] O R [1 - x COMP 10] - // order matters in subtraction: - if (arithmetic_child_index == 0) { - // [x - 1 COMP 10] - // change right side to 10+1 (outer_constant + inner_constant) - if (!Hugeint::TryAddInPlace(outer_value, inner_value)) { - return nullptr; - } - auto result_value = Value::HUGEINT(outer_value); - if (!result_value.DefaultTryCastAs(constant_type)) { - // if the cast is not possible then an equality comparison is not possible - if (comparison.type != ExpressionType::COMPARE_EQUAL) { - return nullptr; - } - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(false)); - } - outer_constant.value = std::move(result_value); - } else { - // [1 - x COMP 10] - // change right side to 1-10=-9 - if (!Hugeint::TrySubtractInPlace(inner_value, outer_value)) { - return nullptr; - } - auto result_value = Value::HUGEINT(inner_value); - if (!result_value.DefaultTryCastAs(constant_type)) { - // if the cast is not possible then an equality comparison is not possible - if (comparison.type != ExpressionType::COMPARE_EQUAL) { - return nullptr; - } - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(false)); - } - outer_constant.value = std::move(result_value); - // in this case, we should also flip the comparison - // e.g. if we have [4 - x < 2] then we should have [x > 2] - comparison.type = FlipComparisonExpression(comparison.type); - } - } else { - D_ASSERT(op_type == "*"); - // [x * 2 COMP 10] OR [2 * x COMP 10] - // order does not matter in multiplication: - // change right side to 10/2 (outer_constant / inner_constant) - // but ONLY if outer_constant is cleanly divisible by the inner_constant - if (inner_value == 0) { - // x * 0, the result is either 0 or NULL - // we let the arithmetic_simplification rule take care of simplifying this first - return nullptr; - } - // check out of range for HUGEINT or not cleanly divisible - // HUGEINT is not cleanly divisible when outer_value == minimum and inner value == -1. (modulo overflow) - if ((outer_value == NumericLimits::Minimum() && inner_value == -1) || - outer_value % inner_value != 0) { - bool is_equality = comparison.type == ExpressionType::COMPARE_EQUAL; - bool is_inequality = comparison.type == ExpressionType::COMPARE_NOTEQUAL; - if (is_equality || is_inequality) { - // we know the values are not equal - // the result will be either FALSE or NULL (if COMPARE_EQUAL) - // or TRUE or NULL (if COMPARE_NOTEQUAL) - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(is_inequality)); - } else { - // not cleanly divisible and we are doing > >= < <=, skip the simplification for now - return nullptr; - } - } - if (inner_value < 0) { - // multiply by negative value, need to flip expression - comparison.type = FlipComparisonExpression(comparison.type); - } - // else divide the RHS by the LHS - // we need to do a range check on the cast even though we do a division - // because e.g. -128 / -1 = 128, which is out of range - auto result_value = Value::HUGEINT(outer_value / inner_value); - if (!result_value.DefaultTryCastAs(constant_type)) { - return ExpressionRewriter::ConstantOrNull(std::move(arithmetic.children[arithmetic_child_index]), - Value::BOOLEAN(false)); - } - outer_constant.value = std::move(result_value); - } - // replace left side with x - // first extract x from the arithmetic expression - auto arithmetic_child = std::move(arithmetic.children[arithmetic_child_index]); - // then place in the comparison - if (comparison.left.get() == &outer_constant) { - comparison.right = std::move(arithmetic_child); - } else { - comparison.left = std::move(arithmetic_child); - } - changes_made = true; - return nullptr; -} - -} // namespace duckdb diff --git a/src/optimizer/rule/CMakeLists.txt b/src/optimizer/rule/CMakeLists.txt index f6e957fcad07..1ebbb4a0708f 100644 --- a/src/optimizer/rule/CMakeLists.txt +++ b/src/optimizer/rule/CMakeLists.txt @@ -11,9 +11,9 @@ add_library_unity( empty_needle_removal.cpp enum_comparison.cpp equal_or_null_simplification.cpp - move_constants.cpp - like_optimizations.cpp in_clause_simplification_rule.cpp + like_optimizations.cpp + move_constants.cpp ordered_aggregate_optimizer.cpp regex_optimizations.cpp) set(ALL_OBJECT_FILES From d16dff09afe68ee7af8f94a433ec62b028cc4dd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Mon, 15 Apr 2024 16:14:04 +0200 Subject: [PATCH 120/201] orr array_slice --- src/core_functions/scalar/list/array_slice.cpp | 18 ++++++++---------- src/parallel/task_scheduler.cpp | 1 + 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/core_functions/scalar/list/array_slice.cpp b/src/core_functions/scalar/list/array_slice.cpp index a80bf1c7d890..c04e037b4a7c 100644 --- a/src/core_functions/scalar/list/array_slice.cpp +++ b/src/core_functions/scalar/list/array_slice.cpp @@ -47,17 +47,15 @@ static idx_t CalculateSliceLength(idx_t begin, idx_t end, INDEX_TYPE step, bool if (step == 0 && svalid) { throw InvalidInputException("Slice step cannot be zero"); } - auto step_unsigned = UnsafeNumericCast(step); // we called abs() on this above. - if (step == 1) { return NumericCast(end - begin); - } else if (step_unsigned >= (end - begin)) { + } else if (static_cast(step) >= (end - begin)) { return 1; } - if ((end - begin) % step_unsigned != 0) { - return (end - begin) / step_unsigned + 1; + if ((end - begin) % UnsafeNumericCast(step) != 0) { + return (end - begin) / UnsafeNumericCast(step) + 1; } - return (end - begin) / step_unsigned; + return (end - begin) / UnsafeNumericCast(step); } template @@ -147,14 +145,14 @@ list_entry_t SliceValueWithSteps(Vector &result, SelectionVector &sel, list_entr return input; } input.length = CalculateSliceLength(UnsafeNumericCast(begin), UnsafeNumericCast(end), step, true); - int64_t child_idx = UnsafeNumericCast(input.offset) + begin; + idx_t child_idx = input.offset + UnsafeNumericCast(begin); if (step < 0) { - child_idx = UnsafeNumericCast(input.offset) + end - 1; + child_idx = input.offset + UnsafeNumericCast(end) - 1; } input.offset = sel_idx; for (idx_t i = 0; i < input.length; i++) { - sel.set_index(sel_idx, UnsafeNumericCast(child_idx)); - child_idx = UnsafeNumericCast(child_idx) + step; + sel.set_index(sel_idx, child_idx); + child_idx += static_cast(step); // intentional overflow?? sel_idx++; } return input; diff --git a/src/parallel/task_scheduler.cpp b/src/parallel/task_scheduler.cpp index b0f8997ee82a..ea7dc0c6d963 100644 --- a/src/parallel/task_scheduler.cpp +++ b/src/parallel/task_scheduler.cpp @@ -2,6 +2,7 @@ #include "duckdb/common/chrono.hpp" #include "duckdb/common/exception.hpp" +#include "duckdb/common/numeric_utils.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/database.hpp" From 5168f7047516b4711e25fb6bd9ea22daf0d819c8 Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 15 Apr 2024 16:32:39 +0200 Subject: [PATCH 121/201] clang tidy --- src/include/duckdb/common/shared_ptr.ipp | 17 +++++++++-------- src/include/duckdb/common/weak_ptr.ipp | 6 ++++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index 840f101588d4..6704724c21b1 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -32,7 +32,7 @@ private: friend class shared_ptr; template - friend shared_ptr shared_ptr_cast(shared_ptr src); + friend shared_ptr shared_ptr_cast(shared_ptr src); // NOLINT: invalid case style private: original internal; @@ -41,7 +41,7 @@ public: // Constructors shared_ptr() : internal() { } - shared_ptr(std::nullptr_t) : internal(nullptr) { + shared_ptr(std::nullptr_t) : internal(nullptr) { // NOLINT: not marked as explicit } // From raw pointer of type U convertible to T @@ -95,7 +95,7 @@ public: typename std::enable_if::value && std::is_convertible::pointer, T *>::value, int>::type = 0> - shared_ptr(unique_ptr &&other) : internal(std::move(other)) { + shared_ptr(unique_ptr &&other) : internal(std::move(other)) { // NOLINT: not marked as explicit __enable_weak_this(internal.get(), internal.get()); } @@ -161,15 +161,15 @@ public: internal.reset(ptr, deleter); } - void swap(shared_ptr &r) noexcept { + void swap(shared_ptr &r) noexcept { // NOLINT: invalid case style internal.swap(r.internal); } - T *get() const { + T *get() const { // NOLINT: invalid case style return internal.get(); } - long use_count() const { + long use_count() const { // NOLINT: invalid case style return internal.use_count(); } @@ -236,7 +236,8 @@ private: template *>::value, int>::type = 0> - void __enable_weak_this(const enable_shared_from_this *object, _OrigPtr *ptr) noexcept { + void __enable_weak_this(const enable_shared_from_this *object, + _OrigPtr *ptr) noexcept { // NOLINT: invalid case style typedef typename std::remove_cv::type NonConstU; if (object && object->__weak_this_.expired()) { // __weak_this__ is the mutable variable returned by 'shared_from_this' @@ -245,7 +246,7 @@ private: } } - void __enable_weak_this(...) noexcept { + void __enable_weak_this(...) noexcept { // NOLINT: invalid case style } }; diff --git a/src/include/duckdb/common/weak_ptr.ipp b/src/include/duckdb/common/weak_ptr.ipp index fff31e251e04..aef42e1f9e7e 100644 --- a/src/include/duckdb/common/weak_ptr.ipp +++ b/src/include/duckdb/common/weak_ptr.ipp @@ -18,23 +18,25 @@ public: weak_ptr() : internal() { } + // NOLINTBEGIN template weak_ptr(shared_ptr const &ptr, typename std::enable_if::value, int>::type = 0) noexcept : internal(ptr.internal) { } - weak_ptr(weak_ptr const &other) noexcept : internal(other.internal) { // NOLINT: not marked as explicit + weak_ptr(weak_ptr const &other) noexcept : internal(other.internal) { } template weak_ptr(weak_ptr const &ptr, typename std::enable_if::value, int>::type = 0) noexcept : internal(ptr.internal) { } - weak_ptr(weak_ptr &&ptr) noexcept : internal(std::move(ptr.internal)) { // NOLINT: not marked as explicit + weak_ptr(weak_ptr &&ptr) noexcept : internal(std::move(ptr.internal)) { } template weak_ptr(weak_ptr &&ptr, typename std::enable_if::value, int>::type = 0) noexcept : internal(std::move(ptr.internal)) { } + // NOLINTEND // Destructor ~weak_ptr() = default; From 4f110e8ee38e6ed68f9a215672ce69842bd18b36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Mon, 15 Apr 2024 17:01:44 +0200 Subject: [PATCH 122/201] missing header for windows --- src/parallel/task_scheduler.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/parallel/task_scheduler.cpp b/src/parallel/task_scheduler.cpp index ea7dc0c6d963..d327f983213d 100644 --- a/src/parallel/task_scheduler.cpp +++ b/src/parallel/task_scheduler.cpp @@ -16,6 +16,8 @@ #include #endif +#include // ssize_t + namespace duckdb { struct SchedulerThread { From 0ebc97f53df465f1a3f679012b263cecee5fc429 Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 15 Apr 2024 17:33:33 +0200 Subject: [PATCH 123/201] tidy --- src/include/duckdb/common/shared_ptr.ipp | 4 ++-- src/include/duckdb/common/weak_ptr.ipp | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index 6704724c21b1..b6e8ab8b40e9 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -144,19 +144,19 @@ public: reset() { // NOLINT: invalid case style internal.reset(); } + template #ifdef DUCKDB_CLANG_TIDY // This is necessary to tell clang-tidy that it reinitializes the variable after a move [[clang::reinitializes]] #endif - template void reset(U *ptr) { // NOLINT: invalid case style internal.reset(ptr); } + template #ifdef DUCKDB_CLANG_TIDY // This is necessary to tell clang-tidy that it reinitializes the variable after a move [[clang::reinitializes]] #endif - template void reset(U *ptr, DELETER deleter) { // NOLINT: invalid case style internal.reset(ptr, deleter); } diff --git a/src/include/duckdb/common/weak_ptr.ipp b/src/include/duckdb/common/weak_ptr.ipp index aef42e1f9e7e..2f1b9c1b506f 100644 --- a/src/include/duckdb/common/weak_ptr.ipp +++ b/src/include/duckdb/common/weak_ptr.ipp @@ -53,6 +53,10 @@ public: } // Modifiers +#ifdef DUCKDB_CLANG_TIDY + // This is necessary to tell clang-tidy that it reinitializes the variable after a move + [[clang::reinitializes]] +#endif void reset() { // NOLINT: invalid case style internal.reset(); } From dece8f16c8d17f3a021aae73bc3507e2a5107cda Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 15 Apr 2024 20:19:55 +0200 Subject: [PATCH 124/201] format --- src/include/duckdb/common/shared_ptr.ipp | 6 ++++-- src/include/duckdb/common/weak_ptr.ipp | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index b6e8ab8b40e9..b7a76e395ed3 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -149,7 +149,8 @@ public: // This is necessary to tell clang-tidy that it reinitializes the variable after a move [[clang::reinitializes]] #endif - void reset(U *ptr) { // NOLINT: invalid case style + void + reset(U *ptr) { // NOLINT: invalid case style internal.reset(ptr); } template @@ -157,7 +158,8 @@ public: // This is necessary to tell clang-tidy that it reinitializes the variable after a move [[clang::reinitializes]] #endif - void reset(U *ptr, DELETER deleter) { // NOLINT: invalid case style + void + reset(U *ptr, DELETER deleter) { // NOLINT: invalid case style internal.reset(ptr, deleter); } diff --git a/src/include/duckdb/common/weak_ptr.ipp b/src/include/duckdb/common/weak_ptr.ipp index 2f1b9c1b506f..84e0d747d25f 100644 --- a/src/include/duckdb/common/weak_ptr.ipp +++ b/src/include/duckdb/common/weak_ptr.ipp @@ -57,7 +57,8 @@ public: // This is necessary to tell clang-tidy that it reinitializes the variable after a move [[clang::reinitializes]] #endif - void reset() { // NOLINT: invalid case style + void + reset() { // NOLINT: invalid case style internal.reset(); } From 6a76c352523801bd9b07846ac43b8c39a87829a6 Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 15 Apr 2024 21:57:52 +0200 Subject: [PATCH 125/201] tidy checks --- src/include/duckdb/common/shared_ptr.ipp | 16 ++++++++++------ src/include/duckdb/common/weak_ptr.ipp | 3 +++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index b7a76e395ed3..e9e080aacbcc 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -104,6 +104,9 @@ public: // Assign from shared_ptr copy shared_ptr &operator=(const shared_ptr &other) noexcept { + if (this == &other) { + return *this; + } // Create a new shared_ptr using the copy constructor, then swap out the ownership to *this shared_ptr(other).swap(*this); return *this; @@ -235,16 +238,17 @@ public: private: // This overload is used when the class inherits from 'enable_shared_from_this' - template *>::value, + template *>::value, int>::type = 0> - void __enable_weak_this(const enable_shared_from_this *object, - _OrigPtr *ptr) noexcept { // NOLINT: invalid case style - typedef typename std::remove_cv::type NonConstU; + void __enable_weak_this(const enable_shared_from_this *object, // NOLINT: invalid case style + V *ptr) noexcept { + typedef typename std::remove_cv::type non_const_u_t; if (object && object->__weak_this_.expired()) { // __weak_this__ is the mutable variable returned by 'shared_from_this' // it is initialized here - object->__weak_this_ = shared_ptr(*this, const_cast(static_cast(ptr))); + auto non_const = const_cast(static_cast(ptr)); // NOLINT: const cast + object->__weak_this_ = shared_ptr(*this, non_const); } } diff --git a/src/include/duckdb/common/weak_ptr.ipp b/src/include/duckdb/common/weak_ptr.ipp index 84e0d747d25f..40688ded2ea5 100644 --- a/src/include/duckdb/common/weak_ptr.ipp +++ b/src/include/duckdb/common/weak_ptr.ipp @@ -42,6 +42,9 @@ public: // Assignment operators weak_ptr &operator=(const weak_ptr &other) { + if (this == &other) { + return *this; + } internal = other.internal; return *this; } From 1d6e334b76b5f63d94f816a880ff97c6a3183bcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 16 Apr 2024 09:04:08 +0200 Subject: [PATCH 126/201] ci fixes --- src/include/duckdb/common/vector.hpp | 2 +- src/parallel/task_scheduler.cpp | 2 +- src/storage/compression/bitpacking.cpp | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/include/duckdb/common/vector.hpp b/src/include/duckdb/common/vector.hpp index 95f76fe4d94e..c767b76bab3b 100644 --- a/src/include/duckdb/common/vector.hpp +++ b/src/include/duckdb/common/vector.hpp @@ -101,7 +101,7 @@ class vector : public std::vector> { // NOL return get(original::size() - 1); } - void erase_at(idx_t idx) { + void erase_at(idx_t idx) { // NOLINT: not using camelcase on purpose here if (MemorySafety::ENABLED && idx > original::size()) { throw InternalException("Can't remove offset %d from vector of size %d", idx, original::size()); } diff --git a/src/parallel/task_scheduler.cpp b/src/parallel/task_scheduler.cpp index d327f983213d..db2dd6087d5f 100644 --- a/src/parallel/task_scheduler.cpp +++ b/src/parallel/task_scheduler.cpp @@ -16,7 +16,7 @@ #include #endif -#include // ssize_t +#include // ssize_t namespace duckdb { diff --git a/src/storage/compression/bitpacking.cpp b/src/storage/compression/bitpacking.cpp index 03472f4cd5d3..53605cd540a8 100644 --- a/src/storage/compression/bitpacking.cpp +++ b/src/storage/compression/bitpacking.cpp @@ -808,6 +808,7 @@ void BitpackingScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t T multiplier; auto success = TryCast::Operation(scan_state.current_group_offset + i, multiplier); D_ASSERT(success); + (void)success; target_ptr[i] = (multiplier * scan_state.current_constant) + scan_state.current_frame_of_reference; } From a376b015ead517d89e58ef9afabeef9c26bdfac0 Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 16 Apr 2024 09:57:59 +0200 Subject: [PATCH 127/201] add more 'reinitializes' attributes to silence erroneous 'use after move' clang tidy errors --- src/include/duckdb/common/shared_ptr.ipp | 12 +++++++++--- src/include/duckdb/common/weak_ptr.ipp | 6 ++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index e9e080aacbcc..460ae98f9f3b 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -73,8 +73,14 @@ public: } // Move constructor, share ownership with ref template ::value, int>::type = 0> +#ifdef DUCKDB_CLANG_TIDY + [[clang::reinitializes]] +#endif shared_ptr(shared_ptr &&ref) noexcept : internal(std::move(ref.internal)) { // NOLINT: not marked as explicit } +#ifdef DUCKDB_CLANG_TIDY + [[clang::reinitializes]] +#endif shared_ptr(shared_ptr &&other) : internal(std::move(other.internal)) { // NOLINT: not marked as explicit } @@ -95,6 +101,9 @@ public: typename std::enable_if::value && std::is_convertible::pointer, T *>::value, int>::type = 0> +#ifdef DUCKDB_CLANG_TIDY + [[clang::reinitializes]] +#endif shared_ptr(unique_ptr &&other) : internal(std::move(other)) { // NOLINT: not marked as explicit __enable_weak_this(internal.get(), internal.get()); } @@ -140,7 +149,6 @@ public: } #ifdef DUCKDB_CLANG_TIDY - // This is necessary to tell clang-tidy that it reinitializes the variable after a move [[clang::reinitializes]] #endif void @@ -149,7 +157,6 @@ public: } template #ifdef DUCKDB_CLANG_TIDY - // This is necessary to tell clang-tidy that it reinitializes the variable after a move [[clang::reinitializes]] #endif void @@ -158,7 +165,6 @@ public: } template #ifdef DUCKDB_CLANG_TIDY - // This is necessary to tell clang-tidy that it reinitializes the variable after a move [[clang::reinitializes]] #endif void diff --git a/src/include/duckdb/common/weak_ptr.ipp b/src/include/duckdb/common/weak_ptr.ipp index 40688ded2ea5..a714eb0e67b0 100644 --- a/src/include/duckdb/common/weak_ptr.ipp +++ b/src/include/duckdb/common/weak_ptr.ipp @@ -30,9 +30,15 @@ public: weak_ptr(weak_ptr const &ptr, typename std::enable_if::value, int>::type = 0) noexcept : internal(ptr.internal) { } +#ifdef DUCKDB_CLANG_TIDY + [[clang::reinitializes]] +#endif weak_ptr(weak_ptr &&ptr) noexcept : internal(std::move(ptr.internal)) { } template +#ifdef DUCKDB_CLANG_TIDY + [[clang::reinitializes]] +#endif weak_ptr(weak_ptr &&ptr, typename std::enable_if::value, int>::type = 0) noexcept : internal(std::move(ptr.internal)) { } From a04e5f34da9477097b5cd1595f7b6c54acfde4d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 16 Apr 2024 10:06:15 +0200 Subject: [PATCH 128/201] moar fixes for ci --- src/include/duckdb/common/types/cast_helpers.hpp | 6 +++--- src/parallel/task_scheduler.cpp | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/include/duckdb/common/types/cast_helpers.hpp b/src/include/duckdb/common/types/cast_helpers.hpp index fb312ace0d64..369ae2df1533 100644 --- a/src/include/duckdb/common/types/cast_helpers.hpp +++ b/src/include/duckdb/common/types/cast_helpers.hpp @@ -61,10 +61,10 @@ class NumericHelper { template static string_t FormatSigned(T value, Vector &vector) { - typedef typename MakeUnsigned::type UNSIGNED; + typedef typename MakeUnsigned::type unsigned_t; int8_t sign = -(value < 0); - UNSIGNED unsigned_value = UNSIGNED(value ^ T(sign)) + UNSIGNED(AbsValue(sign)); - auto length = UnsafeNumericCast(UnsignedLength(unsigned_value) + AbsValue(sign)); + unsigned_t unsigned_value = UNSIGNED(value ^ T(sign)) + unsigned_t(AbsValue(sign)); + auto length = UnsafeNumericCast(UnsignedLength(unsigned_value) + AbsValue(sign)); string_t result = StringVector::EmptyString(vector, length); auto dataptr = result.GetDataWriteable(); auto endptr = dataptr + length; diff --git a/src/parallel/task_scheduler.cpp b/src/parallel/task_scheduler.cpp index db2dd6087d5f..137cf4c2ea73 100644 --- a/src/parallel/task_scheduler.cpp +++ b/src/parallel/task_scheduler.cpp @@ -16,8 +16,6 @@ #include #endif -#include // ssize_t - namespace duckdb { struct SchedulerThread { @@ -263,6 +261,7 @@ void TaskScheduler::SetAllocatorFlushTreshold(idx_t threshold) { void TaskScheduler::Signal(idx_t n) { #ifndef DUCKDB_NO_THREADS + typedef std::make_signed::type ssize_t; queue->semaphore.signal(NumericCast(n)); #endif } From 446386b9a6382cacc5154a125ef4633c5001b598 Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 16 Apr 2024 10:07:42 +0200 Subject: [PATCH 129/201] format --- src/include/duckdb/common/shared_ptr.ipp | 9 ++++++--- src/include/duckdb/common/weak_ptr.ipp | 18 ++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/include/duckdb/common/shared_ptr.ipp b/src/include/duckdb/common/shared_ptr.ipp index 460ae98f9f3b..d046dc1412f8 100644 --- a/src/include/duckdb/common/shared_ptr.ipp +++ b/src/include/duckdb/common/shared_ptr.ipp @@ -76,12 +76,14 @@ public: #ifdef DUCKDB_CLANG_TIDY [[clang::reinitializes]] #endif - shared_ptr(shared_ptr &&ref) noexcept : internal(std::move(ref.internal)) { // NOLINT: not marked as explicit + shared_ptr(shared_ptr &&ref) noexcept // NOLINT: not marked as explicit + : internal(std::move(ref.internal)) { } #ifdef DUCKDB_CLANG_TIDY [[clang::reinitializes]] #endif - shared_ptr(shared_ptr &&other) : internal(std::move(other.internal)) { // NOLINT: not marked as explicit + shared_ptr(shared_ptr &&other) // NOLINT: not marked as explicit + : internal(std::move(other.internal)) { } // Construct from std::shared_ptr @@ -104,7 +106,8 @@ public: #ifdef DUCKDB_CLANG_TIDY [[clang::reinitializes]] #endif - shared_ptr(unique_ptr &&other) : internal(std::move(other)) { // NOLINT: not marked as explicit + shared_ptr(unique_ptr &&other) // NOLINT: not marked as explicit + : internal(std::move(other)) { __enable_weak_this(internal.get(), internal.get()); } diff --git a/src/include/duckdb/common/weak_ptr.ipp b/src/include/duckdb/common/weak_ptr.ipp index a714eb0e67b0..076fde953258 100644 --- a/src/include/duckdb/common/weak_ptr.ipp +++ b/src/include/duckdb/common/weak_ptr.ipp @@ -19,27 +19,25 @@ public: } // NOLINTBEGIN - template - weak_ptr(shared_ptr const &ptr, - typename std::enable_if::value, int>::type = 0) noexcept - : internal(ptr.internal) { + template ::value, int>::type = 0> + weak_ptr(shared_ptr const &ptr) noexcept : internal(ptr.internal) { } weak_ptr(weak_ptr const &other) noexcept : internal(other.internal) { } - template - weak_ptr(weak_ptr const &ptr, typename std::enable_if::value, int>::type = 0) noexcept - : internal(ptr.internal) { + template ::value, int>::type = 0> + weak_ptr(weak_ptr const &ptr) noexcept : internal(ptr.internal) { } #ifdef DUCKDB_CLANG_TIDY [[clang::reinitializes]] #endif - weak_ptr(weak_ptr &&ptr) noexcept : internal(std::move(ptr.internal)) { + weak_ptr(weak_ptr &&ptr) noexcept + : internal(std::move(ptr.internal)) { } - template + template ::value, int>::type = 0> #ifdef DUCKDB_CLANG_TIDY [[clang::reinitializes]] #endif - weak_ptr(weak_ptr &&ptr, typename std::enable_if::value, int>::type = 0) noexcept + weak_ptr(weak_ptr &&ptr) noexcept : internal(std::move(ptr.internal)) { } // NOLINTEND From 10db9ca14093c68b8c1ea8c895a67f8058ad9d57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 16 Apr 2024 10:10:48 +0200 Subject: [PATCH 130/201] orr --- src/include/duckdb/common/types/cast_helpers.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/duckdb/common/types/cast_helpers.hpp b/src/include/duckdb/common/types/cast_helpers.hpp index 369ae2df1533..2ff2da7594b2 100644 --- a/src/include/duckdb/common/types/cast_helpers.hpp +++ b/src/include/duckdb/common/types/cast_helpers.hpp @@ -63,7 +63,7 @@ class NumericHelper { static string_t FormatSigned(T value, Vector &vector) { typedef typename MakeUnsigned::type unsigned_t; int8_t sign = -(value < 0); - unsigned_t unsigned_value = UNSIGNED(value ^ T(sign)) + unsigned_t(AbsValue(sign)); + unsigned_t unsigned_value = unsigned_t(value ^ T(sign)) + unsigned_t(AbsValue(sign)); auto length = UnsafeNumericCast(UnsignedLength(unsigned_value) + AbsValue(sign)); string_t result = StringVector::EmptyString(vector, length); auto dataptr = result.GetDataWriteable(); From 5d1aa4833580dc60945bf0aa6b5bf2a18a9fbe3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 16 Apr 2024 11:49:37 +0200 Subject: [PATCH 131/201] moar ci 42 --- src/common/box_renderer.cpp | 2 +- src/parallel/task_scheduler.cpp | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/common/box_renderer.cpp b/src/common/box_renderer.cpp index 629f6349fe69..684e47ce3ad3 100644 --- a/src/common/box_renderer.cpp +++ b/src/common/box_renderer.cpp @@ -399,7 +399,7 @@ vector BoxRenderer::ComputeRenderWidths(const vector &names, cons // e.g. if we have 10 columns, we remove #5, then #4, then #6, then #3, then #7, etc int64_t offset = 0; while (total_length > max_width) { - idx_t c = column_count / 2 + NumericCast(offset); + auto c = NumericCast(NumericCast(column_count) / 2 + offset); total_length -= widths[c] + 3; pruned_columns.insert(c); if (offset >= 0) { diff --git a/src/parallel/task_scheduler.cpp b/src/parallel/task_scheduler.cpp index 137cf4c2ea73..d11b144bf3e1 100644 --- a/src/parallel/task_scheduler.cpp +++ b/src/parallel/task_scheduler.cpp @@ -109,7 +109,11 @@ TaskScheduler::TaskScheduler(DatabaseInstance &db) TaskScheduler::~TaskScheduler() { #ifndef DUCKDB_NO_THREADS - RelaunchThreadsInternal(0); + try { + RelaunchThreadsInternal(0); + } catch (...) { + // nothing we can do in the destructor if this fails + } #endif } From e70a2aa5fce4c9a55282e17033624ba8bb059aeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Tue, 16 Apr 2024 15:27:29 +0200 Subject: [PATCH 132/201] fixing merge conflict, thanks @pdet --- src/execution/operator/csv_scanner/util/csv_reader_options.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/execution/operator/csv_scanner/util/csv_reader_options.cpp b/src/execution/operator/csv_scanner/util/csv_reader_options.cpp index 6fa228033e44..e1beb5ef7d21 100644 --- a/src/execution/operator/csv_scanner/util/csv_reader_options.cpp +++ b/src/execution/operator/csv_scanner/util/csv_reader_options.cpp @@ -180,8 +180,6 @@ void CSVReaderOptions::SetReadOption(const string &loption, const Value &value, SetSkipRows(ParseInteger(value, loption)); } else if (loption == "max_line_size" || loption == "maximum_line_size") { maximum_line_size = NumericCast(ParseInteger(value, loption)); - } else if (loption == "force_not_null") { - force_not_null = ParseColumnList(value, expected_names, loption); } else if (loption == "date_format" || loption == "dateformat") { string format = ParseString(value, loption); SetDateFormat(LogicalTypeId::DATE, format, true); From 44a0e68ce4d6110e930591c356e03b6c7f48b827 Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 16 Apr 2024 20:53:08 +0200 Subject: [PATCH 133/201] use ExpressionExecutor, enables overriding behavior through the Catalog (such as ICU) --- src/function/table/copy_csv.cpp | 82 +++++++++++++++++-- .../duckdb/function/table/read_csv.hpp | 2 + 2 files changed, 76 insertions(+), 8 deletions(-) diff --git a/src/function/table/copy_csv.cpp b/src/function/table/copy_csv.cpp index 8d1d50b3558a..bf8ff78d1d2f 100644 --- a/src/function/table/copy_csv.cpp +++ b/src/function/table/copy_csv.cpp @@ -110,6 +110,65 @@ static unique_ptr WriteCSVBind(ClientContext &context, CopyFunctio } bind_data->Finalize(); + auto &options = csv_data->options; + auto &formats = options.write_date_format; + + bool has_dateformat = !formats[LogicalTypeId::DATE].Empty(); + bool has_timestampformat = !formats[LogicalTypeId::TIMESTAMP].Empty(); + + // Create a binder + auto binder = Binder::CreateBinder(context); + + // Create a Binding, used by the ExpressionBinder to turn our columns into BoundReferenceExpressions + auto &bind_context = binder->bind_context; + auto table_index = binder->GenerateTableIndex(); + bind_context.AddGenericBinding(table_index, "copy_csv", names, sql_types); + + // Create the ParsedExpressions (cast, strftime, etc..) + vector> unbound_expressions; + for (idx_t i = 0; i < sql_types.size(); i++) { + auto &type = sql_types[i]; + auto &name = names[i]; + + bool is_timestamp = type.id() == LogicalTypeId::TIMESTAMP || type.id() == LogicalTypeId::TIMESTAMP_TZ; + if (has_dateformat && type.id() == LogicalTypeId::DATE) { + // strftime(, 'format') + vector> children; + children.push_back(make_uniq(name)); + // TODO: set from user-provided format + children.push_back(make_uniq("%m/%d/%Y, %-I:%-M %p")); + auto func = make_uniq("strftime", std::move(children)); + unbound_expressions.push_back(std::move(expr)); + } else if (has_timestampformat && is_timestamp) { + // strftime(, 'format') + vector> children; + children.push_back(make_uniq(name)); + // TODO: set from user-provided format + children.push_back(make_uniq("%Y-%m-%dT%H:%M:%S.%fZ")); + auto func = make_uniq("strftime", std::move(children)); + unbound_expressions.push_back(std::move(expr)); + } else { + // CAST AS VARCHAR + auto column = make_uniq(name); + auto expr = make_uniq(LogicalType::VARCHAR, std::move(column)); + unbound_expressions.push_back(std::move(expr)); + } + } + + // Create an ExpressionBinder, bind the Expressions + vector> expressions; + ExpressionBinder expression_binder(*binder, context); + expression_binder.target_type = LogicalType::VARCHAR; + for (auto &expr : unbound_expressions) { + expressions.push_back(expression_binder.Bind(expr)); + } + + bind_data->cast_expressions = std::move(expressions); + + // Move these into the WriteCSVData + // In 'WriteCSVInitializeLocal' we'll create an ExpressionExecutor, fed our expressions + // In 'WriteCSVChunkInternal' we use this expression executor to convert our input columns to VARCHAR + bind_data->requires_quotes = make_unsafe_uniq_array(256); memset(bind_data->requires_quotes.get(), 0, sizeof(bool) * 256); bind_data->requires_quotes['\n'] = true; @@ -262,6 +321,14 @@ static void WriteQuotedString(WriteStream &writer, WriteCSVData &csv_data, const // Sink //===--------------------------------------------------------------------===// struct LocalWriteCSVData : public LocalFunctionData { +public: + LocalWriteCSVData(ClientContext &context, vector> &expressions) + : executor(context, expressions) { + } + +public: + //! Used to execute the expressions that transform input -> string + ExpressionExecutor executor; //! The thread-local buffer to write data into MemoryStream stream; //! A chunk with VARCHAR columns to cast intermediates into @@ -314,7 +381,7 @@ struct GlobalWriteCSVData : public GlobalFunctionData { static unique_ptr WriteCSVInitializeLocal(ExecutionContext &context, FunctionData &bind_data) { auto &csv_data = bind_data.Cast(); - auto local_data = make_uniq(); + auto local_data = make_uniq(context.client, csv_data.cast_expressions); // create the chunk with VARCHAR types vector types; @@ -362,6 +429,7 @@ static void WriteCSVChunkInternal(ClientContext &context, FunctionData &bind_dat MemoryStream &writer, DataChunk &input, bool &written_anything) { auto &csv_data = bind_data.Cast(); auto &options = csv_data.options; + auto &formats = options.write_date_format; // first cast the columns of the chunk to varchar cast_chunk.Reset(); @@ -370,17 +438,15 @@ static void WriteCSVChunkInternal(ClientContext &context, FunctionData &bind_dat if (csv_data.sql_types[col_idx].id() == LogicalTypeId::VARCHAR) { // VARCHAR, just reinterpret (cannot reference, because LogicalTypeId::VARCHAR is used by the JSON type too) cast_chunk.data[col_idx].Reinterpret(input.data[col_idx]); - } else if (!csv_data.options.write_date_format[LogicalTypeId::DATE].Empty() && - csv_data.sql_types[col_idx].id() == LogicalTypeId::DATE) { + } else if (!formats[LogicalTypeId::DATE].Empty() && csv_data.sql_types[col_idx].id() == LogicalTypeId::DATE) { // use the date format to cast the chunk - csv_data.options.write_date_format[LogicalTypeId::DATE].ConvertDateVector( - input.data[col_idx], cast_chunk.data[col_idx], input.size()); - } else if (!csv_data.options.write_date_format[LogicalTypeId::TIMESTAMP].Empty() && + formats[LogicalTypeId::DATE].ConvertDateVector(input.data[col_idx], cast_chunk.data[col_idx], input.size()); + } else if (!formats[LogicalTypeId::TIMESTAMP].Empty() && (csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP || csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP_TZ)) { // use the timestamp format to cast the chunk - csv_data.options.write_date_format[LogicalTypeId::TIMESTAMP].ConvertTimestampVector( - input.data[col_idx], cast_chunk.data[col_idx], input.size()); + formats[LogicalTypeId::TIMESTAMP].ConvertTimestampVector(input.data[col_idx], cast_chunk.data[col_idx], + input.size()); } else { // non varchar column, perform the cast VectorOperations::Cast(context, input.data[col_idx], cast_chunk.data[col_idx], input.size()); diff --git a/src/include/duckdb/function/table/read_csv.hpp b/src/include/duckdb/function/table/read_csv.hpp index aeb5050214fe..272fbbf68550 100644 --- a/src/include/duckdb/function/table/read_csv.hpp +++ b/src/include/duckdb/function/table/read_csv.hpp @@ -56,6 +56,8 @@ struct WriteCSVData : public BaseCSVData { idx_t flush_size = 4096ULL * 8ULL; //! For each byte whether or not the CSV file requires quotes when containing the byte unsafe_unique_array requires_quotes; + //! Expressions used to convert the input into strings + vector> cast_expressions; }; struct ColumnInfo { From b07a2f4b50d6ba9b3efa113f39fba1e2abdb4753 Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 16 Apr 2024 21:38:58 +0200 Subject: [PATCH 134/201] reworked execution --- src/execution/column_binding_resolver.cpp | 4 + src/function/table/copy_csv.cpp | 103 +++++++++--------- .../execution/column_binding_resolver.hpp | 2 + 3 files changed, 60 insertions(+), 49 deletions(-) diff --git a/src/execution/column_binding_resolver.cpp b/src/execution/column_binding_resolver.cpp index 568381503d78..ed3b2c8b2f95 100644 --- a/src/execution/column_binding_resolver.cpp +++ b/src/execution/column_binding_resolver.cpp @@ -15,6 +15,10 @@ namespace duckdb { ColumnBindingResolver::ColumnBindingResolver(bool verify_only) : verify_only(verify_only) { } +void ColumnBindingResolver::SetBindings(vector &&bindings_p) { + bindings = std::move(bindings_p); +} + void ColumnBindingResolver::VisitOperator(LogicalOperator &op) { switch (op.type) { case LogicalOperatorType::LOGICAL_ASOF_JOIN: diff --git a/src/function/table/copy_csv.cpp b/src/function/table/copy_csv.cpp index bf8ff78d1d2f..73adcd53c52f 100644 --- a/src/function/table/copy_csv.cpp +++ b/src/function/table/copy_csv.cpp @@ -12,6 +12,11 @@ #include "duckdb/function/scalar/string_functions.hpp" #include "duckdb/function/table/read_csv.hpp" #include "duckdb/parser/parsed_data/copy_info.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/execution/column_binding_resolver.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" #include namespace duckdb { @@ -93,24 +98,10 @@ string TransformNewLine(string new_line) { ; } -static unique_ptr WriteCSVBind(ClientContext &context, CopyFunctionBindInput &input, - const vector &names, const vector &sql_types) { - auto bind_data = make_uniq(input.info.file_path, sql_types, names); - - // check all the options in the copy info - for (auto &option : input.info.options) { - auto loption = StringUtil::Lower(option.first); - auto &set = option.second; - bind_data->options.SetWriteOption(loption, ConvertVectorToValue(set)); - } - // verify the parsed options - if (bind_data->options.force_quote.empty()) { - // no FORCE_QUOTE specified: initialize to false - bind_data->options.force_quote.resize(names.size(), false); - } - bind_data->Finalize(); - - auto &options = csv_data->options; +static vector> CreateCastExpressions(WriteCSVData &bind_data, ClientContext &context, + const vector &names, + const vector &sql_types) { + auto &options = bind_data.options; auto &formats = options.write_date_format; bool has_dateformat = !formats[LogicalTypeId::DATE].Empty(); @@ -137,20 +128,20 @@ static unique_ptr WriteCSVBind(ClientContext &context, CopyFunctio children.push_back(make_uniq(name)); // TODO: set from user-provided format children.push_back(make_uniq("%m/%d/%Y, %-I:%-M %p")); - auto func = make_uniq("strftime", std::move(children)); - unbound_expressions.push_back(std::move(expr)); + auto func = make_uniq_base("strftime", std::move(children)); + unbound_expressions.push_back(std::move(func)); } else if (has_timestampformat && is_timestamp) { // strftime(, 'format') vector> children; children.push_back(make_uniq(name)); // TODO: set from user-provided format children.push_back(make_uniq("%Y-%m-%dT%H:%M:%S.%fZ")); - auto func = make_uniq("strftime", std::move(children)); - unbound_expressions.push_back(std::move(expr)); + auto func = make_uniq_base("strftime", std::move(children)); + unbound_expressions.push_back(std::move(func)); } else { // CAST AS VARCHAR auto column = make_uniq(name); - auto expr = make_uniq(LogicalType::VARCHAR, std::move(column)); + auto expr = make_uniq_base(LogicalType::VARCHAR, std::move(column)); unbound_expressions.push_back(std::move(expr)); } } @@ -163,11 +154,38 @@ static unique_ptr WriteCSVBind(ClientContext &context, CopyFunctio expressions.push_back(expression_binder.Bind(expr)); } - bind_data->cast_expressions = std::move(expressions); + ColumnBindingResolver resolver; + vector bindings; + for (idx_t i = 0; i < sql_types.size(); i++) { + bindings.push_back(ColumnBinding(table_index, i)); + } + resolver.SetBindings(std::move(bindings)); - // Move these into the WriteCSVData - // In 'WriteCSVInitializeLocal' we'll create an ExpressionExecutor, fed our expressions - // In 'WriteCSVChunkInternal' we use this expression executor to convert our input columns to VARCHAR + for (auto &expr : expressions) { + resolver.VisitExpression(&expr); + } + return expressions; +} + +static unique_ptr WriteCSVBind(ClientContext &context, CopyFunctionBindInput &input, + const vector &names, const vector &sql_types) { + auto bind_data = make_uniq(input.info.file_path, sql_types, names); + + // check all the options in the copy info + for (auto &option : input.info.options) { + auto loption = StringUtil::Lower(option.first); + auto &set = option.second; + bind_data->options.SetWriteOption(loption, ConvertVectorToValue(set)); + } + // verify the parsed options + if (bind_data->options.force_quote.empty()) { + // no FORCE_QUOTE specified: initialize to false + bind_data->options.force_quote.resize(names.size(), false); + } + bind_data->Finalize(); + + auto expressions = CreateCastExpressions(*bind_data, context, names, sql_types); + bind_data->cast_expressions = std::move(expressions); bind_data->requires_quotes = make_unsafe_uniq_array(256); memset(bind_data->requires_quotes.get(), 0, sizeof(bool) * 256); @@ -426,32 +444,16 @@ idx_t WriteCSVFileSize(GlobalFunctionData &gstate) { } static void WriteCSVChunkInternal(ClientContext &context, FunctionData &bind_data, DataChunk &cast_chunk, - MemoryStream &writer, DataChunk &input, bool &written_anything) { + MemoryStream &writer, DataChunk &input, bool &written_anything, + ExpressionExecutor &executor) { auto &csv_data = bind_data.Cast(); auto &options = csv_data.options; - auto &formats = options.write_date_format; // first cast the columns of the chunk to varchar cast_chunk.Reset(); cast_chunk.SetCardinality(input); - for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { - if (csv_data.sql_types[col_idx].id() == LogicalTypeId::VARCHAR) { - // VARCHAR, just reinterpret (cannot reference, because LogicalTypeId::VARCHAR is used by the JSON type too) - cast_chunk.data[col_idx].Reinterpret(input.data[col_idx]); - } else if (!formats[LogicalTypeId::DATE].Empty() && csv_data.sql_types[col_idx].id() == LogicalTypeId::DATE) { - // use the date format to cast the chunk - formats[LogicalTypeId::DATE].ConvertDateVector(input.data[col_idx], cast_chunk.data[col_idx], input.size()); - } else if (!formats[LogicalTypeId::TIMESTAMP].Empty() && - (csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP || - csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP_TZ)) { - // use the timestamp format to cast the chunk - formats[LogicalTypeId::TIMESTAMP].ConvertTimestampVector(input.data[col_idx], cast_chunk.data[col_idx], - input.size()); - } else { - // non varchar column, perform the cast - VectorOperations::Cast(context, input.data[col_idx], cast_chunk.data[col_idx], input.size()); - } - } + + executor.Execute(input, cast_chunk); cast_chunk.Flatten(); // now loop over the vectors and output the values @@ -492,7 +494,7 @@ static void WriteCSVSink(ExecutionContext &context, FunctionData &bind_data, Glo // write data into the local buffer WriteCSVChunkInternal(context.client, bind_data, local_data.cast_chunk, local_data.stream, input, - local_data.written_anything); + local_data.written_anything, local_data.executor); // check if we should flush what we have currently written auto &writer = local_data.stream; @@ -570,11 +572,14 @@ unique_ptr WriteCSVPrepareBatch(ClientContext &context, Funct DataChunk cast_chunk; cast_chunk.Initialize(Allocator::Get(context), types); + auto expressions = CreateCastExpressions(csv_data, context, csv_data.options.name_list, types); + ExpressionExecutor executor(context, expressions); + // write CSV chunks to the batch data bool written_anything = false; auto batch = make_uniq(); for (auto &chunk : collection->Chunks()) { - WriteCSVChunkInternal(context, bind_data, cast_chunk, batch->stream, chunk, written_anything); + WriteCSVChunkInternal(context, bind_data, cast_chunk, batch->stream, chunk, written_anything, executor); } return std::move(batch); } diff --git a/src/include/duckdb/execution/column_binding_resolver.hpp b/src/include/duckdb/execution/column_binding_resolver.hpp index f98aeb2cfef4..de98ea2166a4 100644 --- a/src/include/duckdb/execution/column_binding_resolver.hpp +++ b/src/include/duckdb/execution/column_binding_resolver.hpp @@ -23,6 +23,8 @@ class ColumnBindingResolver : public LogicalOperatorVisitor { void VisitOperator(LogicalOperator &op) override; static void Verify(LogicalOperator &op); + //! Manually set bindings + void SetBindings(vector &&bindings); protected: vector bindings; From e35b725951212b4f71f39d5dfe6bd1126a1802b4 Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 16 Apr 2024 21:50:15 +0200 Subject: [PATCH 135/201] update hardcoded timestamp (placeholder) --- src/function/table/copy_csv.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/function/table/copy_csv.cpp b/src/function/table/copy_csv.cpp index 73adcd53c52f..32eb32231bfe 100644 --- a/src/function/table/copy_csv.cpp +++ b/src/function/table/copy_csv.cpp @@ -135,7 +135,7 @@ static vector> CreateCastExpressions(WriteCSVData &bind_d vector> children; children.push_back(make_uniq(name)); // TODO: set from user-provided format - children.push_back(make_uniq("%Y-%m-%dT%H:%M:%S.%fZ")); + children.push_back(make_uniq("%x %X.%g%z")); auto func = make_uniq_base("strftime", std::move(children)); unbound_expressions.push_back(std::move(func)); } else { From 6c643f7f4cad3cc91799a7e4ac57c541941df34f Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 16 Apr 2024 22:32:11 +0200 Subject: [PATCH 136/201] save Values for the format, fix an issue in the PrepareBatch --- .../operator/csv_scanner/util/csv_reader_options.cpp | 2 +- src/function/table/copy_csv.cpp | 11 ++++++----- .../operator/csv_scanner/csv_reader_options.hpp | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/execution/operator/csv_scanner/util/csv_reader_options.cpp b/src/execution/operator/csv_scanner/util/csv_reader_options.cpp index c4a24ffbc054..c39d7890d2ec 100644 --- a/src/execution/operator/csv_scanner/util/csv_reader_options.cpp +++ b/src/execution/operator/csv_scanner/util/csv_reader_options.cpp @@ -148,7 +148,7 @@ void CSVReaderOptions::SetDateFormat(LogicalTypeId type, const string &format, b error = StrTimeFormat::ParseFormatSpecifier(format, strpformat); dialect_options.date_format[type].Set(strpformat); } else { - error = StrTimeFormat::ParseFormatSpecifier(format, write_date_format[type]); + write_date_format[type] = Value(format); } if (!error.empty()) { throw InvalidInputException("Could not parse DATEFORMAT: %s", error.c_str()); diff --git a/src/function/table/copy_csv.cpp b/src/function/table/copy_csv.cpp index 32eb32231bfe..dd9567a0807f 100644 --- a/src/function/table/copy_csv.cpp +++ b/src/function/table/copy_csv.cpp @@ -104,8 +104,8 @@ static vector> CreateCastExpressions(WriteCSVData &bind_d auto &options = bind_data.options; auto &formats = options.write_date_format; - bool has_dateformat = !formats[LogicalTypeId::DATE].Empty(); - bool has_timestampformat = !formats[LogicalTypeId::TIMESTAMP].Empty(); + bool has_dateformat = !formats[LogicalTypeId::DATE].IsNull(); + bool has_timestampformat = !formats[LogicalTypeId::TIMESTAMP].IsNull(); // Create a binder auto binder = Binder::CreateBinder(context); @@ -127,7 +127,7 @@ static vector> CreateCastExpressions(WriteCSVData &bind_d vector> children; children.push_back(make_uniq(name)); // TODO: set from user-provided format - children.push_back(make_uniq("%m/%d/%Y, %-I:%-M %p")); + children.push_back(make_uniq(formats[LogicalTypeId::DATE])); auto func = make_uniq_base("strftime", std::move(children)); unbound_expressions.push_back(std::move(func)); } else if (has_timestampformat && is_timestamp) { @@ -135,7 +135,7 @@ static vector> CreateCastExpressions(WriteCSVData &bind_d vector> children; children.push_back(make_uniq(name)); // TODO: set from user-provided format - children.push_back(make_uniq("%x %X.%g%z")); + children.push_back(make_uniq(formats[LogicalTypeId::TIMESTAMP])); auto func = make_uniq_base("strftime", std::move(children)); unbound_expressions.push_back(std::move(func)); } else { @@ -572,7 +572,8 @@ unique_ptr WriteCSVPrepareBatch(ClientContext &context, Funct DataChunk cast_chunk; cast_chunk.Initialize(Allocator::Get(context), types); - auto expressions = CreateCastExpressions(csv_data, context, csv_data.options.name_list, types); + auto &original_types = collection->Types(); + auto expressions = CreateCastExpressions(csv_data, context, csv_data.options.name_list, original_types); ExpressionExecutor executor(context, expressions); // write CSV chunks to the batch data diff --git a/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp b/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp index d929fb721aaf..65ae58f0e0e5 100644 --- a/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp +++ b/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp @@ -123,7 +123,7 @@ struct CSVReaderOptions { //! The date format to use (if any is specified) map date_format = {{LogicalTypeId::DATE, {}}, {LogicalTypeId::TIMESTAMP, {}}}; //! The date format to use for writing (if any is specified) - map write_date_format = {{LogicalTypeId::DATE, {}}, {LogicalTypeId::TIMESTAMP, {}}}; + map write_date_format = {{LogicalTypeId::DATE, Value()}, {LogicalTypeId::TIMESTAMP, Value()}}; //! Whether or not a type format is specified map has_format = {{LogicalTypeId::DATE, false}, {LogicalTypeId::TIMESTAMP, false}}; From 1b1d7f62358f90f1aa3d24c8bb12932427d18d00 Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 16 Apr 2024 22:36:16 +0200 Subject: [PATCH 137/201] add test for the newly enabled functionality, using a strftime overload added by ICU --- test/sql/copy/csv/test_csv_timestamp_tz.test | 1 - .../copy/csv/test_csv_timestamp_tz_icu.test | 24 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 test/sql/copy/csv/test_csv_timestamp_tz_icu.test diff --git a/test/sql/copy/csv/test_csv_timestamp_tz.test b/test/sql/copy/csv/test_csv_timestamp_tz.test index 6bc16ad0df7e..ca6bc3e6381a 100644 --- a/test/sql/copy/csv/test_csv_timestamp_tz.test +++ b/test/sql/copy/csv/test_csv_timestamp_tz.test @@ -15,4 +15,3 @@ query II select * from read_csv_auto('__TEST_DIR__/timestamps.csv'); ---- Tuesday Tuesday - diff --git a/test/sql/copy/csv/test_csv_timestamp_tz_icu.test b/test/sql/copy/csv/test_csv_timestamp_tz_icu.test new file mode 100644 index 000000000000..f24d9f799147 --- /dev/null +++ b/test/sql/copy/csv/test_csv_timestamp_tz_icu.test @@ -0,0 +1,24 @@ +# name: test/sql/copy/csv/test_csv_timestamp_tz_icu.test +# description: Test CSV with timestamp_tz and timestampformat +# group: [csv] + +statement ok +pragma enable_verification + +require icu + +statement ok +SET Calendar = 'gregorian'; + +statement ok +SET TimeZone = 'America/Los_Angeles'; + +statement ok +COPY ( + SELECT make_timestamptz(1713193669561000) AS t +) TO '__TEST_DIR__/timestamp-format.csv' (FORMAT CSV, timestampformat '%x %X.%g%z'); + +query I +select * from read_csv('__TEST_DIR__/timestamp-format.csv', all_varchar=true) +---- +2024-04-15 08:07:49.561-07 From a7f8ca2bf812a076a1203e1765ae89e38549c244 Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Wed, 17 Apr 2024 06:08:35 +0200 Subject: [PATCH 138/201] CMake port from duckdb-wasm: Add thread setting and LINKED_LIBS option --- CMakeLists.txt | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f199d3518cf..058197542d96 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,6 +83,13 @@ if (EXTENSION_STATIC_BUILD AND "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") endif() option(DISABLE_UNITY "Disable unity builds." FALSE) +option(USE_WASM_THREADS "Should threads be used" FALSE) +if (${USE_WASM_THREADS}) + set(WASM_THREAD_FLAGS + -pthread + -sSHARED_MEMORY=1 + ) +endif() option(FORCE_COLORED_OUTPUT "Always produce ANSI-colored output (GNU/Clang only)." FALSE) @@ -810,10 +817,11 @@ function(build_loadable_extension_directory NAME OUTPUT_DIRECTORY EXTENSION_VERS COMMAND ${CMAKE_COMMAND} -E copy $ $.lib ) # Compile the library into the actual wasm file + string(TOUPPER ${NAME} EXTENSION_NAME_UPPERCASE) add_custom_command( TARGET ${TARGET_NAME} POST_BUILD - COMMAND emcc $.lib -o $ -sSIDE_MODULE=1 -O3 + COMMAND emcc $.lib -o $ -O3 -sSIDE_MODULE=2 -sEXPORTED_FUNCTIONS="_${NAME}_init" ${WASM_THREAD_FLAGS} ${DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_LINKED_LIBS} ) endif() add_custom_command( @@ -846,7 +854,7 @@ function(build_static_extension NAME PARAMETERS) endfunction() # Internal extension register function -function(register_extension NAME DONT_LINK DONT_BUILD LOAD_TESTS PATH INCLUDE_PATH TEST_PATH) +function(register_extension NAME DONT_LINK DONT_BUILD LOAD_TESTS PATH INCLUDE_PATH TEST_PATH LINKED_LIBS) string(TOLOWER ${NAME} EXTENSION_NAME_LOWERCASE) string(TOUPPER ${NAME} EXTENSION_NAME_UPPERCASE) @@ -868,6 +876,8 @@ function(register_extension NAME DONT_LINK DONT_BUILD LOAD_TESTS PATH INCLUDE_PA endif() endif() + set(DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_LINKED_LIBS "${LINKED_LIBS}" PARENT_SCOPE) + # Allows explicitly disabling extensions that may be specified in other configurations if (NOT ${DONT_BUILD} AND NOT ${EXTENSION_TESTS_ONLY}) set(DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_SHOULD_BUILD TRUE PARENT_SCOPE) @@ -893,7 +903,7 @@ function(register_extension NAME DONT_LINK DONT_BUILD LOAD_TESTS PATH INCLUDE_PA endfunction() # Downloads the external extension repo at the specified commit and calls register_extension -macro(register_external_extension NAME URL COMMIT DONT_LINK DONT_BUILD LOAD_TESTS PATH INCLUDE_PATH TEST_PATH APPLY_PATCHES SUBMODULES) +macro(register_external_extension NAME URL COMMIT DONT_LINK DONT_BUILD LOAD_TESTS PATH INCLUDE_PATH TEST_PATH APPLY_PATCHES LINKED_LIBS SUBMODULES) include(FetchContent) if (${APPLY_PATCHES}) set(PATCH_COMMAND python3 ${CMAKE_SOURCE_DIR}/scripts/apply_extension_patches.py ${CMAKE_SOURCE_DIR}/.github/patches/extensions/${NAME}/) @@ -921,6 +931,8 @@ macro(register_external_extension NAME URL COMMIT DONT_LINK DONT_BUILD LOAD_TEST endif() message(STATUS "Load extension '${NAME}' from ${URL} @ ${GIT_SHORT_COMMIT}") + + string(TOUPPER ${NAME} EXTENSION_NAME_UPPERCASE) set(DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_EXT_VERSION "${GIT_SHORT_COMMIT}" PARENT_SCOPE) if ("${INCLUDE_PATH}" STREQUAL "") @@ -935,13 +947,13 @@ macro(register_external_extension NAME URL COMMIT DONT_LINK DONT_BUILD LOAD_TEST set(TEST_FULL_PATH "${${NAME}_extension_fc_SOURCE_DIR}/${TEST_PATH}") endif() - register_extension(${NAME} ${DONT_LINK} ${DONT_BUILD} ${LOAD_TESTS} ${${NAME}_extension_fc_SOURCE_DIR}/${PATH} "${INCLUDE_FULL_PATH}" "${TEST_FULL_PATH}") + register_extension(${NAME} ${DONT_LINK} ${DONT_BUILD} ${LOAD_TESTS} ${${NAME}_extension_fc_SOURCE_DIR}/${PATH} "${INCLUDE_FULL_PATH}" "${TEST_FULL_PATH}" "${LINKED_LIBS}") endmacro() function(duckdb_extension_load NAME) # Parameter parsing set(options DONT_LINK DONT_BUILD LOAD_TESTS APPLY_PATCHES) - set(oneValueArgs SOURCE_DIR INCLUDE_DIR TEST_DIR GIT_URL GIT_TAG SUBMODULES EXTENSION_VERSION) + set(oneValueArgs SOURCE_DIR INCLUDE_DIR TEST_DIR GIT_URL GIT_TAG SUBMODULES EXTENSION_VERSION LINKED_LIBS) cmake_parse_arguments(duckdb_extension_load "${options}" "${oneValueArgs}" "" ${ARGN}) string(TOLOWER ${NAME} EXTENSION_NAME_LOWERCASE) @@ -960,12 +972,12 @@ function(duckdb_extension_load NAME) # Remote Git extension if (${duckdb_extension_load_DONT_BUILD}) - register_extension(${NAME} "${duckdb_extension_load_DONT_LINK}" "${duckdb_extension_load_DONT_BUILD}" "" "" "" "") + register_extension(${NAME} "${duckdb_extension_load_DONT_LINK}" "${duckdb_extension_load_DONT_BUILD}" "" "" "" "" "") elseif (NOT "${duckdb_extension_load_GIT_URL}" STREQUAL "") if ("${duckdb_extension_load_GIT_TAG}" STREQUAL "") error("Git URL specified but no valid GIT_TAG was found for ${NAME} extension") endif() - register_external_extension(${NAME} "${duckdb_extension_load_GIT_URL}" "${duckdb_extension_load_GIT_TAG}" "${duckdb_extension_load_DONT_LINK}" "${duckdb_extension_load_DONT_BUILD}" "${duckdb_extension_load_LOAD_TESTS}" "${duckdb_extension_load_SOURCE_DIR}" "${duckdb_extension_load_INCLUDE_DIR}" "${duckdb_extension_load_TEST_DIR}" "${duckdb_extension_load_APPLY_PATCHES}" "${duckdb_extension_load_SUBMODULES}") + register_external_extension(${NAME} "${duckdb_extension_load_GIT_URL}" "${duckdb_extension_load_GIT_TAG}" "${duckdb_extension_load_DONT_LINK}" "${duckdb_extension_load_DONT_BUILD}" "${duckdb_extension_load_LOAD_TESTS}" "${duckdb_extension_load_SOURCE_DIR}" "${duckdb_extension_load_INCLUDE_DIR}" "${duckdb_extension_load_TEST_DIR}" "${duckdb_extension_load_APPLY_PATCHES}" "${duckdb_extension_load_LINKED_LIBS}" "${duckdb_extension_load_SUBMODULES}") if (NOT "${duckdb_extension_load_EXTENSION_VERSION}" STREQUAL "") set(DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_EXT_VERSION "${duckdb_extension_load_EXTENSION_VERSION}" PARENT_SCOPE) endif() @@ -1003,7 +1015,7 @@ function(duckdb_extension_load NAME) set(DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_EXT_VERSION "" PARENT_SCOPE) endif() endif() - register_extension(${NAME} "${duckdb_extension_load_DONT_LINK}" "${duckdb_extension_load_DONT_BUILD}" "${duckdb_extension_load_LOAD_TESTS}" "${duckdb_extension_load_SOURCE_DIR}" "${INCLUDE_PATH_DEFAULT}" "${TEST_PATH_DEFAULT}") + register_extension(${NAME} "${duckdb_extension_load_DONT_LINK}" "${duckdb_extension_load_DONT_BUILD}" "${duckdb_extension_load_LOAD_TESTS}" "${duckdb_extension_load_SOURCE_DIR}" "${INCLUDE_PATH_DEFAULT}" "${TEST_PATH_DEFAULT}" "${duckdb_extension_load_LINKED_LIBS}") elseif(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/extension_external/${NAME}) # Local extension, default path message(STATUS "Load extension '${NAME}' from '${CMAKE_CURRENT_SOURCE_DIR}/extension_external'") @@ -1012,7 +1024,7 @@ function(duckdb_extension_load NAME) else() # Local extension, default path message(STATUS "Load extension '${NAME}' from '${CMAKE_CURRENT_SOURCE_DIR}/extensions'") - register_extension(${NAME} ${duckdb_extension_load_DONT_LINK} "${duckdb_extension_load_DONT_BUILD}" "${duckdb_extension_load_LOAD_TESTS}" "${CMAKE_CURRENT_SOURCE_DIR}/extension/${NAME}" "${CMAKE_CURRENT_SOURCE_DIR}/extension/${NAME}/include" "${CMAKE_CURRENT_SOURCE_DIR}/extension/${NAME}/test/sql") + register_extension(${NAME} ${duckdb_extension_load_DONT_LINK} "${duckdb_extension_load_DONT_BUILD}" "${duckdb_extension_load_LOAD_TESTS}" "${CMAKE_CURRENT_SOURCE_DIR}/extension/${NAME}" "${CMAKE_CURRENT_SOURCE_DIR}/extension/${NAME}/include" "${CMAKE_CURRENT_SOURCE_DIR}/extension/${NAME}/test/sql" "${duckdb_extension_load_LINKED_LIBS}") set(DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_EXT_VERSION "${DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_EXT_VERSION}" PARENT_SCOPE) endif() @@ -1024,6 +1036,7 @@ function(duckdb_extension_load NAME) set(DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_PATH ${DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_PATH} PARENT_SCOPE) set(DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_INCLUDE_PATH ${DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_INCLUDE_PATH} PARENT_SCOPE) set(DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_TEST_PATH ${DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_TEST_PATH} PARENT_SCOPE) + set(DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_LINKED_LIBS ${DUCKDB_EXTENSION_${EXTENSION_NAME_UPPERCASE}_LINKED_LIBS} PARENT_SCOPE) endfunction() if(${EXPORT_DLL_SYMBOLS}) From d4d56b877027e11f29afb5277f8c7af2c1c43197 Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Wed, 17 Apr 2024 06:10:48 +0200 Subject: [PATCH 139/201] Port duckdb-wasm changes to Makefile --- Makefile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index a6fd0dcd3ea0..911beb115bff 100644 --- a/Makefile +++ b/Makefile @@ -294,17 +294,17 @@ release: ${EXTENSION_CONFIG_STEP} wasm_mvp: ${EXTENSION_CONFIG_STEP} mkdir -p ./build/wasm_mvp && \ - emcmake cmake $(GENERATOR) ${COMMON_CMAKE_VARS} -DWASM_LOADABLE_EXTENSIONS=1 -DBUILD_EXTENSIONS_ONLY=1 -Bbuild/wasm_mvp -DCMAKE_CXX_FLAGS="-DDUCKDB_CUSTOM_PLATFORM=wasm_mvp" && \ + emcmake cmake $(GENERATOR) -DWASM_LOADABLE_EXTENSIONS=1 -DBUILD_EXTENSIONS_ONLY=1 -Bbuild/wasm_mvp -DCMAKE_CXX_FLAGS="-DDUCKDB_CUSTOM_PLATFORM=wasm_mvp" -DDUCKDB_EXPLICIT_PLATFORM="wasm_mvp" ${COMMON_CMAKE_VARS} ${TOOLCHAIN_FLAGS} && \ emmake make -j8 -Cbuild/wasm_mvp wasm_eh: ${EXTENSION_CONFIG_STEP} mkdir -p ./build/wasm_eh && \ - emcmake cmake $(GENERATOR) ${COMMON_CMAKE_VARS} -DWASM_LOADABLE_EXTENSIONS=1 -DBUILD_EXTENSIONS_ONLY=1 -Bbuild/wasm_eh -DCMAKE_CXX_FLAGS="-fwasm-exceptions -DWEBDB_FAST_EXCEPTIONS=1 -DDUCKDB_CUSTOM_PLATFORM=wasm_eh" && \ + emcmake cmake $(GENERATOR) -DWASM_LOADABLE_EXTENSIONS=1 -DBUILD_EXTENSIONS_ONLY=1 -Bbuild/wasm_eh -DCMAKE_CXX_FLAGS="-fwasm-exceptions -DWEBDB_FAST_EXCEPTIONS=1 -DDUCKDB_CUSTOM_PLATFORM=wasm_eh" -DDUCKDB_EXPLICIT_PLATFORM="wasm_eh" ${COMMON_CMAKE_VARS} ${TOOLCHAIN_FLAGS} && \ emmake make -j8 -Cbuild/wasm_eh wasm_threads: ${EXTENSION_CONFIG_STEP} mkdir -p ./build/wasm_threads && \ - emcmake cmake $(GENERATOR) ${COMMON_CMAKE_VARS} -DWASM_LOADABLE_EXTENSIONS=1 -DBUILD_EXTENSIONS_ONLY=1 -Bbuild/wasm_threads -DCMAKE_CXX_FLAGS="-fwasm-exceptions -DWEBDB_FAST_EXCEPTIONS=1 -DWITH_WASM_THREADS=1 -DWITH_WASM_SIMD=1 -DWITH_WASM_BULK_MEMORY=1 -DDUCKDB_CUSTOM_PLATFORM=wasm_threads" && \ + emcmake cmake $(GENERATOR) -DWASM_LOADABLE_EXTENSIONS=1 -DBUILD_EXTENSIONS_ONLY=1 -Bbuild/wasm_threads -DCMAKE_CXX_FLAGS="-fwasm-exceptions -DWEBDB_FAST_EXCEPTIONS=1 -DWITH_WASM_THREADS=1 -DWITH_WASM_SIMD=1 -DWITH_WASM_BULK_MEMORY=1 -DDUCKDB_CUSTOM_PLATFORM=wasm_threads -pthread" -DDUCKDB_EXPLICIT_PLATFORM="wasm_threads" ${COMMON_CMAKE_VARS} -DUSE_WASM_THREADS=1 -DCMAKE_C_FLAGS="-pthread" ${TOOLCHAIN_FLAGS} && \ emmake make -j8 -Cbuild/wasm_threads cldebug: ${EXTENSION_CONFIG_STEP} From cee3b7387d73b5fa737931a1069998c9ca8cf2e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Wed, 17 Apr 2024 09:42:46 +0200 Subject: [PATCH 140/201] re-fixing bitpacking --- src/storage/compression/bitpacking.cpp | 4 ++-- src/transaction/duck_transaction_manager.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/storage/compression/bitpacking.cpp b/src/storage/compression/bitpacking.cpp index b41f9779ca2f..d5d06dac8f45 100644 --- a/src/storage/compression/bitpacking.cpp +++ b/src/storage/compression/bitpacking.cpp @@ -516,9 +516,9 @@ struct BitpackingCompressState : public CompressionState { // Compact the segment by moving the metadata next to the data. - idx_t unaligned_offset = data_ptr - base_ptr; + idx_t unaligned_offset = NumericCast(data_ptr - base_ptr); idx_t metadata_offset = AlignValue(unaligned_offset); - idx_t metadata_size = base_ptr + Storage::BLOCK_SIZE - metadata_ptr; + idx_t metadata_size = NumericCast(base_ptr + Storage::BLOCK_SIZE - metadata_ptr); idx_t total_segment_size = metadata_offset + metadata_size; // Asserting things are still sane here diff --git a/src/transaction/duck_transaction_manager.cpp b/src/transaction/duck_transaction_manager.cpp index a3ec6d4a3d9f..f53ab599670f 100644 --- a/src/transaction/duck_transaction_manager.cpp +++ b/src/transaction/duck_transaction_manager.cpp @@ -311,7 +311,7 @@ void DuckTransactionManager::RemoveTransaction(DuckTransaction &transaction) noe if (i > 0) { // we garbage collected transactions: remove them from the list recently_committed_transactions.erase(recently_committed_transactions.begin(), - recently_committed_transactions.begin() + UnsafeNumericCast(i)); + recently_committed_transactions.begin() + static_cast(i)); } // check if we can free the memory of any old transactions i = active_transactions.empty() ? old_transactions.size() : 0; @@ -326,7 +326,7 @@ void DuckTransactionManager::RemoveTransaction(DuckTransaction &transaction) noe } if (i > 0) { // we garbage collected transactions: remove them from the list - old_transactions.erase(old_transactions.begin(), old_transactions.begin() + UnsafeNumericCast(i)); + old_transactions.erase(old_transactions.begin(), old_transactions.begin() + static_cast(i)); } } From 8964b151fd37313401dbcdb18f8a64a15fee1762 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Wed, 17 Apr 2024 09:57:56 +0200 Subject: [PATCH 141/201] missed on in column_data --- src/storage/table/column_data.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/storage/table/column_data.cpp b/src/storage/table/column_data.cpp index 5347ec69c440..cc68b4625294 100644 --- a/src/storage/table/column_data.cpp +++ b/src/storage/table/column_data.cpp @@ -171,7 +171,7 @@ void ColumnData::FetchUpdateRow(TransactionData transaction, row_t row_id, Vecto if (!updates) { return; } - updates->FetchRow(transaction, row_id, result, result_idx); + updates->FetchRow(transaction, NumericCast(row_id), result, result_idx); } void ColumnData::UpdateInternal(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, From defaaa1578aa45dacc2f86db1cbca30054e329f4 Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Apr 2024 10:05:04 +0200 Subject: [PATCH 142/201] add a reset to make clang-tidy happy --- src/storage/table/row_group.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/storage/table/row_group.cpp b/src/storage/table/row_group.cpp index 24c05e50463d..dba0a13ae1d5 100644 --- a/src/storage/table/row_group.cpp +++ b/src/storage/table/row_group.cpp @@ -273,6 +273,7 @@ unique_ptr RowGroup::AlterType(RowGroupCollection &new_collection, con if (i == changed_idx) { // this is the altered column: use the new column row_group->columns.push_back(std::move(column_data)); + column_data.reset(); } else { // this column was not altered: use the data directly row_group->columns.push_back(cols[i]); From 8df6f0502c572b7b3ef177dceff12d1a2486555e Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Apr 2024 10:30:20 +0200 Subject: [PATCH 143/201] add patch to azure --- .github/config/out_of_tree_extensions.cmake | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/config/out_of_tree_extensions.cmake b/.github/config/out_of_tree_extensions.cmake index 50bdda371c96..7b489f8c14ff 100644 --- a/.github/config/out_of_tree_extensions.cmake +++ b/.github/config/out_of_tree_extensions.cmake @@ -41,6 +41,7 @@ if (NOT MINGW) LOAD_TESTS GIT_URL https://github.com/duckdb/duckdb_azure GIT_TAG 09623777a366572bfb8fa53e47acdf72133a360e + APPLY_PATCHES ) endif() From adce37dd2004c0f3d96c75e91bb7da3723461e03 Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Apr 2024 10:35:12 +0200 Subject: [PATCH 144/201] this test should error, no resolution for TIMESTAMP_TZ is available --- test/sql/copy/csv/test_csv_timestamp_tz.test | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/sql/copy/csv/test_csv_timestamp_tz.test b/test/sql/copy/csv/test_csv_timestamp_tz.test index ca6bc3e6381a..36728603706c 100644 --- a/test/sql/copy/csv/test_csv_timestamp_tz.test +++ b/test/sql/copy/csv/test_csv_timestamp_tz.test @@ -5,13 +5,9 @@ statement ok pragma enable_verification -statement ok +statement error copy ( select '2021-05-25 04:55:03.382494 UTC'::timestamp as ts, '2021-05-25 04:55:03.382494 UTC'::timestamptz as tstz ) to '__TEST_DIR__/timestamps.csv' ( timestampformat '%A'); - - -query II -select * from read_csv_auto('__TEST_DIR__/timestamps.csv'); ---- -Tuesday Tuesday +No function matches the given name and argument types From fec25c383143c213acca7f7c56eb7b48045dd4de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Wed, 17 Apr 2024 12:39:58 +0200 Subject: [PATCH 145/201] good to go? --- src/common/serializer/buffered_file_writer.cpp | 4 ++-- src/include/duckdb/common/vector.hpp | 6 +++++- src/storage/buffer/buffer_pool.cpp | 2 -- src/transaction/duck_transaction_manager.cpp | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/common/serializer/buffered_file_writer.cpp b/src/common/serializer/buffered_file_writer.cpp index 62d237e63ea3..be4f51fc3c7a 100644 --- a/src/common/serializer/buffered_file_writer.cpp +++ b/src/common/serializer/buffered_file_writer.cpp @@ -37,8 +37,8 @@ void BufferedFileWriter::WriteData(const_data_ptr_t buffer, idx_t write_size) { Flush(); // Flush buffer before writing every things else } idx_t remaining_to_write = write_size - to_copy; - fs.Write(*handle, const_cast(buffer + to_copy), - UnsafeNumericCast(remaining_to_write)); // NOLINT: wrong API in Write + fs.Write(*handle, const_cast(buffer + to_copy), // NOLINT: wrong API in Write + UnsafeNumericCast(remaining_to_write)); total_written += remaining_to_write; } else { // first copy anything we can from the buffer diff --git a/src/include/duckdb/common/vector.hpp b/src/include/duckdb/common/vector.hpp index c767b76bab3b..676adac20788 100644 --- a/src/include/duckdb/common/vector.hpp +++ b/src/include/duckdb/common/vector.hpp @@ -101,11 +101,15 @@ class vector : public std::vector> { // NOL return get(original::size() - 1); } + void unsafe_erase_at(idx_t idx) { // NOLINT: not using camelcase on purpose here + original::erase(original::begin() + static_cast(idx)); + } + void erase_at(idx_t idx) { // NOLINT: not using camelcase on purpose here if (MemorySafety::ENABLED && idx > original::size()) { throw InternalException("Can't remove offset %d from vector of size %d", idx, original::size()); } - original::erase(original::begin() + static_cast(idx)); + unsafe_erase_at(idx); } }; diff --git a/src/storage/buffer/buffer_pool.cpp b/src/storage/buffer/buffer_pool.cpp index a7741b268a4d..d94b87d20cfa 100644 --- a/src/storage/buffer/buffer_pool.cpp +++ b/src/storage/buffer/buffer_pool.cpp @@ -72,9 +72,7 @@ void BufferPool::UpdateUsedMemory(MemoryTag tag, int64_t size) { memory_usage_per_tag[uint8_t(tag)] -= UnsafeNumericCast(-size); } else { current_memory += UnsafeNumericCast(size); - ; memory_usage_per_tag[uint8_t(tag)] += UnsafeNumericCast(size); - ; } } diff --git a/src/transaction/duck_transaction_manager.cpp b/src/transaction/duck_transaction_manager.cpp index f53ab599670f..42e099cdf39e 100644 --- a/src/transaction/duck_transaction_manager.cpp +++ b/src/transaction/duck_transaction_manager.cpp @@ -277,7 +277,7 @@ void DuckTransactionManager::RemoveTransaction(DuckTransaction &transaction) noe } } // remove the transaction from the set of currently active transactions - active_transactions.erase_at(t_index); + active_transactions.unsafe_erase_at(t_index); // traverse the recently_committed transactions to see if we can remove any idx_t i = 0; for (; i < recently_committed_transactions.size(); i++) { From 835e114257d1098da19adec4e91f16dbb3daf93f Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Apr 2024 12:48:24 +0200 Subject: [PATCH 146/201] missing constant expression --- src/function/table/copy_csv.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/function/table/copy_csv.cpp b/src/function/table/copy_csv.cpp index dd9567a0807f..c581b6443f56 100644 --- a/src/function/table/copy_csv.cpp +++ b/src/function/table/copy_csv.cpp @@ -15,6 +15,7 @@ #include "duckdb/parser/expression/cast_expression.hpp" #include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" #include "duckdb/execution/column_binding_resolver.hpp" #include "duckdb/planner/operator/logical_dummy_scan.hpp" #include From d6399e7792a2ca1a5534559d34a9abfcb230354c Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Apr 2024 13:03:24 +0200 Subject: [PATCH 147/201] always create a Copy of a SQLStatement when ALTERNATIVE_VERIFY is set --- src/main/client_context.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/main/client_context.cpp b/src/main/client_context.cpp index c3aa1917a66e..cd517fc91f78 100644 --- a/src/main/client_context.cpp +++ b/src/main/client_context.cpp @@ -762,6 +762,11 @@ void ClientContext::SetActiveResult(ClientContextLock &lock, BaseQueryResult &re unique_ptr ClientContext::PendingStatementOrPreparedStatementInternal( ClientContextLock &lock, const string &query, unique_ptr statement, shared_ptr &prepared, const PendingQueryParameters ¶meters) { +#ifdef DUCKDB_ALTERNATIVE_VERIFY + if (statement && statement->type != StatementType::LOGICAL_PLAN_STATEMENT) { + statement = statement->Copy(); + } +#endif // check if we are on AutoCommit. In this case we should start a transaction. if (statement && config.AnyVerification()) { // query verification is enabled From d8c083e21c1a1e7a13b5b9f32cac2dcbef053d05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Wed, 17 Apr 2024 14:45:53 +0200 Subject: [PATCH 148/201] void --- src/storage/compression/bitpacking.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/storage/compression/bitpacking.cpp b/src/storage/compression/bitpacking.cpp index d5d06dac8f45..cb10b9c92817 100644 --- a/src/storage/compression/bitpacking.cpp +++ b/src/storage/compression/bitpacking.cpp @@ -902,6 +902,7 @@ void BitpackingFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t r if (scan_state.current_group.mode == BitpackingMode::CONSTANT_DELTA) { T multiplier; auto cast = TryCast::Operation(scan_state.current_group_offset, multiplier); + (void)cast; D_ASSERT(cast); #ifdef DEBUG // overflow check From a8c752062096ff161f42fdae7cb6cbd2750d9a5b Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Apr 2024 15:26:01 +0200 Subject: [PATCH 149/201] dont use a ColumnBindingResolver, create BoundReferenceExpressions directly --- src/execution/column_binding_resolver.cpp | 4 ---- src/function/table/copy_csv.cpp | 17 +++-------------- .../execution/column_binding_resolver.hpp | 2 -- 3 files changed, 3 insertions(+), 20 deletions(-) diff --git a/src/execution/column_binding_resolver.cpp b/src/execution/column_binding_resolver.cpp index ed3b2c8b2f95..568381503d78 100644 --- a/src/execution/column_binding_resolver.cpp +++ b/src/execution/column_binding_resolver.cpp @@ -15,10 +15,6 @@ namespace duckdb { ColumnBindingResolver::ColumnBindingResolver(bool verify_only) : verify_only(verify_only) { } -void ColumnBindingResolver::SetBindings(vector &&bindings_p) { - bindings = std::move(bindings_p); -} - void ColumnBindingResolver::VisitOperator(LogicalOperator &op) { switch (op.type) { case LogicalOperatorType::LOGICAL_ASOF_JOIN: diff --git a/src/function/table/copy_csv.cpp b/src/function/table/copy_csv.cpp index c581b6443f56..f8916f62d71b 100644 --- a/src/function/table/copy_csv.cpp +++ b/src/function/table/copy_csv.cpp @@ -111,7 +111,6 @@ static vector> CreateCastExpressions(WriteCSVData &bind_d // Create a binder auto binder = Binder::CreateBinder(context); - // Create a Binding, used by the ExpressionBinder to turn our columns into BoundReferenceExpressions auto &bind_context = binder->bind_context; auto table_index = binder->GenerateTableIndex(); bind_context.AddGenericBinding(table_index, "copy_csv", names, sql_types); @@ -126,7 +125,7 @@ static vector> CreateCastExpressions(WriteCSVData &bind_d if (has_dateformat && type.id() == LogicalTypeId::DATE) { // strftime(, 'format') vector> children; - children.push_back(make_uniq(name)); + children.push_back(make_uniq(make_uniq(name, type, i))); // TODO: set from user-provided format children.push_back(make_uniq(formats[LogicalTypeId::DATE])); auto func = make_uniq_base("strftime", std::move(children)); @@ -134,14 +133,14 @@ static vector> CreateCastExpressions(WriteCSVData &bind_d } else if (has_timestampformat && is_timestamp) { // strftime(, 'format') vector> children; - children.push_back(make_uniq(name)); + children.push_back(make_uniq(make_uniq(name, type, i))); // TODO: set from user-provided format children.push_back(make_uniq(formats[LogicalTypeId::TIMESTAMP])); auto func = make_uniq_base("strftime", std::move(children)); unbound_expressions.push_back(std::move(func)); } else { // CAST AS VARCHAR - auto column = make_uniq(name); + auto column = make_uniq(make_uniq(name, type, i)); auto expr = make_uniq_base(LogicalType::VARCHAR, std::move(column)); unbound_expressions.push_back(std::move(expr)); } @@ -155,16 +154,6 @@ static vector> CreateCastExpressions(WriteCSVData &bind_d expressions.push_back(expression_binder.Bind(expr)); } - ColumnBindingResolver resolver; - vector bindings; - for (idx_t i = 0; i < sql_types.size(); i++) { - bindings.push_back(ColumnBinding(table_index, i)); - } - resolver.SetBindings(std::move(bindings)); - - for (auto &expr : expressions) { - resolver.VisitExpression(&expr); - } return expressions; } diff --git a/src/include/duckdb/execution/column_binding_resolver.hpp b/src/include/duckdb/execution/column_binding_resolver.hpp index de98ea2166a4..f98aeb2cfef4 100644 --- a/src/include/duckdb/execution/column_binding_resolver.hpp +++ b/src/include/duckdb/execution/column_binding_resolver.hpp @@ -23,8 +23,6 @@ class ColumnBindingResolver : public LogicalOperatorVisitor { void VisitOperator(LogicalOperator &op) override; static void Verify(LogicalOperator &op); - //! Manually set bindings - void SetBindings(vector &&bindings); protected: vector bindings; From 23c3553c8103bf61da63aa891959dd11a9caa0ff Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Apr 2024 15:26:56 +0200 Subject: [PATCH 150/201] remove resolved TODOs --- src/function/table/copy_csv.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/function/table/copy_csv.cpp b/src/function/table/copy_csv.cpp index f8916f62d71b..8f6999dd9262 100644 --- a/src/function/table/copy_csv.cpp +++ b/src/function/table/copy_csv.cpp @@ -126,7 +126,6 @@ static vector> CreateCastExpressions(WriteCSVData &bind_d // strftime(, 'format') vector> children; children.push_back(make_uniq(make_uniq(name, type, i))); - // TODO: set from user-provided format children.push_back(make_uniq(formats[LogicalTypeId::DATE])); auto func = make_uniq_base("strftime", std::move(children)); unbound_expressions.push_back(std::move(func)); @@ -134,7 +133,6 @@ static vector> CreateCastExpressions(WriteCSVData &bind_d // strftime(, 'format') vector> children; children.push_back(make_uniq(make_uniq(name, type, i))); - // TODO: set from user-provided format children.push_back(make_uniq(formats[LogicalTypeId::TIMESTAMP])); auto func = make_uniq_base("strftime", std::move(children)); unbound_expressions.push_back(std::move(func)); From 651513da98a7523f3d7cd772b4c5a1ff5019d8f1 Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Apr 2024 20:23:31 +0200 Subject: [PATCH 151/201] missing headers --- src/function/table/copy_csv.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/function/table/copy_csv.cpp b/src/function/table/copy_csv.cpp index 8f6999dd9262..eb5205b113e9 100644 --- a/src/function/table/copy_csv.cpp +++ b/src/function/table/copy_csv.cpp @@ -16,6 +16,8 @@ #include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/parser/expression/columnref_expression.hpp" #include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/execution/column_binding_resolver.hpp" #include "duckdb/planner/operator/logical_dummy_scan.hpp" #include From 30afb3b8dfbed512395cefd11ac0871a4d48a5a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Wed, 17 Apr 2024 21:34:59 +0200 Subject: [PATCH 152/201] first part wconversion --- src/CMakeLists.txt | 2 +- src/common/operator/cast_operators.cpp | 4 ++-- src/common/progress_bar/progress_bar.cpp | 2 +- src/common/types.cpp | 2 +- src/core_functions/aggregate/holistic/quantile.cpp | 7 ++++--- src/core_functions/scalar/math/numeric.cpp | 4 ++-- src/core_functions/scalar/random/setseed.cpp | 2 +- src/execution/aggregate_hashtable.cpp | 4 ++-- src/execution/join_hashtable.cpp | 3 ++- src/execution/operator/join/physical_hash_join.cpp | 2 +- src/execution/radix_partitioned_hashtable.cpp | 8 +++++--- src/include/duckdb/common/operator/numeric_cast.hpp | 4 ++-- src/main/config.cpp | 2 +- src/optimizer/join_order/estimated_properties.cpp | 4 ++-- src/optimizer/join_order/relation_manager.cpp | 6 ++++-- src/optimizer/join_order/relation_statistics_helper.cpp | 6 +++--- src/parallel/executor.cpp | 4 +++- src/storage/temporary_memory_manager.cpp | 6 +++--- 18 files changed, 40 insertions(+), 32 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4c69853abd3c..21f47240848a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,7 +24,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") set(EXIT_TIME_DESTRUCTORS_WARNING TRUE) set(CMAKE_CXX_FLAGS_DEBUG - "${CMAKE_CXX_FLAGS_DEBUG} -Wexit-time-destructors -Wimplicit-int-conversion -Wshorten-64-to-32 -Wnarrowing -Wsign-conversion -Wsign-compare" + "${CMAKE_CXX_FLAGS_DEBUG} -Wexit-time-destructors -Wimplicit-int-conversion -Wshorten-64-to-32 -Wnarrowing -Wsign-conversion -Wsign-compare -Wconversion" ) endif() diff --git a/src/common/operator/cast_operators.cpp b/src/common/operator/cast_operators.cpp index 769ff78cdaf7..843acc70527a 100644 --- a/src/common/operator/cast_operators.cpp +++ b/src/common/operator/cast_operators.cpp @@ -1957,7 +1957,7 @@ struct DecimalCastOperation { for (idx_t i = 0; i < state.excessive_decimals; i++) { auto mod = state.result % 10; round_up = NEGATIVE ? mod <= -5 : mod >= 5; - state.result /= 10.0; + state.result /= static_cast(10.0); } //! Only round up when exponents are involved if (state.exponent_type == T::ExponentType::POSITIVE && round_up) { @@ -2486,7 +2486,7 @@ bool DoubleToDecimalCast(SRC input, DST &result, CastParameters ¶meters, uin HandleCastError::AssignError(error, parameters); return false; } - result = Cast::Operation(value); + result = Cast::Operation(UnsafeNumericCast(value)); return true; } diff --git a/src/common/progress_bar/progress_bar.cpp b/src/common/progress_bar/progress_bar.cpp index 9c6a75fb6f86..720f5499fe89 100644 --- a/src/common/progress_bar/progress_bar.cpp +++ b/src/common/progress_bar/progress_bar.cpp @@ -121,7 +121,7 @@ void ProgressBar::Update(bool final) { if (final) { FinishProgressBarPrint(); } else { - PrintProgress(query_progress.percentage); + PrintProgress(NumericCast(query_progress.percentage.load())); } #endif } diff --git a/src/common/types.cpp b/src/common/types.cpp index 862f86e5f900..b54c47e78b0d 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -1110,7 +1110,7 @@ bool ApproxEqual(float ldecimal, float rdecimal) { if (!Value::FloatIsFinite(ldecimal) || !Value::FloatIsFinite(rdecimal)) { return ldecimal == rdecimal; } - float epsilon = std::fabs(rdecimal) * 0.01 + 0.00000001; + auto epsilon = UnsafeNumericCast(std::fabs(rdecimal) * 0.01 + 0.00000001); return std::fabs(ldecimal - rdecimal) <= epsilon; } diff --git a/src/core_functions/aggregate/holistic/quantile.cpp b/src/core_functions/aggregate/holistic/quantile.cpp index 0b3c67fe01a1..84446a6cee44 100644 --- a/src/core_functions/aggregate/holistic/quantile.cpp +++ b/src/core_functions/aggregate/holistic/quantile.cpp @@ -157,7 +157,7 @@ struct CastInterpolation { template static inline TARGET_TYPE Interpolate(const TARGET_TYPE &lo, const double d, const TARGET_TYPE &hi) { const auto delta = hi - lo; - return lo + delta * d; + return UnsafeNumericCast(lo + delta * d); } }; @@ -295,7 +295,8 @@ bool operator==(const QuantileValue &x, const QuantileValue &y) { template struct Interpolator { Interpolator(const QuantileValue &q, const idx_t n_p, const bool desc_p) - : desc(desc_p), RN((double)(n_p - 1) * q.dbl), FRN(floor(RN)), CRN(ceil(RN)), begin(0), end(n_p) { + : desc(desc_p), RN((double)(n_p - 1) * q.dbl), FRN(UnsafeNumericCast(floor(RN))), + CRN(UnsafeNumericCast(ceil(RN))), begin(0), end(n_p) { } template > @@ -365,7 +366,7 @@ struct Interpolator { } default: const auto scaled_q = (double)(n * q.dbl); - floored = floor(n - scaled_q); + floored = UnsafeNumericCast(floor(n - scaled_q)); break; } diff --git a/src/core_functions/scalar/math/numeric.cpp b/src/core_functions/scalar/math/numeric.cpp index 711f92608aec..4a6055a91413 100644 --- a/src/core_functions/scalar/math/numeric.cpp +++ b/src/core_functions/scalar/math/numeric.cpp @@ -516,7 +516,7 @@ struct RoundOperatorPrecision { return input; } } - return rounded_value; + return UnsafeNumericCast(rounded_value); } }; @@ -527,7 +527,7 @@ struct RoundOperator { if (std::isinf(rounded_value) || std::isnan(rounded_value)) { return input; } - return rounded_value; + return UnsafeNumericCast(rounded_value); } }; diff --git a/src/core_functions/scalar/random/setseed.cpp b/src/core_functions/scalar/random/setseed.cpp index f2db16e6c5a6..32965cf18ded 100644 --- a/src/core_functions/scalar/random/setseed.cpp +++ b/src/core_functions/scalar/random/setseed.cpp @@ -39,7 +39,7 @@ static void SetSeedFunction(DataChunk &args, ExpressionState &state, Vector &res if (input_seeds[i] < -1.0 || input_seeds[i] > 1.0 || Value::IsNan(input_seeds[i])) { throw InvalidInputException("SETSEED accepts seed values between -1.0 and 1.0, inclusive"); } - uint32_t norm_seed = (input_seeds[i] + 1.0) * half_max; + auto norm_seed = NumericCast((input_seeds[i] + 1.0) * half_max); random_engine.SetSeed(norm_seed); } diff --git a/src/execution/aggregate_hashtable.cpp b/src/execution/aggregate_hashtable.cpp index c8ab1ce6b965..6027e8205286 100644 --- a/src/execution/aggregate_hashtable.cpp +++ b/src/execution/aggregate_hashtable.cpp @@ -122,7 +122,7 @@ idx_t GroupedAggregateHashTable::InitialCapacity() { idx_t GroupedAggregateHashTable::GetCapacityForCount(idx_t count) { count = MaxValue(InitialCapacity(), count); - return NextPowerOfTwo(count * LOAD_FACTOR); + return NextPowerOfTwo(NumericCast(static_cast(count) * LOAD_FACTOR)); } idx_t GroupedAggregateHashTable::Capacity() const { @@ -130,7 +130,7 @@ idx_t GroupedAggregateHashTable::Capacity() const { } idx_t GroupedAggregateHashTable::ResizeThreshold() const { - return Capacity() / LOAD_FACTOR; + return NumericCast(static_cast(Capacity()) / LOAD_FACTOR); } idx_t GroupedAggregateHashTable::ApplyBitMask(hash_t hash) const { diff --git a/src/execution/join_hashtable.cpp b/src/execution/join_hashtable.cpp index 53adf3a85ef8..c1f273915e6c 100644 --- a/src/execution/join_hashtable.cpp +++ b/src/execution/join_hashtable.cpp @@ -940,7 +940,8 @@ void JoinHashTable::SetRepartitionRadixBits(vector> &l auto new_estimated_size = double(max_partition_size) / partition_multiplier; auto new_estimated_count = double(max_partition_count) / partition_multiplier; - auto new_estimated_ht_size = new_estimated_size + PointerTableSize(new_estimated_count); + auto new_estimated_ht_size = + new_estimated_size + static_cast(PointerTableSize(NumericCast(new_estimated_count))); if (new_estimated_ht_size <= double(max_ht_size) / 4) { // Aim for an estimated partition size of max_ht_size / 4 diff --git a/src/execution/operator/join/physical_hash_join.cpp b/src/execution/operator/join/physical_hash_join.cpp index 8632c44bda13..9cd7e34a07f7 100644 --- a/src/execution/operator/join/physical_hash_join.cpp +++ b/src/execution/operator/join/physical_hash_join.cpp @@ -409,7 +409,7 @@ class HashJoinRepartitionEvent : public BasePipelineEvent { total_size += sink_collection.SizeInBytes(); total_count += sink_collection.Count(); } - auto total_blocks = (double(total_size) + Storage::BLOCK_SIZE - 1) / Storage::BLOCK_SIZE; + auto total_blocks = NumericCast((double(total_size) + Storage::BLOCK_SIZE - 1) / Storage::BLOCK_SIZE); auto count_per_block = total_count / total_blocks; auto blocks_per_vector = MaxValue(STANDARD_VECTOR_SIZE / count_per_block, 2); diff --git a/src/execution/radix_partitioned_hashtable.cpp b/src/execution/radix_partitioned_hashtable.cpp index d2a174dc546e..595c1c4c2c83 100644 --- a/src/execution/radix_partitioned_hashtable.cpp +++ b/src/execution/radix_partitioned_hashtable.cpp @@ -197,7 +197,8 @@ RadixHTGlobalSinkState::RadixHTGlobalSinkState(ClientContext &context_p, const R count_before_combining(0), max_partition_size(0) { auto tuples_per_block = Storage::BLOCK_ALLOC_SIZE / radix_ht.GetLayout().GetRowWidth(); - idx_t ht_count = config.sink_capacity / GroupedAggregateHashTable::LOAD_FACTOR; + idx_t ht_count = + NumericCast(static_cast(config.sink_capacity) / GroupedAggregateHashTable::LOAD_FACTOR); auto num_partitions = RadixPartitioning::NumberOfPartitions(config.GetRadixBits()); auto count_per_partition = ht_count / num_partitions; auto blocks_per_partition = (count_per_partition + tuples_per_block) / tuples_per_block + 1; @@ -305,7 +306,8 @@ idx_t RadixHTConfig::SinkCapacity(ClientContext &context) { // Divide cache per active thread by entry size, round up to next power of two, to get capacity const auto size_per_entry = sizeof(aggr_ht_entry_t) * GroupedAggregateHashTable::LOAD_FACTOR; - const auto capacity = NextPowerOfTwo(cache_per_active_thread / size_per_entry); + const auto capacity = + NextPowerOfTwo(NumericCast(static_cast(cache_per_active_thread) / size_per_entry)); // Capacity must be at least the minimum capacity return MaxValue(capacity, GroupedAggregateHashTable::InitialCapacity()); @@ -718,7 +720,7 @@ void RadixHTLocalSourceState::Finalize(RadixHTGlobalSinkState &sink, RadixHTGlob // However, we will limit the initial capacity so we don't do a huge over-allocation const auto n_threads = NumericCast(TaskScheduler::GetScheduler(gstate.context).NumberOfThreads()); const auto memory_limit = BufferManager::GetBufferManager(gstate.context).GetMaxMemory(); - const idx_t thread_limit = 0.6 * memory_limit / n_threads; + const idx_t thread_limit = NumericCast(0.6 * memory_limit / n_threads); const idx_t size_per_entry = partition.data->SizeInBytes() / MaxValue(partition.data->Count(), 1) + idx_t(GroupedAggregateHashTable::LOAD_FACTOR * sizeof(aggr_ht_entry_t)); diff --git a/src/include/duckdb/common/operator/numeric_cast.hpp b/src/include/duckdb/common/operator/numeric_cast.hpp index 26603a987ce2..b6d3b6742f80 100644 --- a/src/include/duckdb/common/operator/numeric_cast.hpp +++ b/src/include/duckdb/common/operator/numeric_cast.hpp @@ -75,7 +75,7 @@ bool TryCastWithOverflowCheckFloat(SRC value, T &result, SRC min, SRC max) { return false; } // PG FLOAT => INT casts use statistical rounding. - result = std::nearbyint(value); + result = UnsafeNumericCast(std::nearbyint(value)); return true; } @@ -182,7 +182,7 @@ bool TryCastWithOverflowCheck(double input, float &result) { return true; } auto res = float(input); - if (!Value::FloatIsFinite(input)) { + if (!Value::DoubleIsFinite(input)) { return false; } result = res; diff --git a/src/main/config.cpp b/src/main/config.cpp index b8ad98651f11..56df902575a6 100644 --- a/src/main/config.cpp +++ b/src/main/config.cpp @@ -423,7 +423,7 @@ idx_t DBConfig::ParseMemoryLimit(const string &arg) { throw ParserException("Unknown unit for memory_limit: %s (expected: KB, MB, GB, TB for 1000^i units or KiB, " "MiB, GiB, TiB for 1024^i unites)"); } - return (idx_t)multiplier * limit; + return NumericCast(multiplier * limit); } // Right now we only really care about access mode when comparing DBConfigs diff --git a/src/optimizer/join_order/estimated_properties.cpp b/src/optimizer/join_order/estimated_properties.cpp index d3841a1bb3fb..9e907331abc4 100644 --- a/src/optimizer/join_order/estimated_properties.cpp +++ b/src/optimizer/join_order/estimated_properties.cpp @@ -11,7 +11,7 @@ double EstimatedProperties::GetCardinality() const { template <> idx_t EstimatedProperties::GetCardinality() const { auto max_idx_t = NumericLimits::Maximum() - 10000; - return MinValue(cardinality, max_idx_t); + return MinValue(NumericCast(cardinality), max_idx_t); } template <> @@ -22,7 +22,7 @@ double EstimatedProperties::GetCost() const { template <> idx_t EstimatedProperties::GetCost() const { auto max_idx_t = NumericLimits::Maximum() - 10000; - return MinValue(cost, max_idx_t); + return MinValue(NumericCast(cost), max_idx_t); } void EstimatedProperties::SetCardinality(double new_card) { diff --git a/src/optimizer/join_order/relation_manager.cpp b/src/optimizer/join_order/relation_manager.cpp index 3a0fbf0e920b..3df21b77e7fe 100644 --- a/src/optimizer/join_order/relation_manager.cpp +++ b/src/optimizer/join_order/relation_manager.cpp @@ -183,7 +183,8 @@ bool RelationManager::ExtractJoinRelations(LogicalOperator &input_op, auto &aggr = op->Cast(); auto operator_stats = RelationStatisticsHelper::ExtractAggregationStats(aggr, child_stats); if (!datasource_filters.empty()) { - operator_stats.cardinality *= RelationStatisticsHelper::DEFAULT_SELECTIVITY; + operator_stats.cardinality = NumericCast(static_cast(operator_stats.cardinality) * + RelationStatisticsHelper::DEFAULT_SELECTIVITY); } AddAggregateOrWindowRelation(input_op, parent, operator_stats, op->type); return true; @@ -196,7 +197,8 @@ bool RelationManager::ExtractJoinRelations(LogicalOperator &input_op, auto &window = op->Cast(); auto operator_stats = RelationStatisticsHelper::ExtractWindowStats(window, child_stats); if (!datasource_filters.empty()) { - operator_stats.cardinality *= RelationStatisticsHelper::DEFAULT_SELECTIVITY; + operator_stats.cardinality = NumericCast(static_cast(operator_stats.cardinality) * + RelationStatisticsHelper::DEFAULT_SELECTIVITY); } AddAggregateOrWindowRelation(input_op, parent, operator_stats, op->type); return true; diff --git a/src/optimizer/join_order/relation_statistics_helper.cpp b/src/optimizer/join_order/relation_statistics_helper.cpp index 94f5ddeb8b47..79af3bd8ca33 100644 --- a/src/optimizer/join_order/relation_statistics_helper.cpp +++ b/src/optimizer/join_order/relation_statistics_helper.cpp @@ -121,8 +121,8 @@ RelationStats RelationStatisticsHelper::ExtractGetStats(LogicalGet &get, ClientC // and there are other table filters (i.e cost > 50), use default selectivity. bool has_equality_filter = (cardinality_after_filters != base_table_cardinality); if (!has_equality_filter && !get.table_filters.filters.empty()) { - cardinality_after_filters = - MaxValue(base_table_cardinality * RelationStatisticsHelper::DEFAULT_SELECTIVITY, 1); + cardinality_after_filters = MaxValue( + NumericCast(base_table_cardinality * RelationStatisticsHelper::DEFAULT_SELECTIVITY), 1U); } if (base_table_cardinality == 0) { cardinality_after_filters = 0; @@ -345,7 +345,7 @@ RelationStats RelationStatisticsHelper::ExtractAggregationStats(LogicalAggregate // most likely we are running on parquet files. Therefore we divide by 2. new_card = (double)child_stats.cardinality / 2; } - stats.cardinality = new_card; + stats.cardinality = NumericCast(new_card); stats.column_names = child_stats.column_names; stats.stats_initialized = true; auto num_child_columns = aggr.GetColumnBindings().size(); diff --git a/src/parallel/executor.cpp b/src/parallel/executor.cpp index 41e710284c0e..585d093eca37 100644 --- a/src/parallel/executor.cpp +++ b/src/parallel/executor.cpp @@ -639,7 +639,9 @@ bool Executor::GetPipelinesProgress(double ¤t_progress, uint64_t ¤t_ for (size_t i = 0; i < progress.size(); i++) { progress[i] = MaxValue(0.0, MinValue(100.0, progress[i])); - current_cardinality += double(progress[i]) * double(cardinality[i]) / double(100); + current_cardinality = NumericCast(static_cast( + current_cardinality + + static_cast(progress[i]) * static_cast(cardinality[i]) / static_cast(100))); current_progress += progress[i] * double(cardinality[i]) / double(total_cardinality); D_ASSERT(current_cardinality <= total_cardinality); } diff --git a/src/storage/temporary_memory_manager.cpp b/src/storage/temporary_memory_manager.cpp index ba046d30fb17..2564e11cfb0e 100644 --- a/src/storage/temporary_memory_manager.cpp +++ b/src/storage/temporary_memory_manager.cpp @@ -45,7 +45,7 @@ void TemporaryMemoryManager::UpdateConfiguration(ClientContext &context) { auto &buffer_manager = BufferManager::GetBufferManager(context); auto &task_scheduler = TaskScheduler::GetScheduler(context); - memory_limit = MAXIMUM_MEMORY_LIMIT_RATIO * double(buffer_manager.GetMaxMemory()); + memory_limit = NumericCast(MAXIMUM_MEMORY_LIMIT_RATIO * static_cast(buffer_manager.GetMaxMemory())); has_temporary_directory = buffer_manager.HasTemporaryDirectory(); num_threads = NumericCast(task_scheduler.NumberOfThreads()); query_max_memory = buffer_manager.GetQueryMaxMemory(); @@ -92,14 +92,14 @@ void TemporaryMemoryManager::UpdateState(ClientContext &context, TemporaryMemory // 3. MAXIMUM_FREE_MEMORY_RATIO * free memory auto upper_bound = MinValue(temporary_memory_state.remaining_size, query_max_memory); auto free_memory = memory_limit - (reservation - temporary_memory_state.reservation); - upper_bound = MinValue(upper_bound, MAXIMUM_FREE_MEMORY_RATIO * free_memory); + upper_bound = MinValue(upper_bound, NumericCast(MAXIMUM_FREE_MEMORY_RATIO * free_memory)); if (remaining_size > memory_limit) { // We're processing more data than fits in memory, so we must further limit memory usage. // The upper bound for the reservation of this state is now also the minimum of: // 3. The ratio of the remaining size of this state and the total remaining size * memory limit auto ratio_of_remaining = double(temporary_memory_state.remaining_size) / double(remaining_size); - upper_bound = MinValue(upper_bound, ratio_of_remaining * memory_limit); + upper_bound = MinValue(upper_bound, NumericCast(ratio_of_remaining * memory_limit)); } SetReservation(temporary_memory_state, MaxValue(lower_bound, upper_bound)); From fbeeb5ea4970ce36e9e55ea65b33f3dd5f82d1e5 Mon Sep 17 00:00:00 2001 From: Mark Raasveldt Date: Wed, 17 Apr 2024 21:49:40 +0200 Subject: [PATCH 153/201] Minor improvements to statement reduction --- extension/sqlsmith/statement_simplifier.cpp | 2 ++ scripts/reduce_sql.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/extension/sqlsmith/statement_simplifier.cpp b/extension/sqlsmith/statement_simplifier.cpp index ff31a6fc1255..66662f4cc024 100644 --- a/extension/sqlsmith/statement_simplifier.cpp +++ b/extension/sqlsmith/statement_simplifier.cpp @@ -132,6 +132,8 @@ void StatementSimplifier::Simplify(unique_ptr &ref) { SimplifyOptionalExpression(cp.condition); SimplifyReplace(ref, cp.left); SimplifyReplace(ref, cp.right); + SimplifyEnum(cp.type, JoinType::INNER); + SimplifyEnum(cp.ref_type, JoinRefType::REGULAR); break; } case TableReferenceType::EXPRESSION_LIST: { diff --git a/scripts/reduce_sql.py b/scripts/reduce_sql.py index bb5c62460686..ab2cd70e96c4 100644 --- a/scripts/reduce_sql.py +++ b/scripts/reduce_sql.py @@ -258,7 +258,7 @@ def reduce_query_log(queries, shell, max_time_seconds=300): print(expected_error) print("===================================================") - final_query = reduce(sql_query, data_load, shell, expected_error, args.max_time) + final_query = reduce(sql_query, data_load, shell, expected_error, int(args.max_time)) print("Found final reduced query") print("===================================================") print(final_query) From c6a0b1a787ae100471e529c826efe7660401b6de Mon Sep 17 00:00:00 2001 From: Mark Raasveldt Date: Wed, 17 Apr 2024 21:50:17 +0200 Subject: [PATCH 154/201] Bump Julia to v0.10.2 --- tools/juliapkg/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/juliapkg/Project.toml b/tools/juliapkg/Project.toml index da97b7535576..6914a8f38ee5 100644 --- a/tools/juliapkg/Project.toml +++ b/tools/juliapkg/Project.toml @@ -1,7 +1,7 @@ name = "DuckDB" uuid = "d2f5444f-75bc-4fdf-ac35-56f514c445e1" authors = ["Mark Raasveldt "] -version = "0.10.1" +version = "0.10.2" [deps] DBInterface = "a10d1c49-ce27-4219-8d33-6db1a4562965" @@ -14,7 +14,7 @@ WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" [compat] DBInterface = "2.5" -DuckDB_jll = "0.10.1" +DuckDB_jll = "0.10.2" FixedPointDecimals = "0.4, 0.5" Tables = "1.7" WeakRefStrings = "1.4" From ce6d2cdc04cdbb1bcb2da8eb5ad31d1d97653ac4 Mon Sep 17 00:00:00 2001 From: Richard Wesley <13156216+hawkfish@users.noreply.github.com> Date: Wed, 17 Apr 2024 14:31:50 -0700 Subject: [PATCH 155/201] Internal #1848: Window Progress PhysicalWindow is a source but didn't implement GetProgress fixes: duckdblabs/duckdb-internal#1848 --- .../operator/aggregate/physical_window.cpp | 15 ++++++++++++++- .../operator/aggregate/physical_window.hpp | 2 ++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/execution/operator/aggregate/physical_window.cpp b/src/execution/operator/aggregate/physical_window.cpp index bcfe0a56bd3b..5cb0045d9e60 100644 --- a/src/execution/operator/aggregate/physical_window.cpp +++ b/src/execution/operator/aggregate/physical_window.cpp @@ -205,6 +205,8 @@ class WindowGlobalSourceState : public GlobalSourceState { mutable mutex built_lock; //! The number of unfinished tasks atomic tasks_remaining; + //! The number of rows returned + atomic returned; public: idx_t MaxThreads() override { @@ -217,7 +219,7 @@ class WindowGlobalSourceState : public GlobalSourceState { }; WindowGlobalSourceState::WindowGlobalSourceState(ClientContext &context_p, WindowGlobalSinkState &gsink_p) - : context(context_p), gsink(gsink_p), next_build(0), tasks_remaining(0) { + : context(context_p), gsink(gsink_p), next_build(0), tasks_remaining(0), returned(0) { auto &hash_groups = gsink.global_partition->hash_groups; auto &gpart = gsink.global_partition; @@ -681,6 +683,15 @@ OrderPreservationType PhysicalWindow::SourceOrder() const { return SupportsBatchIndex() ? OrderPreservationType::FIXED_ORDER : OrderPreservationType::NO_ORDER; } +double PhysicalWindow::GetProgress(ClientContext &context, GlobalSourceState &gsource_p) const { + auto &gsource = gsource_p.Cast(); + const auto returned = gsource.returned.load(); + + auto &gsink = gsource.gsink; + const auto count = gsink.global_partition->count.load(); + return count ? (returned / double(count)) : -1; +} + idx_t PhysicalWindow::GetBatchIndex(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate_p, LocalSourceState &lstate_p) const { auto &lstate = lstate_p.Cast(); @@ -689,6 +700,7 @@ idx_t PhysicalWindow::GetBatchIndex(ExecutionContext &context, DataChunk &chunk, SourceResultType PhysicalWindow::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { + auto &gsource = input.global_state.Cast(); auto &lsource = input.local_state.Cast(); while (chunk.size() == 0) { // Move to the next bin if we are done. @@ -699,6 +711,7 @@ SourceResultType PhysicalWindow::GetData(ExecutionContext &context, DataChunk &c } lsource.Scan(chunk); + gsource.returned += chunk.size(); } return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; diff --git a/src/include/duckdb/execution/operator/aggregate/physical_window.hpp b/src/include/duckdb/execution/operator/aggregate/physical_window.hpp index aad04b562776..a554a46bc741 100644 --- a/src/include/duckdb/execution/operator/aggregate/physical_window.hpp +++ b/src/include/duckdb/execution/operator/aggregate/physical_window.hpp @@ -50,6 +50,8 @@ class PhysicalWindow : public PhysicalOperator { bool SupportsBatchIndex() const override; OrderPreservationType SourceOrder() const override; + double GetProgress(ClientContext &context, GlobalSourceState &gstate_p) const override; + public: // Sink interface SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; From c5249bbe91df35b0189b7d8d24af1907d610d2f2 Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 10:55:05 +0200 Subject: [PATCH 156/201] Add CI check on capability to build duckdb in docker --- .github/workflows/DockerTests.yml | 59 +++++++++++++++++++++++++++++++ scripts/test_docker_images.sh | 4 +-- 2 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/DockerTests.yml diff --git a/.github/workflows/DockerTests.yml b/.github/workflows/DockerTests.yml new file mode 100644 index 000000000000..be2e8579f43b --- /dev/null +++ b/.github/workflows/DockerTests.yml @@ -0,0 +1,59 @@ +name: Docker tests +on: + workflow_call: + inputs: + override_git_describe: + type: string + git_ref: + type: string + skip_tests: + type: string + workflow_dispatch: + inputs: + override_git_describe: + type: string + git_ref: + type: string + skip_tests: + type: string + repository_dispatch: + push: + branches: + - '**' + - '!main' + - '!feature' + paths-ignore: + - '**' + - '!.github/workflows/DockerTests.yml' + - '!scripts/test_docker_images.sh' + pull_request: + types: [opened, reopened, ready_for_review] + paths-ignore: + - '**' + - '!.github/workflows/DockerTests.yml' + - '!scripts/test_docker_images.sh' + +concurrency: + group: docker-${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/main' || github.sha }}-${{ inputs.override_git_describe }} + cancel-in-progress: true + +env: + GH_TOKEN: ${{ secrets.GH_TOKEN }} + OVERRIDE_GIT_DESCRIBE: ${{ inputs.override_git_describe }} + +jobs: + linux-x64-docker: + # Builds binaries for linux_amd64_gcc4 + name: Docker tests on Linux (x64) + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + ref: ${{ inputs.git_ref }} + + - name: Build + shell: bash + run: | + ./scripts/test_docker_images.sh diff --git a/scripts/test_docker_images.sh b/scripts/test_docker_images.sh index 904d4f90b529..c2217a8b6488 100755 --- a/scripts/test_docker_images.sh +++ b/scripts/test_docker_images.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash make clean -docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk add g++ git make cmake ninja && GEN=ninja make" 2>&1 -echo "alpine:latest completed" +docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk add g++ git make cmake ninja python3 && GEN=ninja make && make clean" 2>&1 +docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb ubuntu:20.04 <<< "apt-get update && export DEBIAN_FRONTEND=noninteractive && apt-get install g++ git make cmake ninja-build python3 -y && GEN=ninja make && make clean" 2>&1 From ee45dcc7ba69490f49d7c902dfc54a6b554fe6a0 Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 11:13:58 +0200 Subject: [PATCH 157/201] Move from linux/falloc.h to fcntl.h + _GNU_SOURCE --- src/common/local_file_system.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/common/local_file_system.cpp b/src/common/local_file_system.cpp index 50b39a32d63a..744bce887cb1 100644 --- a/src/common/local_file_system.cpp +++ b/src/common/local_file_system.cpp @@ -40,7 +40,11 @@ extern "C" WINBASEAPI BOOL WINAPI GetPhysicallyInstalledSystemMemory(PULONGLONG) #endif #if defined(__linux__) -#include +// See https://man7.org/linux/man-pages/man2/fallocate.2.html +#ifndef _GNU_SOURCE +#define _GNU_SOURCE /* See feature_test_macros(7) */ +#endif +#include #include // See e.g.: // https://opensource.apple.com/source/CarbonHeaders/CarbonHeaders-18.1/TargetConditionals.h.auto.html From fbec0e894527a72f760ebce9c60dabc18b2f7490 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Thu, 18 Apr 2024 11:44:13 +0200 Subject: [PATCH 158/201] first pass done --- .../duckdb/storage/compression/alp/algorithm/alp.hpp | 6 +++--- .../duckdb/storage/compression/alp/alp_constants.hpp | 9 +++++---- src/include/duckdb/storage/compression/alp/alp_utils.hpp | 2 +- .../duckdb/storage/compression/alprd/algorithm/alprd.hpp | 3 ++- .../duckdb/storage/compression/alprd/alprd_analyze.hpp | 8 +++++--- src/storage/compression/dictionary_compression.cpp | 4 ++-- src/storage/compression/fsst.cpp | 2 +- src/storage/statistics/distinct_statistics.cpp | 2 +- 8 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp b/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp index a8c7e91f3e66..d0bf989755b2 100644 --- a/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp +++ b/src/include/duckdb/storage/compression/alp/algorithm/alp.hpp @@ -107,10 +107,10 @@ struct AlpCompression { */ static int64_t NumberToInt64(T n) { if (IsImpossibleToEncode(n)) { - return AlpConstants::ENCODING_UPPER_LIMIT; + return NumericCast(AlpConstants::ENCODING_UPPER_LIMIT); } n = n + AlpTypedConstants::MAGIC_NUMBER - AlpTypedConstants::MAGIC_NUMBER; - return static_cast(n); + return NumericCast(n); } /* @@ -185,7 +185,7 @@ struct AlpCompression { // Evaluate factor/exponent compression size (we optimize for FOR) uint64_t delta = (static_cast(max_encoded_value) - static_cast(min_encoded_value)); - estimated_bits_per_value = std::ceil(std::log2(delta + 1)); + estimated_bits_per_value = NumericCast(std::ceil(std::log2(delta + 1))); estimated_compression_size += n_values * estimated_bits_per_value; estimated_compression_size += exceptions_count * (EXACT_TYPE_BITSIZE + (AlpConstants::EXCEPTION_POSITION_SIZE * 8)); diff --git a/src/include/duckdb/storage/compression/alp/alp_constants.hpp b/src/include/duckdb/storage/compression/alp/alp_constants.hpp index 55353dda1fb6..9a7a36f9136f 100644 --- a/src/include/duckdb/storage/compression/alp/alp_constants.hpp +++ b/src/include/duckdb/storage/compression/alp/alp_constants.hpp @@ -70,11 +70,12 @@ struct AlpTypedConstants { static constexpr float MAGIC_NUMBER = 12582912.0; //! 2^22 + 2^23 static constexpr uint8_t MAX_EXPONENT = 10; - static constexpr const float EXP_ARR[] = {1.0, 10.0, 100.0, 1000.0, 10000.0, 100000.0, - 1000000.0, 10000000.0, 100000000.0, 1000000000.0, 10000000000.0}; + static constexpr const float EXP_ARR[] = {1.0F, 10.0F, 100.0F, 1000.0F, + 10000.0F, 100000.0F, 1000000.0F, 10000000.0F, + 100000000.0F, 1000000000.0F, 10000000000.0F}; - static constexpr float FRAC_ARR[] = {1.0, 0.1, 0.01, 0.001, 0.0001, 0.00001, - 0.000001, 0.0000001, 0.00000001, 0.000000001, 0.0000000001}; + static constexpr float FRAC_ARR[] = {1.0F, 0.1F, 0.01F, 0.001F, 0.0001F, 0.00001F, + 0.000001F, 0.0000001F, 0.00000001F, 0.000000001F, 0.0000000001F}; }; template <> diff --git a/src/include/duckdb/storage/compression/alp/alp_utils.hpp b/src/include/duckdb/storage/compression/alp/alp_utils.hpp index b5e49a6f3027..75292b829d9b 100644 --- a/src/include/duckdb/storage/compression/alp/alp_utils.hpp +++ b/src/include/duckdb/storage/compression/alp/alp_utils.hpp @@ -42,7 +42,7 @@ class AlpUtils { //! We sample equidistant values within a vector; to do this we jump a fixed number of values uint32_t n_sampled_increments = MaxValue( 1, UnsafeNumericCast(std::ceil((double)n_lookup_values / AlpConstants::SAMPLES_PER_VECTOR))); - uint32_t n_sampled_values = std::ceil((double)n_lookup_values / n_sampled_increments); + uint32_t n_sampled_values = NumericCast(std::ceil((double)n_lookup_values / n_sampled_increments)); D_ASSERT(n_sampled_values < AlpConstants::ALP_VECTOR_SIZE); AlpSamplingParameters sampling_params = {n_lookup_values, n_sampled_increments, n_sampled_values}; diff --git a/src/include/duckdb/storage/compression/alprd/algorithm/alprd.hpp b/src/include/duckdb/storage/compression/alprd/algorithm/alprd.hpp index 66d8262aebc9..ce99e1a0148a 100644 --- a/src/include/duckdb/storage/compression/alprd/algorithm/alprd.hpp +++ b/src/include/duckdb/storage/compression/alprd/algorithm/alprd.hpp @@ -105,7 +105,8 @@ struct AlpRDCompression { // The left parts bit width after compression is determined by how many elements are in the dictionary uint64_t actual_dictionary_size = MinValue(AlpRDConstants::MAX_DICTIONARY_SIZE, left_parts_sorted_repetitions.size()); - uint8_t left_bit_width = MaxValue(1, std::ceil(std::log2(actual_dictionary_size))); + uint8_t left_bit_width = + MaxValue(1, NumericCast(std::ceil(std::log2(actual_dictionary_size)))); if (PERSIST_DICT) { for (idx_t dict_idx = 0; dict_idx < actual_dictionary_size; dict_idx++) { diff --git a/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp b/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp index e88fdae61b34..e37d873ac525 100644 --- a/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp +++ b/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp @@ -126,13 +126,15 @@ idx_t AlpRDFinalAnalyze(AnalyzeState &state) { //! Overhead per vector: Pointer to data + Exceptions count double per_vector_overhead = AlpRDConstants::METADATA_POINTER_SIZE + AlpRDConstants::EXCEPTIONS_COUNT_SIZE; - uint32_t n_vectors = std::ceil((double)analyze_state.total_values_count / AlpRDConstants::ALP_VECTOR_SIZE); + uint32_t n_vectors = + NumericCast(std::ceil((double)analyze_state.total_values_count / AlpRDConstants::ALP_VECTOR_SIZE)); auto estimated_size = (estimed_compressed_bytes * factor_of_sampling) + (n_vectors * per_vector_overhead); - uint32_t estimated_n_blocks = std::ceil(estimated_size / (Storage::BLOCK_SIZE - per_segment_overhead)); + uint32_t estimated_n_blocks = + NumericCast(std::ceil(estimated_size / (Storage::BLOCK_SIZE - per_segment_overhead))); auto final_analyze_size = estimated_size + (estimated_n_blocks * per_segment_overhead); - return final_analyze_size; + return NumericCast(final_analyze_size); } } // namespace duckdb diff --git a/src/storage/compression/dictionary_compression.cpp b/src/storage/compression/dictionary_compression.cpp index 79ccc64471b9..58529e72aa19 100644 --- a/src/storage/compression/dictionary_compression.cpp +++ b/src/storage/compression/dictionary_compression.cpp @@ -86,7 +86,7 @@ typedef struct { } dictionary_compression_header_t; struct DictionaryCompressionStorage { - static constexpr float MINIMUM_COMPRESSION_RATIO = 1.2; + static constexpr float MINIMUM_COMPRESSION_RATIO = 1.2F; static constexpr uint16_t DICTIONARY_HEADER_SIZE = sizeof(dictionary_compression_header_t); static constexpr size_t COMPACTION_FLUSH_LIMIT = (size_t)Storage::BLOCK_SIZE / 5 * 4; @@ -402,7 +402,7 @@ idx_t DictionaryCompressionStorage::StringFinalAnalyze(AnalyzeState &state_p) { auto req_space = RequiredSpace(state.current_tuple_count, state.current_unique_count, state.current_dict_size, width); - return MINIMUM_COMPRESSION_RATIO * (state.segment_count * Storage::BLOCK_SIZE + req_space); + return NumericCast(MINIMUM_COMPRESSION_RATIO * (state.segment_count * Storage::BLOCK_SIZE + req_space)); } //===--------------------------------------------------------------------===// diff --git a/src/storage/compression/fsst.cpp b/src/storage/compression/fsst.cpp index 02474963f31d..fcccab4af95a 100644 --- a/src/storage/compression/fsst.cpp +++ b/src/storage/compression/fsst.cpp @@ -191,7 +191,7 @@ idx_t FSSTStorage::StringFinalAnalyze(AnalyzeState &state_p) { auto estimated_size = estimated_base_size + symtable_size; - return estimated_size * MINIMUM_COMPRESSION_RATIO; + return NumericCast(estimated_size * MINIMUM_COMPRESSION_RATIO); } //===--------------------------------------------------------------------===// diff --git a/src/storage/statistics/distinct_statistics.cpp b/src/storage/statistics/distinct_statistics.cpp index 7d4e6e92a346..69f533071733 100644 --- a/src/storage/statistics/distinct_statistics.cpp +++ b/src/storage/statistics/distinct_statistics.cpp @@ -64,7 +64,7 @@ idx_t DistinctStatistics::GetCount() const { double u1 = pow(u / s, 2) * u; // Estimate total uniques using Good Turing Estimation - idx_t estimate = u + u1 / s * (n - s); + idx_t estimate = NumericCast(u + u1 / s * (n - s)); return MinValue(estimate, total_count); } From fb4c967a64ad5fe5084a87550b8ee96323802c3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20M=C3=BChleisen?= Date: Thu, 18 Apr 2024 12:07:46 +0200 Subject: [PATCH 159/201] whee --- src/common/operator/cast_operators.cpp | 2 +- src/common/types.cpp | 2 +- src/include/duckdb/common/operator/numeric_cast.hpp | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/common/operator/cast_operators.cpp b/src/common/operator/cast_operators.cpp index 843acc70527a..02e30431969c 100644 --- a/src/common/operator/cast_operators.cpp +++ b/src/common/operator/cast_operators.cpp @@ -2486,7 +2486,7 @@ bool DoubleToDecimalCast(SRC input, DST &result, CastParameters ¶meters, uin HandleCastError::AssignError(error, parameters); return false; } - result = Cast::Operation(UnsafeNumericCast(value)); + result = Cast::Operation(static_cast(value)); return true; } diff --git a/src/common/types.cpp b/src/common/types.cpp index b54c47e78b0d..45bb5edab003 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -1110,7 +1110,7 @@ bool ApproxEqual(float ldecimal, float rdecimal) { if (!Value::FloatIsFinite(ldecimal) || !Value::FloatIsFinite(rdecimal)) { return ldecimal == rdecimal; } - auto epsilon = UnsafeNumericCast(std::fabs(rdecimal) * 0.01 + 0.00000001); + float epsilon = static_cast(std::fabs(rdecimal) * 0.01 + 0.00000001); return std::fabs(ldecimal - rdecimal) <= epsilon; } diff --git a/src/include/duckdb/common/operator/numeric_cast.hpp b/src/include/duckdb/common/operator/numeric_cast.hpp index b6d3b6742f80..f48839c4175e 100644 --- a/src/include/duckdb/common/operator/numeric_cast.hpp +++ b/src/include/duckdb/common/operator/numeric_cast.hpp @@ -75,7 +75,7 @@ bool TryCastWithOverflowCheckFloat(SRC value, T &result, SRC min, SRC max) { return false; } // PG FLOAT => INT casts use statistical rounding. - result = UnsafeNumericCast(std::nearbyint(value)); + result = static_cast(std::nearbyint(value)); return true; } @@ -182,7 +182,7 @@ bool TryCastWithOverflowCheck(double input, float &result) { return true; } auto res = float(input); - if (!Value::DoubleIsFinite(input)) { + if (!Value::FloatIsFinite(res)) { return false; } result = res; From e3ae8be9d3a520b44ffad76f4c5f8b88b87e9d63 Mon Sep 17 00:00:00 2001 From: Mark Raasveldt Date: Thu, 18 Apr 2024 13:00:56 +0200 Subject: [PATCH 160/201] Remove bound_defaults from BoundCreateTableInfo --- src/catalog/catalog_entry/duck_table_entry.cpp | 6 +++--- src/include/duckdb/planner/binder.hpp | 2 ++ .../parsed_data/bound_create_table_info.hpp | 2 -- .../binder/statement/bind_create_table.cpp | 15 +++++++++++---- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/catalog/catalog_entry/duck_table_entry.cpp b/src/catalog/catalog_entry/duck_table_entry.cpp index eb7be8d7a451..90c4de3ca0b7 100644 --- a/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/catalog/catalog_entry/duck_table_entry.cpp @@ -342,9 +342,9 @@ unique_ptr DuckTableEntry::AddColumn(ClientContext &context, AddCo create_info->columns.AddColumn(std::move(col)); auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); - auto new_storage = - make_shared(context, *storage, info.new_column, *bound_create_info->bound_defaults.back()); + vector> bound_defaults; + auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema, bound_defaults); + auto new_storage = make_shared(context, *storage, info.new_column, *bound_defaults.back()); return make_uniq(catalog, schema, *bound_create_info, new_storage); } diff --git a/src/include/duckdb/planner/binder.hpp b/src/include/duckdb/planner/binder.hpp index ed646e6b04b7..9d992fb407e7 100644 --- a/src/include/duckdb/planner/binder.hpp +++ b/src/include/duckdb/planner/binder.hpp @@ -119,6 +119,8 @@ class Binder : public std::enable_shared_from_this { unique_ptr BindCreateTableInfo(unique_ptr info); unique_ptr BindCreateTableInfo(unique_ptr info, SchemaCatalogEntry &schema); + unique_ptr BindCreateTableInfo(unique_ptr info, SchemaCatalogEntry &schema, + vector> &bound_defaults); void BindCreateViewInfo(CreateViewInfo &base); SchemaCatalogEntry &BindSchema(CreateInfo &info); diff --git a/src/include/duckdb/planner/parsed_data/bound_create_table_info.hpp b/src/include/duckdb/planner/parsed_data/bound_create_table_info.hpp index fe71dfbda6fe..c7a1aac57705 100644 --- a/src/include/duckdb/planner/parsed_data/bound_create_table_info.hpp +++ b/src/include/duckdb/planner/parsed_data/bound_create_table_info.hpp @@ -38,8 +38,6 @@ struct BoundCreateTableInfo { vector> constraints; //! List of bound constraints on the table vector> bound_constraints; - //! Bound default values - vector> bound_defaults; //! Dependents of the table (in e.g. default values) LogicalDependencyList dependencies; //! The existing table data on disk (if any) diff --git a/src/planner/binder/statement/bind_create_table.cpp b/src/planner/binder/statement/bind_create_table.cpp index 4564cb0f04c6..780ef2380497 100644 --- a/src/planner/binder/statement/bind_create_table.cpp +++ b/src/planner/binder/statement/bind_create_table.cpp @@ -239,8 +239,8 @@ static void ExtractExpressionDependencies(Expression &expr, LogicalDependencyLis expr, [&](Expression &child) { ExtractExpressionDependencies(child, dependencies); }); } -static void ExtractDependencies(BoundCreateTableInfo &info) { - for (auto &default_value : info.bound_defaults) { +static void ExtractDependencies(BoundCreateTableInfo &info, vector> &bound_defaults) { + for (auto &default_value : bound_defaults) { if (default_value) { ExtractExpressionDependencies(*default_value, info.dependencies); } @@ -252,7 +252,14 @@ static void ExtractDependencies(BoundCreateTableInfo &info) { } } } + unique_ptr Binder::BindCreateTableInfo(unique_ptr info, SchemaCatalogEntry &schema) { + vector> bound_defaults; + return BindCreateTableInfo(std::move(info), schema, bound_defaults); +} + +unique_ptr Binder::BindCreateTableInfo(unique_ptr info, SchemaCatalogEntry &schema, + vector> &bound_defaults) { auto &base = info->Cast(); auto result = make_uniq(schema, std::move(info)); if (base.query) { @@ -279,10 +286,10 @@ unique_ptr Binder::BindCreateTableInfo(unique_ptrbound_defaults); + BindDefaultValues(base.columns, bound_defaults); } // extract dependencies from any default values or CHECK constraints - ExtractDependencies(*result); + ExtractDependencies(*result, bound_defaults); if (base.columns.PhysicalColumnCount() == 0) { throw BinderException("Creating a table without physical (non-generated) columns is not supported"); From 5adcf391dac6d12359e48e4884426901a17e4542 Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 13:14:01 +0200 Subject: [PATCH 161/201] Explicitly define DUCKDB_MODULE_DIR and use that in invoking scripts --- CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0064a8b4445d..c6678fdf710b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,6 +24,8 @@ project(DuckDB) find_package(Threads REQUIRED) +set(DUCKDB_MODULE_BASE_DIR "${CMAKE_CURRENT_LIST_DIR}") + set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set (CMAKE_CXX_STANDARD 11) @@ -827,7 +829,7 @@ function(build_loadable_extension_directory NAME OUTPUT_DIRECTORY EXTENSION_VERS TARGET ${TARGET_NAME} POST_BUILD COMMAND - ${CMAKE_COMMAND} -DEXTENSION=$ -DPLATFORM_FILE=${DuckDB_BINARY_DIR}/duckdb_platform_out -DDUCKDB_VERSION="${DUCKDB_NORMALIZED_VERSION}" -DEXTENSION_VERSION="${EXTENSION_VERSION}" -DNULL_FILE=${CMAKE_CURRENT_FUNCTION_LIST_DIR}/scripts/null.txt -P ${CMAKE_CURRENT_FUNCTION_LIST_DIR}/scripts/append_metadata.cmake + ${CMAKE_COMMAND} -DEXTENSION=$ -DPLATFORM_FILE=${DuckDB_BINARY_DIR}/duckdb_platform_out -DDUCKDB_VERSION="${DUCKDB_NORMALIZED_VERSION}" -DEXTENSION_VERSION="${EXTENSION_VERSION}" -DNULL_FILE=${DUCKDB_MODULE_BASE_DIR}/scripts/null.txt -P ${DUCKDB_MODULE_BASE_DIR}/scripts/append_metadata.cmake ) add_dependencies(${TARGET_NAME} duckdb_platform) if (NOT EXTENSION_CONFIG_BUILD AND NOT ${EXTENSION_TESTS_ONLY} AND NOT CLANG_TIDY) From d2adbc8bb91008f3402143282292ae932c86682c Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 13:38:12 +0200 Subject: [PATCH 162/201] Properly avoid build-time dependency on Python Triggered by https://github.com/duckdb/duckdb/pull/11710 --- CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0064a8b4445d..d9b7f2045b5a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -715,6 +715,8 @@ endif() set(LOCAL_EXTENSION_REPO FALSE) if (NOT EXTENSION_CONFIG_BUILD AND NOT ${EXTENSION_TESTS_ONLY} AND NOT CLANG_TIDY) if (NOT Python3_FOUND) + add_custom_target( + duckdb_local_extension_repo ALL) MESSAGE(STATUS "Could not find python3, create extension directory step will be skipped") else() add_custom_target( From c9c03b2c8731c6908f49f5f108ae31813d167cad Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 18 Apr 2024 14:01:49 +0200 Subject: [PATCH 163/201] add duckdb::make_shared with static assert --- src/include/duckdb/common/helper.hpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/include/duckdb/common/helper.hpp b/src/include/duckdb/common/helper.hpp index 52e0dd65ea63..d40395855a8e 100644 --- a/src/include/duckdb/common/helper.hpp +++ b/src/include/duckdb/common/helper.hpp @@ -152,6 +152,14 @@ static duckdb::unique_ptr make_unique(ARGS&&... __args) { // NOLINT: mimic st return unique_ptr(new T(std::forward(__args)...)); } +template +static duckdb::shared_ptr make_shared(ARGS&&... __args) { // NOLINT: mimic std style +#ifndef DUCKDB_ENABLE_DEPRECATED_API + static_assert(sizeof(T) == 0, "Use make_shared_ptr instead of make_shared!"); +#endif // DUCKDB_ENABLE_DEPRECATED_API + return shared_ptr(new T(std::forward(__args)...)); +} + template constexpr T MaxValue(T a, T b) { return a > b ? a : b; From 9136c3de9a4248649a00566f3890e3dcf7f582b5 Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 14:53:13 +0200 Subject: [PATCH 164/201] Upload staging: from 'git describe --tags --long' to 'git log -1 --format=%h' This makes so that target is independent of actual status (pre or post tagging) --- .github/workflows/Python.yml | 8 ++++---- .github/workflows/StagedUpload.yml | 3 ++- .github/workflows/TwineUpload.yml | 2 +- scripts/upload-assets-to-staging.sh | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/Python.yml b/.github/workflows/Python.yml index 9d8d43efa313..ced2a2c88304 100644 --- a/.github/workflows/Python.yml +++ b/.github/workflows/Python.yml @@ -81,7 +81,7 @@ jobs: elif [[ -z "${{ inputs.override_git_describe }}" ]]; then echo "No override_git_describe provided" else - echo "UPLOAD_ASSETS_TO_STAGING_TARGET=$(git describe --tags --long)" >> "$GITHUB_ENV" + echo "UPLOAD_ASSETS_TO_STAGING_TARGET=$(git log -1 --format=%h)" >> "$GITHUB_ENV" echo "override_git_describe ${{ inputs.override_git_describe }}: add tag" git tag ${{ inputs.override_git_describe }} fi @@ -212,7 +212,7 @@ jobs: elif [[ -z "${{ inputs.override_git_describe }}" ]]; then echo "No override_git_describe provided" else - echo "UPLOAD_ASSETS_TO_STAGING_TARGET=$(git describe --tags --long)" >> "$GITHUB_ENV" + echo "UPLOAD_ASSETS_TO_STAGING_TARGET=$(git log -1 --format=%h)" >> "$GITHUB_ENV" echo "override_git_describe ${{ inputs.override_git_describe }}: add tag" git tag ${{ inputs.override_git_describe }} fi @@ -302,7 +302,7 @@ jobs: elif [[ -z "${{ inputs.override_git_describe }}" ]]; then echo "No override_git_describe provided" else - echo "UPLOAD_ASSETS_TO_STAGING_TARGET=$(git describe --tags --long)" >> "$GITHUB_ENV" + echo "UPLOAD_ASSETS_TO_STAGING_TARGET=$(git log -1 --format=%h)" >> "$GITHUB_ENV" echo "override_git_describe ${{ inputs.override_git_describe }}: add tag" git tag ${{ inputs.override_git_describe }} fi @@ -380,7 +380,7 @@ jobs: elif [[ -z "${{ inputs.override_git_describe }}" ]]; then echo "No override_git_describe provided" else - echo "UPLOAD_ASSETS_TO_STAGING_TARGET=$(git describe --tags --long)" >> "$GITHUB_ENV" + echo "UPLOAD_ASSETS_TO_STAGING_TARGET=$(git log -1 --format=%h)" >> "$GITHUB_ENV" echo "override_git_describe ${{ inputs.override_git_describe }}: add tag" git tag ${{ inputs.override_git_describe }} fi diff --git a/.github/workflows/StagedUpload.yml b/.github/workflows/StagedUpload.yml index 322f636421db..654070aeff2a 100644 --- a/.github/workflows/StagedUpload.yml +++ b/.github/workflows/StagedUpload.yml @@ -31,8 +31,9 @@ jobs: AWS_ACCESS_KEY_ID: ${{ secrets.S3_DUCKDB_STAGING_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_DUCKDB_STAGING_KEY }} run: | + TARGET=$(git log -1 --format=%h) mkdir to_be_uploaded - aws s3 cp --recursive "s3://duckdb-staging/${{ inputs.target_git_describe }}/$GITHUB_REPOSITORY/github_release" to_be_uploaded --region us-east-2 + aws s3 cp --recursive "s3://duckdb-staging/$TARGET/${{ inputs.target_git_describe }}/$GITHUB_REPOSITORY/github_release" to_be_uploaded --region us-east-2 - name: Deploy shell: bash diff --git a/.github/workflows/TwineUpload.yml b/.github/workflows/TwineUpload.yml index 1343fb691f58..58d27c2ccd4e 100644 --- a/.github/workflows/TwineUpload.yml +++ b/.github/workflows/TwineUpload.yml @@ -37,7 +37,7 @@ jobs: AWS_ACCESS_KEY_ID: ${{ secrets.S3_DUCKDB_STAGING_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_DUCKDB_STAGING_KEY }} run: | - TARGET=$(git describe --tags --long) + TARGET=$(git log -1 --format=%h) # decide target for staging if [ "$OVERRIDE_GIT_DESCRIBE" ]; then TARGET="$TARGET/$OVERRIDE_GIT_DESCRIBE" diff --git a/scripts/upload-assets-to-staging.sh b/scripts/upload-assets-to-staging.sh index 3b0d71ec828b..d2111d894800 100755 --- a/scripts/upload-assets-to-staging.sh +++ b/scripts/upload-assets-to-staging.sh @@ -45,7 +45,7 @@ if [ -z "$AWS_ACCESS_KEY_ID" ]; then fi -TARGET=$(git describe --tags --long) +TARGET=$(git log -1 --format=%h) if [ "$UPLOAD_ASSETS_TO_STAGING_TARGET" ]; then TARGET="$UPLOAD_ASSETS_TO_STAGING_TARGET" From e445c53157d136affe61cd94d3a279d781abf4b0 Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 15:07:39 +0200 Subject: [PATCH 165/201] Add OnTag.yml, to automatically trigger needed workflows on tag creations --- .github/workflows/OnTag.yml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 .github/workflows/OnTag.yml diff --git a/.github/workflows/OnTag.yml b/.github/workflows/OnTag.yml new file mode 100644 index 000000000000..a88b6eb3df66 --- /dev/null +++ b/.github/workflows/OnTag.yml @@ -0,0 +1,22 @@ +name: On Tag +on: + workflow_dispatch: + inputs: + override_git_describe: + type: string + push: + tags: + - 'v[0-9]+.[0-9]+.[0-9]+' + +jobs: + twine_upload: + uses: ./.github/workflows/TwineUpload.yml + secrets: inherit + with: + override_git_describe: ${{ inputs.override_git_describe || github.event.release.tag_name }} + + staged_upload: + uses: ./.github/workflows/StagedUpload.yml + secrets: inherit + with: + override_git_describe: ${{ inputs.override_git_describe || github.event.release.tag_name }} From 62de06159551912da70313cf946621432316cdeb Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 18 Apr 2024 16:07:34 +0200 Subject: [PATCH 166/201] wrongly translated settings from the sqllogic_test_runner.cpp file --- scripts/sqllogictest/result.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scripts/sqllogictest/result.py b/scripts/sqllogictest/result.py index 68b92964b153..535a1a5ba28f 100644 --- a/scripts/sqllogictest/result.py +++ b/scripts/sqllogictest/result.py @@ -772,10 +772,9 @@ def execute_load(self, load: Load): # set up the config file additional_config = {} if readonly: - additional_config['temp_directory'] = False + additional_config['temp_directory'] = "" additional_config['access_mode'] = 'read_only' else: - additional_config['temp_directory'] = True additional_config['access_mode'] = 'automatic' self.pool = None From 59f3dcd3e8b2bf285a261c4238c8dd62a6583c50 Mon Sep 17 00:00:00 2001 From: Mark Raasveldt Date: Thu, 18 Apr 2024 20:30:14 +0200 Subject: [PATCH 167/201] Improve mkdir error reporting --- src/common/local_file_system.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/common/local_file_system.cpp b/src/common/local_file_system.cpp index 744bce887cb1..029aa02cf584 100644 --- a/src/common/local_file_system.cpp +++ b/src/common/local_file_system.cpp @@ -540,7 +540,8 @@ void LocalFileSystem::CreateDirectory(const string &directory, optional_ptr Date: Thu, 18 Apr 2024 20:31:20 +0200 Subject: [PATCH 168/201] Consistent quotes --- src/common/local_file_system.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/local_file_system.cpp b/src/common/local_file_system.cpp index 029aa02cf584..a2e7ab0dce74 100644 --- a/src/common/local_file_system.cpp +++ b/src/common/local_file_system.cpp @@ -991,7 +991,7 @@ void LocalFileSystem::CreateDirectory(const string &directory, optional_ptr Date: Thu, 18 Apr 2024 20:26:43 +0200 Subject: [PATCH 169/201] Remove python3 build dependency from docker tests --- scripts/test_docker_images.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/test_docker_images.sh b/scripts/test_docker_images.sh index c2217a8b6488..ce12ef35ae47 100755 --- a/scripts/test_docker_images.sh +++ b/scripts/test_docker_images.sh @@ -1,5 +1,6 @@ #!/usr/bin/env bash make clean +docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk add g++ git make cmake ninja && GEN=ninja make && make clean" 2>&1 docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk add g++ git make cmake ninja python3 && GEN=ninja make && make clean" 2>&1 -docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb ubuntu:20.04 <<< "apt-get update && export DEBIAN_FRONTEND=noninteractive && apt-get install g++ git make cmake ninja-build python3 -y && GEN=ninja make && make clean" 2>&1 +docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb ubuntu:20.04 <<< "apt-get update && export DEBIAN_FRONTEND=noninteractive && apt-get install g++ git make cmake ninja-build -y && GEN=ninja make && make clean" 2>&1 From cd0e1435f68db5ffea9af049fb863965dde14a2a Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 21:04:15 +0200 Subject: [PATCH 170/201] Allow overriding C++ standard from [C]Make --- CMakeLists.txt | 2 +- Makefile | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index fb45218c66fa..ec210f000160 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,7 +28,7 @@ set(DUCKDB_MODULE_BASE_DIR "${CMAKE_CURRENT_LIST_DIR}") set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set (CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD "11" CACHE STRING "C++ standard to enforce") set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) diff --git a/Makefile b/Makefile index 31a4cdcb64de..248c541c2d5d 100644 --- a/Makefile +++ b/Makefile @@ -65,6 +65,9 @@ ifdef OVERRIDE_GIT_DESCRIBE else COMMON_CMAKE_VARS:=${COMMON_CMAKE_VARS} -DOVERRIDE_GIT_DESCRIBE="" endif +ifneq (${CXX_STANDARD}, ) + CMAKE_VARS:=${CMAKE_VARS} -DCMAKE_CXX_STANDARD="${CXX_STANDARD}" +endif ifneq (${DUCKDB_EXTENSIONS}, ) BUILD_EXTENSIONS:=${DUCKDB_EXTENSIONS} endif From da40bca1c2f5d8a57e50764ff754283e81a080ea Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 21:04:59 +0200 Subject: [PATCH 171/201] Fix asymmetrical operator== [C++20 error] --- src/include/duckdb/parser/base_expression.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/duckdb/parser/base_expression.hpp b/src/include/duckdb/parser/base_expression.hpp index aed84f5eeee3..a64baf7243fd 100644 --- a/src/include/duckdb/parser/base_expression.hpp +++ b/src/include/duckdb/parser/base_expression.hpp @@ -79,7 +79,7 @@ class BaseExpression { static bool Equals(const BaseExpression &left, const BaseExpression &right) { return left.Equals(right); } - bool operator==(const BaseExpression &rhs) { + bool operator==(const BaseExpression &rhs) const { return Equals(rhs); } From b92bbf90f819795909cb018427d4367781d1b278 Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 21:05:38 +0200 Subject: [PATCH 172/201] Fix compound assigment on volatile [C++20 error] --- third_party/mbedtls/library/constant_time.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/mbedtls/library/constant_time.cpp b/third_party/mbedtls/library/constant_time.cpp index a797dce390e3..d2a2239e5d87 100644 --- a/third_party/mbedtls/library/constant_time.cpp +++ b/third_party/mbedtls/library/constant_time.cpp @@ -61,7 +61,7 @@ int mbedtls_ct_memcmp( const void *a, * This avoids IAR compiler warning: * 'the order of volatile accesses is undefined ..' */ unsigned char x = A[i], y = B[i]; - diff |= x ^ y; + diff = (diff | (x ^ y)); } return( (int)diff ); From a802e6ca350dec56bd8358c971f17ff1b141dada Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 21:06:38 +0200 Subject: [PATCH 173/201] Add CI run on C++23 standard --- scripts/test_docker_images.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/test_docker_images.sh b/scripts/test_docker_images.sh index ce12ef35ae47..7f9deaf196ae 100755 --- a/scripts/test_docker_images.sh +++ b/scripts/test_docker_images.sh @@ -3,4 +3,5 @@ make clean docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk add g++ git make cmake ninja && GEN=ninja make && make clean" 2>&1 docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk add g++ git make cmake ninja python3 && GEN=ninja make && make clean" 2>&1 +docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk add g++ git make cmake ninja && CXX_STANDARD=23 GEN=ninja make && make clean" 2>&1 docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb ubuntu:20.04 <<< "apt-get update && export DEBIAN_FRONTEND=noninteractive && apt-get install g++ git make cmake ninja-build -y && GEN=ninja make && make clean" 2>&1 From 196b1d1fa875330fd8f93d9e5f9d1afcd008ba95 Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 21:11:06 +0200 Subject: [PATCH 174/201] Fix missing std:: namespaces [C++17 error] --- src/include/duckdb/common/shared_ptr.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/include/duckdb/common/shared_ptr.hpp b/src/include/duckdb/common/shared_ptr.hpp index f5ca7d762c3a..6d0910ca6bf7 100644 --- a/src/include/duckdb/common/shared_ptr.hpp +++ b/src/include/duckdb/common/shared_ptr.hpp @@ -23,13 +23,13 @@ namespace duckdb { #if _LIBCPP_STD_VER >= 17 template -struct __bounded_convertible_to_unbounded : false_type {}; +struct __bounded_convertible_to_unbounded : std::false_type {}; template -struct __bounded_convertible_to_unbounded<_Up[_Np], T> : is_same, _Up[]> {}; +struct __bounded_convertible_to_unbounded<_Up[_Np], T> : std::is_same, _Up[]> {}; template -struct compatible_with_t : _Or, __bounded_convertible_to_unbounded> {}; +struct compatible_with_t : std::_Or, __bounded_convertible_to_unbounded> {}; #else template struct compatible_with_t : std::is_convertible {}; // NOLINT: invalid case style From 044ec9e34128d65a95fcbbf78a6242a2ab2f3844 Mon Sep 17 00:00:00 2001 From: Mark Raasveldt Date: Thu, 18 Apr 2024 21:52:57 +0200 Subject: [PATCH 175/201] Initial version of removal of BoundConstraint from TableCatalogEntry --- .../catalog_entry/duck_table_entry.cpp | 39 +++--- .../catalog_entry/table_catalog_entry.cpp | 10 +- .../persistent/physical_batch_insert.cpp | 8 +- .../operator/persistent/physical_delete.cpp | 12 +- .../operator/persistent/physical_insert.cpp | 25 +++- .../operator/persistent/physical_update.cpp | 24 +++- .../table/system/duckdb_constraints.cpp | 32 +++-- .../catalog_entry/duck_table_entry.hpp | 6 +- .../catalog_entry/table_catalog_entry.hpp | 5 +- src/include/duckdb/planner/binder.hpp | 4 + .../parsed_data/bound_create_table_info.hpp | 2 - src/include/duckdb/storage/data_table.hpp | 25 ++-- .../duckdb/storage/table/append_state.hpp | 9 ++ .../duckdb/storage/table/delete_state.hpp | 23 ++++ .../duckdb/storage/table/update_state.hpp | 20 +++ .../binder/statement/bind_create_table.cpp | 130 +++++++++++------- src/planner/binder/statement/bind_update.cpp | 2 +- src/storage/data_table.cpp | 116 +++++++++------- src/storage/wal_replay.cpp | 3 +- 19 files changed, 326 insertions(+), 169 deletions(-) create mode 100644 src/include/duckdb/storage/table/delete_state.hpp create mode 100644 src/include/duckdb/storage/table/update_state.hpp diff --git a/src/catalog/catalog_entry/duck_table_entry.cpp b/src/catalog/catalog_entry/duck_table_entry.cpp index 90c4de3ca0b7..f77e69f65eba 100644 --- a/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/catalog/catalog_entry/duck_table_entry.cpp @@ -71,10 +71,13 @@ IndexStorageInfo GetIndexInfo(const IndexConstraintType &constraint_type, unique return IndexStorageInfo(constraint_name + create_table_info.table + "_" + to_string(idx)); } +vector GetUniqueConstraintKeys(const ColumnList &columns, const UniqueConstraint &constraint) { + throw InternalException("FIXME: GetUniqueConstraintKeys"); +} + DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, BoundCreateTableInfo &info, std::shared_ptr inherited_storage) : TableCatalogEntry(catalog, schema, info.Base()), storage(std::move(inherited_storage)), - bound_constraints(std::move(info.bound_constraints)), column_dependency_manager(std::move(info.column_dependency_manager)) { if (!storage) { @@ -88,21 +91,19 @@ DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, Bou // create the unique indexes for the UNIQUE and PRIMARY KEY and FOREIGN KEY constraints idx_t indexes_idx = 0; - for (idx_t i = 0; i < bound_constraints.size(); i++) { - - auto &constraint = bound_constraints[i]; - + for (idx_t i = 0; i < constraints.size(); i++) { + auto &constraint = constraints[i]; if (constraint->type == ConstraintType::UNIQUE) { // unique constraint: create a unique index - auto &unique = constraint->Cast(); + auto &unique = constraint->Cast(); IndexConstraintType constraint_type = IndexConstraintType::UNIQUE; if (unique.is_primary_key) { constraint_type = IndexConstraintType::PRIMARY; } - + auto unique_keys = GetUniqueConstraintKeys(columns, unique); if (info.indexes.empty()) { - AddDataTableIndex(*storage, columns, unique.keys, constraint_type, + AddDataTableIndex(*storage, columns, unique_keys, constraint_type, GetIndexInfo(constraint_type, info.base, i)); } else { // we read the index from an old storage version, so we have to apply a dummy name @@ -112,13 +113,12 @@ DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, Bou } // now add the index - AddDataTableIndex(*storage, columns, unique.keys, constraint_type, info.indexes[indexes_idx++]); + AddDataTableIndex(*storage, columns, unique_keys, constraint_type, info.indexes[indexes_idx++]); } } else if (constraint->type == ConstraintType::FOREIGN_KEY) { - // foreign key constraint: create a foreign key index - auto &bfk = constraint->Cast(); + auto &bfk = constraint->Cast(); if (bfk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE || bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { @@ -351,10 +351,10 @@ unique_ptr DuckTableEntry::AddColumn(ClientContext &context, AddCo void DuckTableEntry::UpdateConstraintsOnColumnDrop(const LogicalIndex &removed_index, const vector &adjusted_indices, const RemoveColumnInfo &info, CreateTableInfo &create_info, - bool is_generated) { + const vector> &bound_constraints, + bool is_generated) { // handle constraints for the new table D_ASSERT(constraints.size() == bound_constraints.size()); - for (idx_t constr_idx = 0; constr_idx < constraints.size(); constr_idx++) { auto &constraint = constraints[constr_idx]; auto &bound_constraint = bound_constraints[constr_idx]; @@ -472,9 +472,11 @@ unique_ptr DuckTableEntry::RemoveColumn(ClientContext &context, Re } auto adjusted_indices = column_dependency_manager.RemoveColumn(removed_index, columns.LogicalColumnCount()); - UpdateConstraintsOnColumnDrop(removed_index, adjusted_indices, info, *create_info, dropped_column_is_generated); - auto binder = Binder::CreateBinder(context); + auto bound_constraints = binder->BindConstraints(constraints, name, columns); + + UpdateConstraintsOnColumnDrop(removed_index, adjusted_indices, info, *create_info, bound_constraints, dropped_column_is_generated); + auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); if (columns.GetColumn(LogicalIndex(removed_index)).Generated()) { return make_uniq(catalog, schema, *bound_create_info, storage); @@ -583,6 +585,8 @@ unique_ptr DuckTableEntry::ChangeColumnType(ClientContext &context create_info->temporary = temporary; create_info->comment = comment; + auto binder = Binder::CreateBinder(context); + auto bound_constraints = binder->BindConstraints(constraints, name, columns); for (auto &col : columns.Logical()) { auto copy = col.Copy(); if (change_idx == col.Logical()) { @@ -643,7 +647,6 @@ unique_ptr DuckTableEntry::ChangeColumnType(ClientContext &context create_info->constraints.push_back(std::move(constraint)); } - auto binder = Binder::CreateBinder(context); // bind the specified expression vector bound_columns; AlterBinder expr_binder(*binder, context, *this, bound_columns, info.target_type); @@ -785,10 +788,6 @@ DataTable &DuckTableEntry::GetStorage() { return *storage; } -const vector> &DuckTableEntry::GetBoundConstraints() { - return bound_constraints; -} - TableFunction DuckTableEntry::GetScanFunction(ClientContext &context, unique_ptr &bind_data) { bind_data = make_uniq(*this); return TableScanFunction::GetFunction(); diff --git a/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/catalog/catalog_entry/table_catalog_entry.cpp index f89d214f5465..493403830e2a 100644 --- a/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -172,11 +172,6 @@ const vector> &TableCatalogEntry::GetConstraints() { DataTable &TableCatalogEntry::GetStorage() { throw InternalException("Calling GetStorage on a TableCatalogEntry that is not a DuckTableEntry"); } - -const vector> &TableCatalogEntry::GetBoundConstraints() { - throw InternalException("Calling GetBoundConstraints on a TableCatalogEntry that is not a DuckTableEntry"); -} - // LCOV_EXCL_STOP static void BindExtraColumns(TableCatalogEntry &table, LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, @@ -239,14 +234,15 @@ vector TableCatalogEntry::GetColumnSegmentInfo() { return {}; } -void TableCatalogEntry::BindUpdateConstraints(LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, +void TableCatalogEntry::BindUpdateConstraints(Binder &binder, LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, ClientContext &context) { // check the constraints and indexes of the table to see if we need to project any additional columns // we do this for indexes with multiple columns and CHECK constraints in the UPDATE clause // suppose we have a constraint CHECK(i + j < 10); now we need both i and j to check the constraint // if we are only updating one of the two columns we add the other one to the UPDATE set // with a "useless" update (i.e. i=i) so we can verify that the CHECK constraint is not violated - for (auto &constraint : GetBoundConstraints()) { + auto bound_constraints = binder.BindConstraints(constraints, name, columns); + for (auto &constraint : bound_constraints) { if (constraint->type == ConstraintType::CHECK) { auto &check = constraint->Cast(); // check constraint! check if we need to add any extra columns to the UPDATE clause diff --git a/src/execution/operator/persistent/physical_batch_insert.cpp b/src/execution/operator/persistent/physical_batch_insert.cpp index 04a137aa3c47..210f0b3d201d 100644 --- a/src/execution/operator/persistent/physical_batch_insert.cpp +++ b/src/execution/operator/persistent/physical_batch_insert.cpp @@ -171,6 +171,7 @@ class BatchInsertLocalState : public LocalSinkState { TableAppendState current_append_state; unique_ptr current_collection; optional_ptr writer; + unique_ptr constraint_state; void CreateNewCollection(DuckTableEntry &table, const vector &insert_types) { auto &table_info = table.GetStorage().info; @@ -494,7 +495,10 @@ SinkResultType PhysicalBatchInsert::Sink(ExecutionContext &context, DataChunk &c throw InternalException("Current batch differs from batch - but NextBatch was not called!?"); } - table.GetStorage().VerifyAppendConstraints(table, context.client, lstate.insert_chunk); + if (!lstate.constraint_state) { + lstate.constraint_state = table.GetStorage().InitializeConstraintVerification(table, context.client); + } + table.GetStorage().VerifyAppendConstraints(*lstate.constraint_state, context.client, lstate.insert_chunk); auto new_row_group = lstate.current_collection->Append(lstate.insert_chunk, lstate.current_append_state); if (new_row_group) { @@ -595,7 +599,7 @@ SinkFinalizeType PhysicalBatchInsert::Finalize(Pipeline &pipeline, Event &event, auto &table = gstate.table; auto &storage = table.GetStorage(); LocalAppendState append_state; - storage.InitializeLocalAppend(append_state, context); + storage.InitializeLocalAppend(append_state, table, context); auto &transaction = DuckTransaction::Get(context, table.catalog); for (auto &entry : gstate.collections) { if (entry.type != RowGroupBatchType::NOT_FLUSHED) { diff --git a/src/execution/operator/persistent/physical_delete.cpp b/src/execution/operator/persistent/physical_delete.cpp index 4fc17049032a..300376ce27fe 100644 --- a/src/execution/operator/persistent/physical_delete.cpp +++ b/src/execution/operator/persistent/physical_delete.cpp @@ -6,6 +6,7 @@ #include "duckdb/storage/data_table.hpp" #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/storage/table/delete_state.hpp" namespace duckdb { @@ -25,10 +26,12 @@ class DeleteGlobalState : public GlobalSinkState { class DeleteLocalState : public LocalSinkState { public: - DeleteLocalState(Allocator &allocator, const vector &table_types) { - delete_chunk.Initialize(allocator, table_types); + DeleteLocalState(ClientContext &context, TableCatalogEntry &table) { + delete_chunk.Initialize(Allocator::Get(context), table.GetTypes()); + delete_state = table.GetStorage().InitializeDelete(table, context); } DataChunk delete_chunk; + unique_ptr delete_state; }; SinkResultType PhysicalDelete::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { @@ -52,8 +55,7 @@ SinkResultType PhysicalDelete::Sink(ExecutionContext &context, DataChunk &chunk, table.Fetch(transaction, ustate.delete_chunk, column_ids, row_identifiers, chunk.size(), cfs); gstate.return_collection.Append(ustate.delete_chunk); } - gstate.deleted_count += table.Delete(tableref, context.client, row_identifiers, chunk.size()); - + gstate.deleted_count += table.Delete(*ustate.delete_state, context.client, row_identifiers, chunk.size()); return SinkResultType::NEED_MORE_INPUT; } @@ -62,7 +64,7 @@ unique_ptr PhysicalDelete::GetGlobalSinkState(ClientContext &co } unique_ptr PhysicalDelete::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(Allocator::Get(context.client), table.GetTypes()); + return make_uniq(context.client, tableref); } //===--------------------------------------------------------------------===// diff --git a/src/execution/operator/persistent/physical_insert.cpp b/src/execution/operator/persistent/physical_insert.cpp index 768334b50abe..181216a15b7f 100644 --- a/src/execution/operator/persistent/physical_insert.cpp +++ b/src/execution/operator/persistent/physical_insert.cpp @@ -17,6 +17,7 @@ #include "duckdb/execution/index/art/art.hpp" #include "duckdb/transaction/duck_transaction.hpp" #include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/table/update_state.hpp" namespace duckdb { @@ -105,6 +106,14 @@ class InsertLocalState : public LocalSinkState { // Rows in the transaction-local storage that have been updated by a DO UPDATE conflict unordered_set updated_local_rows; idx_t update_count = 0; + unique_ptr constraint_state; + + ConstraintVerificationState &GetConstraintState(DataTable &table, TableCatalogEntry &tableref, ClientContext &context) { + if (!constraint_state) { + constraint_state = table.InitializeConstraintVerification(tableref, context); + } + return *constraint_state; + } }; unique_ptr PhysicalInsert::GetGlobalSinkState(ClientContext &context) const { @@ -278,7 +287,8 @@ static idx_t PerformOnConflictAction(ExecutionContext &context, DataChunk &chunk auto &data_table = table.GetStorage(); // Perform the update, using the results of the SET expressions if (GLOBAL) { - data_table.Update(table, context.client, row_ids, set_columns, update_chunk); + auto update_state = data_table.InitializeUpdate(table, context.client); + data_table.Update(*update_state, context.client, row_ids, set_columns, update_chunk); } else { auto &local_storage = LocalStorage::Get(context.client, data_table.db); // Perform the update, using the results of the SET expressions @@ -320,7 +330,8 @@ static idx_t HandleInsertConflicts(TableCatalogEntry &table, ExecutionContext &c ConflictInfo conflict_info(conflict_target); ConflictManager conflict_manager(VerifyExistenceType::APPEND, lstate.insert_chunk.size(), &conflict_info); if (GLOBAL) { - data_table.VerifyAppendConstraints(table, context.client, lstate.insert_chunk, &conflict_manager); + auto &constraint_state = lstate.GetConstraintState(data_table, table, context.client); + data_table.VerifyAppendConstraints(constraint_state, context.client, lstate.insert_chunk, &conflict_manager); } else { DataTable::VerifyUniqueIndexes(local_storage.GetIndexes(data_table), context.client, lstate.insert_chunk, &conflict_manager); @@ -380,7 +391,8 @@ static idx_t HandleInsertConflicts(TableCatalogEntry &table, ExecutionContext &c combined_chunk.Slice(sel.Selection(), sel.Count()); row_ids.Slice(sel.Selection(), sel.Count()); if (GLOBAL) { - data_table.VerifyAppendConstraints(table, context.client, combined_chunk, nullptr); + auto &constraint_state = lstate.GetConstraintState(data_table, table, context.client); + data_table.VerifyAppendConstraints(constraint_state, context.client, combined_chunk, nullptr); } else { DataTable::VerifyUniqueIndexes(local_storage.GetIndexes(data_table), context.client, lstate.insert_chunk, nullptr); @@ -406,7 +418,8 @@ idx_t PhysicalInsert::OnConflictHandling(TableCatalogEntry &table, ExecutionCont InsertLocalState &lstate) const { auto &data_table = table.GetStorage(); if (action_type == OnConflictAction::THROW) { - data_table.VerifyAppendConstraints(table, context.client, lstate.insert_chunk, nullptr); + auto &constraint_state = lstate.GetConstraintState(data_table, table, context.client); + data_table.VerifyAppendConstraints(constraint_state, context.client, lstate.insert_chunk, nullptr); return 0; } // Check whether any conflicts arise, and if they all meet the conflict_target + condition @@ -429,7 +442,7 @@ SinkResultType PhysicalInsert::Sink(ExecutionContext &context, DataChunk &chunk, if (!parallel) { if (!gstate.initialized) { - storage.InitializeLocalAppend(gstate.append_state, context.client); + storage.InitializeLocalAppend(gstate.append_state, table, context.client); gstate.initialized = true; } @@ -487,7 +500,7 @@ SinkCombineResultType PhysicalInsert::Combine(ExecutionContext &context, Operato // we have few rows - append to the local storage directly auto &table = gstate.table; auto &storage = table.GetStorage(); - storage.InitializeLocalAppend(gstate.append_state, context.client); + storage.InitializeLocalAppend(gstate.append_state, table, context.client); auto &transaction = DuckTransaction::Get(context.client, table.catalog); lstate.local_collection->Scan(transaction, [&](DataChunk &insert_chunk) { storage.LocalAppend(gstate.append_state, table, context.client, insert_chunk); diff --git a/src/execution/operator/persistent/physical_update.cpp b/src/execution/operator/persistent/physical_update.cpp index c8bab2854c06..6c8c6c041028 100644 --- a/src/execution/operator/persistent/physical_update.cpp +++ b/src/execution/operator/persistent/physical_update.cpp @@ -8,6 +8,8 @@ #include "duckdb/parallel/thread_context.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/table/delete_state.hpp" +#include "duckdb/storage/table/update_state.hpp" namespace duckdb { @@ -55,6 +57,22 @@ class UpdateLocalState : public LocalSinkState { DataChunk update_chunk; DataChunk mock_chunk; ExpressionExecutor default_executor; + unique_ptr delete_state; + unique_ptr update_state; + + TableDeleteState &GetDeleteState(DataTable &table, TableCatalogEntry &tableref, ClientContext &context) { + if (!delete_state) { + delete_state = table.InitializeDelete(tableref, context); + } + return *delete_state; + } + + TableUpdateState &GetUpdateState(DataTable &table, TableCatalogEntry &tableref, ClientContext &context) { + if (!update_state) { + update_state = table.InitializeUpdate(tableref, context); + } + return *update_state; + } }; SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { @@ -106,7 +124,8 @@ SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, // we need to slice here update_chunk.Slice(sel, update_count); } - table.Delete(tableref, context.client, row_ids, update_chunk.size()); + auto &delete_state = lstate.GetDeleteState(table, tableref, context.client); + table.Delete(delete_state, context.client, row_ids, update_chunk.size()); // for the append we need to arrange the columns in a specific manner (namely the "standard table order") mock_chunk.SetCardinality(update_chunk); for (idx_t i = 0; i < columns.size(); i++) { @@ -120,7 +139,8 @@ SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, mock_chunk.data[columns[i].index].Reference(update_chunk.data[i]); } } - table.Update(tableref, context.client, row_ids, columns, update_chunk); + auto &update_state = lstate.GetUpdateState(table, tableref, context.client); + table.Update(update_state, context.client, row_ids, columns, update_chunk); } if (return_chunk) { diff --git a/src/function/table/system/duckdb_constraints.cpp b/src/function/table/system/duckdb_constraints.cpp index 467aa2ae7605..c35eaf0da9ad 100644 --- a/src/function/table/system/duckdb_constraints.cpp +++ b/src/function/table/system/duckdb_constraints.cpp @@ -15,6 +15,7 @@ #include "duckdb/planner/constraints/bound_not_null_constraint.hpp" #include "duckdb/planner/constraints/bound_foreign_key_constraint.hpp" #include "duckdb/storage/data_table.hpp" +#include "duckdb/planner/binder.hpp" namespace duckdb { @@ -49,11 +50,24 @@ struct hash { namespace duckdb { +struct ConstraintEntry { + ConstraintEntry(ClientContext &context, TableCatalogEntry &table) : table(table) { + if (!table.IsDuckTable()) { + return; + } + auto binder = Binder::CreateBinder(context); + bound_constraints = binder->BindConstraints(table.GetConstraints(), table.name, table.GetColumns()); + } + + TableCatalogEntry &table; + vector> bound_constraints; +}; + struct DuckDBConstraintsData : public GlobalTableFunctionState { DuckDBConstraintsData() : offset(0), constraint_offset(0), unique_constraint_offset(0) { } - vector> entries; + vector entries; idx_t offset; idx_t constraint_offset; idx_t unique_constraint_offset; @@ -118,8 +132,9 @@ unique_ptr DuckDBConstraintsInit(ClientContext &contex }); sort(entries.begin(), entries.end(), [&](CatalogEntry &x, CatalogEntry &y) { return (x.name < y.name); }); - - result->entries.insert(result->entries.end(), entries.begin(), entries.end()); + for(auto &entry : entries) { + result->entries.emplace_back(context, entry.get().Cast()); + } }; return std::move(result); @@ -135,10 +150,9 @@ void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_ // either fill up the chunk or return all the remaining columns idx_t count = 0; while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset].get(); - D_ASSERT(entry.type == CatalogType::TABLE_ENTRY); + auto &entry = data.entries[data.offset]; - auto &table = entry.Cast(); + auto &table = entry.table; auto &constraints = table.GetConstraints(); bool is_duck_table = table.IsDuckTable(); for (; data.constraint_offset < constraints.size() && count < STANDARD_VECTOR_SIZE; data.constraint_offset++) { @@ -163,7 +177,7 @@ void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_ if (!is_duck_table) { continue; } - auto &bound_constraints = table.GetBoundConstraints(); + auto &bound_constraints = entry.bound_constraints; auto &bound_foreign_key = bound_constraints[data.constraint_offset]->Cast(); if (bound_foreign_key.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE) { // Those are already covered by PRIMARY KEY and UNIQUE entries @@ -194,7 +208,7 @@ void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_ UniqueKeyInfo uk_info; if (is_duck_table) { - auto &bound_constraint = *table.GetBoundConstraints()[data.constraint_offset]; + auto &bound_constraint = *entry.bound_constraints[data.constraint_offset]; switch (bound_constraint.type) { case ConstraintType::UNIQUE: { auto &bound_unique = bound_constraint.Cast(); @@ -251,7 +265,7 @@ void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_ vector column_index_list; if (is_duck_table) { - auto &bound_constraint = *table.GetBoundConstraints()[data.constraint_offset]; + auto &bound_constraint = *entry.bound_constraints[data.constraint_offset]; switch (bound_constraint.type) { case ConstraintType::CHECK: { auto &bound_check = bound_constraint.Cast(); diff --git a/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp b/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp index 0890ce829ab4..7c1323a4f300 100644 --- a/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp +++ b/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp @@ -24,8 +24,6 @@ class DuckTableEntry : public TableCatalogEntry { void UndoAlter(ClientContext &context, AlterInfo &info) override; //! Returns the underlying storage of the table DataTable &GetStorage() override; - //! Returns a list of the bound constraints of the table - const vector> &GetBoundConstraints() override; //! Get statistics of a column (physical or virtual) within the table unique_ptr GetStatistics(ClientContext &context, column_t column_id) override; @@ -60,13 +58,11 @@ class DuckTableEntry : public TableCatalogEntry { unique_ptr SetColumnComment(ClientContext &context, SetColumnCommentInfo &info); void UpdateConstraintsOnColumnDrop(const LogicalIndex &removed_index, const vector &adjusted_indices, - const RemoveColumnInfo &info, CreateTableInfo &create_info, bool is_generated); + const RemoveColumnInfo &info, CreateTableInfo &create_info, const vector> &bound_constraints, bool is_generated); private: //! A reference to the underlying storage unit used for this table std::shared_ptr storage; - //! A list of constraints that are part of this table - vector> bound_constraints; //! Manages dependencies of the individual columns of the table ColumnDependencyManager column_dependency_manager; }; diff --git a/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp b/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp index 243765a45006..fa9a795f2b0f 100644 --- a/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp +++ b/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp @@ -37,6 +37,7 @@ struct SetColumnCommentInfo; class TableFunction; struct FunctionData; +class Binder; class TableColumnInfo; struct ColumnSegmentInfo; class TableStorageInfo; @@ -74,8 +75,6 @@ class TableCatalogEntry : public StandardEntry { DUCKDB_API const ColumnList &GetColumns() const; //! Returns the underlying storage of the table virtual DataTable &GetStorage(); - //! Returns a list of the bound constraints of the table - virtual const vector> &GetBoundConstraints(); //! Returns a list of the constraints of the table DUCKDB_API const vector> &GetConstraints(); @@ -105,7 +104,7 @@ class TableCatalogEntry : public StandardEntry { //! Returns the storage info of this table virtual TableStorageInfo GetStorageInfo(ClientContext &context) = 0; - virtual void BindUpdateConstraints(LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, + virtual void BindUpdateConstraints(Binder &binder, LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, ClientContext &context); protected: diff --git a/src/include/duckdb/planner/binder.hpp b/src/include/duckdb/planner/binder.hpp index 9d992fb407e7..67e2bc1ec6b1 100644 --- a/src/include/duckdb/planner/binder.hpp +++ b/src/include/duckdb/planner/binder.hpp @@ -43,6 +43,7 @@ class ColumnList; class ExternalDependency; class TableFunction; class TableStorageInfo; +class BoundConstraint; struct CreateInfo; struct BoundCreateTableInfo; @@ -121,6 +122,9 @@ class Binder : public std::enable_shared_from_this { unique_ptr BindCreateTableInfo(unique_ptr info, SchemaCatalogEntry &schema); unique_ptr BindCreateTableInfo(unique_ptr info, SchemaCatalogEntry &schema, vector> &bound_defaults); + static vector> BindConstraints(ClientContext &context, const vector> &constraints, const string &table_name, const ColumnList &columns); + vector> BindConstraints(const vector> &constraints, const string &table_name, const ColumnList &columns); + vector> BindNewConstraints(vector> &constraints, const string &table_name, const ColumnList &columns); void BindCreateViewInfo(CreateViewInfo &base); SchemaCatalogEntry &BindSchema(CreateInfo &info); diff --git a/src/include/duckdb/planner/parsed_data/bound_create_table_info.hpp b/src/include/duckdb/planner/parsed_data/bound_create_table_info.hpp index c7a1aac57705..049aca82246d 100644 --- a/src/include/duckdb/planner/parsed_data/bound_create_table_info.hpp +++ b/src/include/duckdb/planner/parsed_data/bound_create_table_info.hpp @@ -36,8 +36,6 @@ struct BoundCreateTableInfo { ColumnDependencyManager column_dependency_manager; //! List of constraints on the table vector> constraints; - //! List of bound constraints on the table - vector> bound_constraints; //! Dependents of the table (in e.g. default values) LogicalDependencyList dependencies; //! The existing table data on disk (if any) diff --git a/src/include/duckdb/storage/data_table.hpp b/src/include/duckdb/storage/data_table.hpp index c6ae98a9f24e..74a5c2283830 100644 --- a/src/include/duckdb/storage/data_table.hpp +++ b/src/include/duckdb/storage/data_table.hpp @@ -41,6 +41,9 @@ class WriteAheadLog; class TableDataWriter; class ConflictManager; class TableScanState; +struct TableDeleteState; +struct ConstraintVerificationState; +struct TableUpdateState; enum class VerifyExistenceType : uint8_t; //! DataTable represents a physical table on disk @@ -92,7 +95,7 @@ class DataTable { const Vector &row_ids, idx_t fetch_count, ColumnFetchState &state); //! Initializes an append to transaction-local storage - void InitializeLocalAppend(LocalAppendState &state, ClientContext &context); + void InitializeLocalAppend(LocalAppendState &state, TableCatalogEntry &table, ClientContext &context); //! Append a DataChunk to the transaction-local storage of the table. void LocalAppend(LocalAppendState &state, TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, bool unsafe = false); @@ -108,10 +111,14 @@ class DataTable { OptimisticDataWriter &CreateOptimisticWriter(ClientContext &context); void FinalizeOptimisticWriter(ClientContext &context, OptimisticDataWriter &writer); + unique_ptr InitializeDelete(TableCatalogEntry &table, ClientContext &context); //! Delete the entries with the specified row identifier from the table - idx_t Delete(TableCatalogEntry &table, ClientContext &context, Vector &row_ids, idx_t count); + idx_t Delete(TableDeleteState &state, ClientContext &context, Vector &row_ids, idx_t count); + + + unique_ptr InitializeUpdate(TableCatalogEntry &table, ClientContext &context); //! Update the entries with the specified row identifier from the table - void Update(TableCatalogEntry &table, ClientContext &context, Vector &row_ids, + void Update(TableUpdateState &state, ClientContext &context, Vector &row_ids, const vector &column_ids, DataChunk &data); //! Update a single (sub-)column along a column path //! The column_path vector is a *path* towards a column within the table @@ -186,22 +193,24 @@ class DataTable { //! FIXME: This is only necessary until we treat all indexes as catalog entries, allowing to alter constraints bool IndexNameIsUnique(const string &name); + //! Initialize constraint verification + unique_ptr InitializeConstraintVerification(TableCatalogEntry &table, ClientContext &context); //! Verify constraints with a chunk from the Append containing all columns of the table - void VerifyAppendConstraints(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, - ConflictManager *conflict_manager = nullptr); + void VerifyAppendConstraints(ConstraintVerificationState &state, ClientContext &context, DataChunk &chunk, + optional_ptr conflict_manager = nullptr); public: static void VerifyUniqueIndexes(TableIndexList &indexes, ClientContext &context, DataChunk &chunk, - ConflictManager *conflict_manager); + optional_ptr conflict_manager); private: //! Verify the new added constraints against current persistent&local data void VerifyNewConstraint(ClientContext &context, DataTable &parent, const BoundConstraint *constraint); //! Verify constraints with a chunk from the Update containing only the specified column_ids - void VerifyUpdateConstraints(ClientContext &context, TableCatalogEntry &table, DataChunk &chunk, + void VerifyUpdateConstraints(ConstraintVerificationState &state, ClientContext &context, DataChunk &chunk, const vector &column_ids); //! Verify constraints with a chunk from the Delete containing all columns of the table - void VerifyDeleteConstraints(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk); + void VerifyDeleteConstraints(TableDeleteState &state, ClientContext &context, DataChunk &chunk); void InitializeScanWithOffset(TableScanState &state, const vector &column_ids, idx_t start_row, idx_t end_row); diff --git a/src/include/duckdb/storage/table/append_state.hpp b/src/include/duckdb/storage/table/append_state.hpp index 42cd29befad0..475a59187f1f 100644 --- a/src/include/duckdb/storage/table/append_state.hpp +++ b/src/include/duckdb/storage/table/append_state.hpp @@ -14,6 +14,7 @@ #include "duckdb/common/vector.hpp" #include "duckdb/function/compression_function.hpp" #include "duckdb/transaction/transaction_data.hpp" +#include "duckdb/planner/bound_constraint.hpp" namespace duckdb { class ColumnSegment; @@ -69,9 +70,17 @@ struct TableAppendState { TransactionData transaction; }; +struct ConstraintVerificationState { + explicit ConstraintVerificationState(TableCatalogEntry &table_p) : table(table_p) {} + + TableCatalogEntry &table; + vector> bound_constraints; +}; + struct LocalAppendState { TableAppendState append_state; LocalTableStorage *storage; + unique_ptr constraint_state; }; } // namespace duckdb diff --git a/src/include/duckdb/storage/table/delete_state.hpp b/src/include/duckdb/storage/table/delete_state.hpp new file mode 100644 index 000000000000..ec2c2d57c547 --- /dev/null +++ b/src/include/duckdb/storage/table/delete_state.hpp @@ -0,0 +1,23 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/delete_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/append_state.hpp" + +namespace duckdb { +class TableCatalogEntry; + +struct TableDeleteState { + vector> bound_constraints; + bool has_delete_constraints = false; + DataChunk verify_chunk; + vector col_ids; +}; + +} // namespace duckdb diff --git a/src/include/duckdb/storage/table/update_state.hpp b/src/include/duckdb/storage/table/update_state.hpp new file mode 100644 index 000000000000..50ce404e0132 --- /dev/null +++ b/src/include/duckdb/storage/table/update_state.hpp @@ -0,0 +1,20 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/update_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/append_state.hpp" + +namespace duckdb { +class TableCatalogEntry; + +struct TableUpdateState { + unique_ptr constraint_state; +}; + +} // namespace duckdb diff --git a/src/planner/binder/statement/bind_create_table.cpp b/src/planner/binder/statement/bind_create_table.cpp index 780ef2380497..10ef50599a08 100644 --- a/src/planner/binder/statement/bind_create_table.cpp +++ b/src/planner/binder/statement/bind_create_table.cpp @@ -35,61 +35,56 @@ static void CreateColumnDependencyManager(BoundCreateTableInfo &info) { } } -static void BindCheckConstraint(Binder &binder, BoundCreateTableInfo &info, const unique_ptr &cond) { - auto &base = info.base->Cast(); - +static unique_ptr BindCheckConstraint(Binder &binder, const string &table_name, const ColumnList &columns, const unique_ptr &cond) { auto bound_constraint = make_uniq(); // check constraint: bind the expression - CheckBinder check_binder(binder, binder.context, base.table, base.columns, bound_constraint->bound_columns); + CheckBinder check_binder(binder, binder.context, table_name, columns, bound_constraint->bound_columns); auto &check = cond->Cast(); // create a copy of the unbound expression because the binding destroys the constraint auto unbound_expression = check.expression->Copy(); // now bind the constraint and create a new BoundCheckConstraint - bound_constraint->expression = check_binder.Bind(check.expression); - info.bound_constraints.push_back(std::move(bound_constraint)); - // move the unbound constraint back into the original check expression - check.expression = std::move(unbound_expression); + bound_constraint->expression = check_binder.Bind(unbound_expression); + return std::move(bound_constraint); } -static void BindConstraints(Binder &binder, BoundCreateTableInfo &info) { - auto &base = info.base->Cast(); +vector> Binder::BindConstraints(ClientContext &context, const vector> &constraints, const string &table_name, const ColumnList &columns) { + auto binder = Binder::CreateBinder(context); + return binder->BindConstraints(constraints, table_name, columns); +} - bool has_primary_key = false; - logical_index_set_t not_null_columns; - vector primary_keys; - for (idx_t i = 0; i < base.constraints.size(); i++) { - auto &cond = base.constraints[i]; - switch (cond->type) { +vector> Binder::BindConstraints(const vector> &constraints, const string &table_name, const ColumnList &columns) { + vector> bound_constraints; + for (auto &constr : constraints) { + switch (constr->type) { case ConstraintType::CHECK: { - BindCheckConstraint(binder, info, cond); + bound_constraints.push_back(BindCheckConstraint(*this, table_name, columns, constr)); break; } case ConstraintType::NOT_NULL: { - auto ¬_null = cond->Cast(); - auto &col = base.columns.GetColumn(LogicalIndex(not_null.index)); - info.bound_constraints.push_back(make_uniq(PhysicalIndex(col.StorageOid()))); - not_null_columns.insert(not_null.index); + auto ¬_null = constr->Cast(); + auto &col = columns.GetColumn(LogicalIndex(not_null.index)); + bound_constraints.push_back(make_uniq(PhysicalIndex(col.StorageOid()))); break; } case ConstraintType::UNIQUE: { - auto &unique = cond->Cast(); + auto &unique = constr->Cast(); // have to resolve columns of the unique constraint vector keys; logical_index_set_t key_set; if (unique.HasIndex()) { - D_ASSERT(unique.GetIndex().index < base.columns.LogicalColumnCount()); + D_ASSERT(unique.GetIndex().index < columns.LogicalColumnCount()); // unique constraint is given by single index - unique.SetColumnName(base.columns.GetColumn(unique.GetIndex()).Name()); + unique.SetColumnName(columns.GetColumn(unique.GetIndex()).Name()); keys.push_back(unique.GetIndex()); key_set.insert(unique.GetIndex()); } else { // unique constraint is given by list of names // have to resolve names for (auto &keyname : unique.GetColumnNames()) { - if (!base.columns.ColumnExists(keyname)) { + if (!columns.ColumnExists(keyname)) { throw ParserException("column \"%s\" named in key does not exist", keyname); } - auto &column = base.columns.GetColumn(keyname); + auto &column = columns.GetColumn(keyname); auto column_index = column.Logical(); if (key_set.find(column_index) != key_set.end()) { throw ParserException("column \"%s\" appears twice in " @@ -100,38 +95,29 @@ static void BindConstraints(Binder &binder, BoundCreateTableInfo &info) { key_set.insert(column_index); } } - - if (unique.IsPrimaryKey()) { - // we can only have one primary key per table - if (has_primary_key) { - throw ParserException("table \"%s\" has more than one primary key", base.table); - } - has_primary_key = true; - primary_keys = keys; - } - info.bound_constraints.push_back( + bound_constraints.push_back( make_uniq(std::move(keys), std::move(key_set), unique.IsPrimaryKey())); break; } case ConstraintType::FOREIGN_KEY: { - auto &fk = cond->Cast(); + auto &fk = constr->Cast(); D_ASSERT((fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE && !fk.info.pk_keys.empty()) || (fk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE && !fk.info.pk_keys.empty()) || fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE); physical_index_set_t fk_key_set, pk_key_set; - for (idx_t i = 0; i < fk.info.pk_keys.size(); i++) { - if (pk_key_set.find(fk.info.pk_keys[i]) != pk_key_set.end()) { + for (auto &pk_key : fk.info.pk_keys) { + if (pk_key_set.find(pk_key) != pk_key_set.end()) { throw BinderException("Duplicate primary key referenced in FOREIGN KEY constraint"); } - pk_key_set.insert(fk.info.pk_keys[i]); + pk_key_set.insert(pk_key); } - for (idx_t i = 0; i < fk.info.fk_keys.size(); i++) { - if (fk_key_set.find(fk.info.fk_keys[i]) != fk_key_set.end()) { + for (auto &fk_key : fk.info.fk_keys) { + if (fk_key_set.find(fk_key) != fk_key_set.end()) { throw BinderException("Duplicate key specified in FOREIGN KEY constraint"); } - fk_key_set.insert(fk.info.fk_keys[i]); + fk_key_set.insert(fk_key); } - info.bound_constraints.push_back( + bound_constraints.push_back( make_uniq(fk.info, std::move(pk_key_set), std::move(fk_key_set))); break; } @@ -139,6 +125,43 @@ static void BindConstraints(Binder &binder, BoundCreateTableInfo &info) { throw NotImplementedException("unrecognized constraint type in bind"); } } + return bound_constraints; +} + +vector> Binder::BindNewConstraints(vector> &constraints, const string &table_name, const ColumnList &columns) { + auto bound_constraints = BindConstraints(constraints, table_name, columns); + + // handle primary keys/not null constraints + bool has_primary_key = false; + logical_index_set_t not_null_columns; + vector primary_keys; + for(idx_t c = 0; c < constraints.size(); c++) { + auto &constr = constraints[c]; + switch(constr->type) { + case ConstraintType::NOT_NULL: { + auto ¬_null = constr->Cast(); + auto &col = columns.GetColumn(LogicalIndex(not_null.index)); + bound_constraints.push_back(make_uniq(PhysicalIndex(col.StorageOid()))); + not_null_columns.insert(not_null.index); + break; + } + case ConstraintType::UNIQUE: { + auto &unique = constr->Cast(); + auto &bound_unique = bound_constraints[c]->Cast(); + if (unique.IsPrimaryKey()) { + // we can only have one primary key per table + if (has_primary_key) { + throw ParserException("table \"%s\" has more than one primary key", table_name); + } + has_primary_key = true; + primary_keys = bound_unique.keys; + } + break; + } + default: + break; + } + } if (has_primary_key) { // if there is a primary key index, also create a NOT NULL constraint for each of the columns for (auto &column_index : primary_keys) { @@ -146,11 +169,12 @@ static void BindConstraints(Binder &binder, BoundCreateTableInfo &info) { //! No need to create a NotNullConstraint, it's already present continue; } - auto physical_index = base.columns.LogicalToPhysical(column_index); - base.constraints.push_back(make_uniq(column_index)); - info.bound_constraints.push_back(make_uniq(physical_index)); + auto physical_index = columns.LogicalToPhysical(column_index); + constraints.push_back(make_uniq(column_index)); + bound_constraints.push_back(make_uniq(physical_index)); } } + return bound_constraints; } void Binder::BindGeneratedColumns(BoundCreateTableInfo &info) { @@ -239,13 +263,13 @@ static void ExtractExpressionDependencies(Expression &expr, LogicalDependencyLis expr, [&](Expression &child) { ExtractExpressionDependencies(child, dependencies); }); } -static void ExtractDependencies(BoundCreateTableInfo &info, vector> &bound_defaults) { - for (auto &default_value : bound_defaults) { +static void ExtractDependencies(BoundCreateTableInfo &info, vector> &defaults, vector> &constraints) { + for (auto &default_value : defaults) { if (default_value) { ExtractExpressionDependencies(*default_value, info.dependencies); } } - for (auto &constraint : info.bound_constraints) { + for (auto &constraint : constraints) { if (constraint->type == ConstraintType::CHECK) { auto &bound_check = constraint->Cast(); ExtractExpressionDependencies(*bound_check.expression, info.dependencies); @@ -262,6 +286,8 @@ unique_ptr Binder::BindCreateTableInfo(unique_ptr> &bound_defaults) { auto &base = info->Cast(); auto result = make_uniq(schema, std::move(info)); + + vector> bound_constraints; if (base.query) { // construct the result object auto query_obj = Bind(*base.query); @@ -284,12 +310,12 @@ unique_ptr Binder::BindCreateTableInfo(unique_ptr(std::move(proj_tmp)); // bind any extra columns necessary for CHECK constraints or indexes - table.BindUpdateConstraints(*get, *proj, *update, context); + table.BindUpdateConstraints(*this, *get, *proj, *update, context); // finally add the row id column to the projection list proj->expressions.push_back(make_uniq( diff --git a/src/storage/data_table.cpp b/src/storage/data_table.cpp index 770a18ed43b2..8dab59b0e1a8 100644 --- a/src/storage/data_table.cpp +++ b/src/storage/data_table.cpp @@ -23,7 +23,9 @@ #include "duckdb/common/types/conflict_manager.hpp" #include "duckdb/common/types/constraint_conflict_info.hpp" #include "duckdb/storage/table/append_state.hpp" +#include "duckdb/storage/table/delete_state.hpp" #include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/table/update_state.hpp" #include "duckdb/common/exception/transaction_exception.hpp" namespace duckdb { @@ -549,7 +551,7 @@ bool HasUniqueIndexes(TableIndexList &list) { bool has_unique_index = false; list.Scan([&](Index &index) { if (index.IsUnique()) { - return has_unique_index = true; + has_unique_index = true; return true; } return false; @@ -558,7 +560,7 @@ bool HasUniqueIndexes(TableIndexList &list) { } void DataTable::VerifyUniqueIndexes(TableIndexList &indexes, ClientContext &context, DataChunk &chunk, - ConflictManager *conflict_manager) { + optional_ptr conflict_manager) { //! check whether or not the chunk can be inserted into the indexes if (!conflict_manager) { // Only need to verify that no unique constraints are violated @@ -614,8 +616,9 @@ void DataTable::VerifyUniqueIndexes(TableIndexList &indexes, ClientContext &cont }); } -void DataTable::VerifyAppendConstraints(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, - ConflictManager *conflict_manager) { +void DataTable::VerifyAppendConstraints(ConstraintVerificationState &state, ClientContext &context, DataChunk &chunk, + optional_ptr conflict_manager) { + auto &table = state.table; if (table.HasGeneratedColumns()) { // Verify that the generated columns expression work with the inserted values auto binder = Binder::CreateBinder(context); @@ -638,10 +641,9 @@ void DataTable::VerifyAppendConstraints(TableCatalogEntry &table, ClientContext } auto &constraints = table.GetConstraints(); - auto &bound_constraints = table.GetBoundConstraints(); - for (idx_t i = 0; i < bound_constraints.size(); i++) { + for (idx_t i = 0; i < state.bound_constraints.size(); i++) { auto &base_constraint = constraints[i]; - auto &constraint = bound_constraints[i]; + auto &constraint = state.bound_constraints[i]; switch (base_constraint->type) { case ConstraintType::NOT_NULL: { auto &bound_not_null = *reinterpret_cast(constraint.get()); @@ -673,12 +675,21 @@ void DataTable::VerifyAppendConstraints(TableCatalogEntry &table, ClientContext } } -void DataTable::InitializeLocalAppend(LocalAppendState &state, ClientContext &context) { +unique_ptr DataTable::InitializeConstraintVerification(TableCatalogEntry &table, ClientContext &context) { + auto result = make_uniq(table); + auto binder = Binder::CreateBinder(context); + result->bound_constraints = binder->BindConstraints(table.GetConstraints(), table.name, table.GetColumns()); + return result; +} + +void DataTable::InitializeLocalAppend(LocalAppendState &state, TableCatalogEntry &table, ClientContext &context) { if (!is_root) { throw TransactionException("Transaction conflict: adding entries to a table that has been altered!"); } auto &local_storage = LocalStorage::Get(context, db); local_storage.InitializeAppend(state, *this); + + state.constraint_state = InitializeConstraintVerification(table, context); } void DataTable::LocalAppend(LocalAppendState &state, TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, @@ -695,7 +706,7 @@ void DataTable::LocalAppend(LocalAppendState &state, TableCatalogEntry &table, C // verify any constraints on the new chunk if (!unsafe) { - VerifyAppendConstraints(table, context, chunk); + VerifyAppendConstraints(*state.constraint_state, context, chunk); } // append to the transaction local data @@ -724,7 +735,7 @@ void DataTable::LocalMerge(ClientContext &context, RowGroupCollection &collectio void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk) { LocalAppendState append_state; auto &storage = table.GetStorage(); - storage.InitializeLocalAppend(append_state, context); + storage.InitializeLocalAppend(append_state, table, context); storage.LocalAppend(append_state, table, context, chunk); storage.FinalizeLocalAppend(append_state); } @@ -732,7 +743,7 @@ void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, Da void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, ColumnDataCollection &collection) { LocalAppendState append_state; auto &storage = table.GetStorage(); - storage.InitializeLocalAppend(append_state, context); + storage.InitializeLocalAppend(append_state, table, context); for (auto &chunk : collection.Chunks()) { storage.LocalAppend(append_state, table, context, chunk); } @@ -956,15 +967,14 @@ void DataTable::RemoveFromIndexes(Vector &row_identifiers, idx_t count) { // Delete //===--------------------------------------------------------------------===// static bool TableHasDeleteConstraints(TableCatalogEntry &table) { - auto &bound_constraints = table.GetBoundConstraints(); - for (auto &constraint : bound_constraints) { + for (auto &constraint : table.GetConstraints()) { switch (constraint->type) { case ConstraintType::NOT_NULL: case ConstraintType::CHECK: case ConstraintType::UNIQUE: break; case ConstraintType::FOREIGN_KEY: { - auto &bfk = *reinterpret_cast(constraint.get()); + auto &bfk = constraint->Cast(); if (bfk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE || bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { return true; @@ -978,9 +988,8 @@ static bool TableHasDeleteConstraints(TableCatalogEntry &table) { return false; } -void DataTable::VerifyDeleteConstraints(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk) { - auto &bound_constraints = table.GetBoundConstraints(); - for (auto &constraint : bound_constraints) { +void DataTable::VerifyDeleteConstraints(TableDeleteState &state, ClientContext &context, DataChunk &chunk) { + for (auto &constraint : state.bound_constraints) { switch (constraint->type) { case ConstraintType::NOT_NULL: case ConstraintType::CHECK: @@ -1000,33 +1009,39 @@ void DataTable::VerifyDeleteConstraints(TableCatalogEntry &table, ClientContext } } -idx_t DataTable::Delete(TableCatalogEntry &table, ClientContext &context, Vector &row_identifiers, idx_t count) { +unique_ptr DataTable::InitializeDelete(TableCatalogEntry &table, ClientContext &context) { + // initialize indexes (if any) + info->InitializeIndexes(context, true); + + auto binder = Binder::CreateBinder(context); + vector> bound_constraints; + vector types; + auto result = make_uniq(); + result->has_delete_constraints = TableHasDeleteConstraints(table); + if (result->has_delete_constraints) { + // initialize the chunk if there are any constraints to verify + for (idx_t i = 0; i < column_definitions.size(); i++) { + result->col_ids.push_back(column_definitions[i].StorageOid()); + types.emplace_back(column_definitions[i].Type()); + } + result->verify_chunk.Initialize(Allocator::Get(context), types); + result->bound_constraints = binder->BindConstraints(table.GetConstraints(), table.name, table.GetColumns()); + } + return result; +} + +idx_t DataTable::Delete(TableDeleteState &state, ClientContext &context, Vector &row_identifiers, idx_t count) { D_ASSERT(row_identifiers.GetType().InternalType() == ROW_TYPE); if (count == 0) { return 0; } - info->InitializeIndexes(context, true); - auto &transaction = DuckTransaction::Get(context, db); auto &local_storage = LocalStorage::Get(transaction); - bool has_delete_constraints = TableHasDeleteConstraints(table); row_identifiers.Flatten(count); auto ids = FlatVector::GetData(row_identifiers); - DataChunk verify_chunk; - vector col_ids; - vector types; - ColumnFetchState fetch_state; - if (has_delete_constraints) { - // initialize the chunk if there are any constraints to verify - for (idx_t i = 0; i < column_definitions.size(); i++) { - col_ids.push_back(column_definitions[i].StorageOid()); - types.emplace_back(column_definitions[i].Type()); - } - verify_chunk.Initialize(Allocator::Get(context), types); - } idx_t pos = 0; idx_t delete_count = 0; while (pos < count) { @@ -1045,18 +1060,20 @@ idx_t DataTable::Delete(TableCatalogEntry &table, ClientContext &context, Vector Vector offset_ids(row_identifiers, current_offset, pos); if (is_transaction_delete) { // transaction-local delete - if (has_delete_constraints) { + if (state.has_delete_constraints) { // perform the constraint verification - local_storage.FetchChunk(*this, offset_ids, current_count, col_ids, verify_chunk, fetch_state); - VerifyDeleteConstraints(table, context, verify_chunk); + ColumnFetchState fetch_state; + local_storage.FetchChunk(*this, offset_ids, current_count, state.col_ids, state.verify_chunk, fetch_state); + VerifyDeleteConstraints(state, context, state.verify_chunk); } delete_count += local_storage.Delete(*this, offset_ids, current_count); } else { // regular table delete - if (has_delete_constraints) { + if (state.has_delete_constraints) { // perform the constraint verification - Fetch(transaction, verify_chunk, col_ids, offset_ids, current_count, fetch_state); - VerifyDeleteConstraints(table, context, verify_chunk); + ColumnFetchState fetch_state; + Fetch(transaction, state.verify_chunk, state.col_ids, offset_ids, current_count, fetch_state); + VerifyDeleteConstraints(state, context, state.verify_chunk); } delete_count += row_groups->Delete(transaction, *this, ids + current_offset, current_count); } @@ -1101,10 +1118,11 @@ static bool CreateMockChunk(TableCatalogEntry &table, const vector &column_ids) { + auto &table = state.table; auto &constraints = table.GetConstraints(); - auto &bound_constraints = table.GetBoundConstraints(); + auto &bound_constraints = state.bound_constraints; for (idx_t constr_idx = 0; constr_idx < bound_constraints.size(); constr_idx++) { auto &base_constraint = constraints[constr_idx]; auto &constraint = bound_constraints[constr_idx]; @@ -1150,7 +1168,16 @@ void DataTable::VerifyUpdateConstraints(ClientContext &context, TableCatalogEntr #endif } -void DataTable::Update(TableCatalogEntry &table, ClientContext &context, Vector &row_ids, +unique_ptr DataTable::InitializeUpdate(TableCatalogEntry &table, ClientContext &context) { + // check that there are no unknown indexes + info->InitializeIndexes(context, true); + + auto result = make_uniq(); + result->constraint_state = InitializeConstraintVerification(table, context); + return result; +} + +void DataTable::Update(TableUpdateState &state, ClientContext &context, Vector &row_ids, const vector &column_ids, DataChunk &updates) { D_ASSERT(row_ids.GetType().InternalType() == ROW_TYPE); D_ASSERT(column_ids.size() == updates.ColumnCount()); @@ -1165,11 +1192,8 @@ void DataTable::Update(TableCatalogEntry &table, ClientContext &context, Vector throw TransactionException("Transaction conflict: cannot update a table that has been altered!"); } - // check that there are no unknown indexes - info->InitializeIndexes(context, true); - // first verify that no constraints are violated - VerifyUpdateConstraints(context, table, updates, column_ids); + VerifyUpdateConstraints(*state.constraint_state, context, updates, column_ids); // now perform the actual update Vector max_row_id_vec(Value::BIGINT(MAX_ROW_ID)); diff --git a/src/storage/wal_replay.cpp b/src/storage/wal_replay.cpp index 9699bbac3e9d..6f5357e3d9c2 100644 --- a/src/storage/wal_replay.cpp +++ b/src/storage/wal_replay.cpp @@ -678,9 +678,10 @@ void WriteAheadLogDeserializer::ReplayDelete() { auto source_ids = FlatVector::GetData(chunk.data[0]); // delete the tuples from the current table + TableDeleteState delete_state; for (idx_t i = 0; i < chunk.size(); i++) { row_ids[0] = source_ids[i]; - state.current_table->GetStorage().Delete(*state.current_table, context, row_identifiers, 1); + state.current_table->GetStorage().Delete(delete_state, context, row_identifiers, 1); } } From c98a18eb9b2cf574868da91c0516e4c8af4b06a8 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 18 Apr 2024 22:13:41 +0200 Subject: [PATCH 176/201] accept NULL type, if any of the inputs are NULL, the result becomes NULL --- src/core_functions/scalar/map/map.cpp | 33 +++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/core_functions/scalar/map/map.cpp b/src/core_functions/scalar/map/map.cpp index e27fe3fd6503..62f54cdd25d0 100644 --- a/src/core_functions/scalar/map/map.cpp +++ b/src/core_functions/scalar/map/map.cpp @@ -28,6 +28,13 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { // - STRUCTs have exactly two fields, a key-field, and a value-field // - key names are unique + if (result.GetType().id() == LogicalTypeId::SQLNULL) { + auto &validity = FlatVector::Validity(result); + validity.SetInvalid(0); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + return; + } + D_ASSERT(result.GetType().id() == LogicalTypeId::MAP); auto row_count = args.size(); @@ -63,13 +70,15 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { UnifiedVectorFormat result_data; result.ToUnifiedFormat(row_count, result_data); auto result_entries = UnifiedVectorFormat::GetDataNoConst(result_data); - result_data.validity.SetAllValid(row_count); + + auto &result_validity = FlatVector::Validity(result); // get the resulting size of the key/value child lists idx_t result_child_size = 0; for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { auto keys_idx = keys_data.sel->get_index(row_idx); - if (!keys_data.validity.RowIsValid(keys_idx)) { + auto values_idx = values_data.sel->get_index(row_idx); + if (!keys_data.validity.RowIsValid(keys_idx) || !values_data.validity.RowIsValid(values_idx)) { continue; } auto keys_entry = keys_entries[keys_idx]; @@ -87,22 +96,15 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { auto values_idx = values_data.sel->get_index(row_idx); auto result_idx = result_data.sel->get_index(row_idx); - // empty map - if (!keys_data.validity.RowIsValid(keys_idx) && !values_data.validity.RowIsValid(values_idx)) { - result_entries[result_idx] = list_entry_t(); + // NULL MAP + if (!keys_data.validity.RowIsValid(keys_idx) || !values_data.validity.RowIsValid(values_idx)) { + result_validity.SetInvalid(row_idx); continue; } auto keys_entry = keys_entries[keys_idx]; auto values_entry = values_entries[values_idx]; - // validity checks - if (!keys_data.validity.RowIsValid(keys_idx)) { - MapVector::EvalMapInvalidReason(MapInvalidReason::NULL_KEY_LIST); - } - if (!values_data.validity.RowIsValid(values_idx)) { - MapVector::EvalMapInvalidReason(MapInvalidReason::NULL_VALUE_LIST); - } if (keys_entry.length != values_entry.length) { MapVector::EvalMapInvalidReason(MapInvalidReason::NOT_ALIGNED); } @@ -166,6 +168,13 @@ static unique_ptr MapBind(ClientContext &, ScalarFunction &bound_f return make_uniq(bound_function.return_type); } + auto key_id = arguments[0]->return_type.id(); + auto value_id = arguments[1]->return_type.id(); + if (key_id == LogicalTypeId::SQLNULL || value_id == LogicalTypeId::SQLNULL) { + bound_function.return_type = LogicalTypeId::SQLNULL; + return make_uniq(bound_function.return_type); + } + // bind a MAP with key-value pairs D_ASSERT(arguments.size() == 2); if (arguments[0]->return_type.id() != LogicalTypeId::LIST) { From 253bd13f90aaa067757baae47c771c980ef31933 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 18 Apr 2024 22:14:54 +0200 Subject: [PATCH 177/201] add test for NULL maps --- test/sql/types/map/map_null.test | 35 ++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 test/sql/types/map/map_null.test diff --git a/test/sql/types/map/map_null.test b/test/sql/types/map/map_null.test new file mode 100644 index 000000000000..6bc138488856 --- /dev/null +++ b/test/sql/types/map/map_null.test @@ -0,0 +1,35 @@ +# name: test/sql/types/map/map_null.test +# group: [map] + +statement ok +pragma enable_verification; + +query I +select map(NULL::INT[], [1,2,3]) +---- +NULL + +query I +select map(NULL, [1,2,3]) +---- +NULL + +query I +select map(NULL, NULL) +---- +NULL + +query I +select map(NULL, [1,2,3]) IS NULL +---- +true + +query I +select map([1,2,3], NULL) +---- +NULL + +query I +select map([1,2,3], NULL::INT[]) +---- +NULL From a371e61859ff26c6ca65359a78f60c31a4667d5a Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 18 Apr 2024 22:22:38 +0200 Subject: [PATCH 178/201] change existing tests --- test/sql/types/nested/map/map_error.test | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/sql/types/nested/map/map_error.test b/test/sql/types/nested/map/map_error.test index 315a1620ea7a..f67c19d4ead2 100644 --- a/test/sql/types/nested/map/map_error.test +++ b/test/sql/types/nested/map/map_error.test @@ -75,10 +75,11 @@ CREATE TABLE null_keys_list (k INT[], v INT[]); statement ok INSERT INTO null_keys_list VALUES ([1], [2]), (NULL, [4]); -statement error +query I SELECT MAP(k, v) FROM null_keys_list; ---- -The list of map keys must not be NULL. +{1=2} +NULL statement ok CREATE TABLE null_values_list (k INT[], v INT[]); @@ -86,7 +87,8 @@ CREATE TABLE null_values_list (k INT[], v INT[]); statement ok INSERT INTO null_values_list VALUES ([1], [2]), ([4], NULL); -statement error +query I SELECT MAP(k, v) FROM null_values_list; ---- -The list of map values must not be NULL. \ No newline at end of file +{1=2} +NULL From c7b80a695e4b3784e14a5a35bce4196c2108bc95 Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 22:29:19 +0200 Subject: [PATCH 179/201] Add also amazonlinux container tests --- scripts/test_docker_images.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scripts/test_docker_images.sh b/scripts/test_docker_images.sh index 7f9deaf196ae..86febb0328fd 100755 --- a/scripts/test_docker_images.sh +++ b/scripts/test_docker_images.sh @@ -1,7 +1,10 @@ #!/usr/bin/env bash make clean +docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb amazonlinux:2 <<< "yum install g++ git make cmake ninja-build -y && GEN=ninja make && make clean" 2>&1 +docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb amazonlinux:latest <<< "yum install clang git make cmake ninja-build -y && GEN=ninja make && make clean" 2>&1 docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk add g++ git make cmake ninja && GEN=ninja make && make clean" 2>&1 docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk add g++ git make cmake ninja python3 && GEN=ninja make && make clean" 2>&1 docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk add g++ git make cmake ninja && CXX_STANDARD=23 GEN=ninja make && make clean" 2>&1 docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb ubuntu:20.04 <<< "apt-get update && export DEBIAN_FRONTEND=noninteractive && apt-get install g++ git make cmake ninja-build -y && GEN=ninja make && make clean" 2>&1 + From 3ed2ae9372c57e7bb34213f39e36ff306228dbf0 Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 22:29:38 +0200 Subject: [PATCH 180/201] Add test also on ubuntu:devel --- scripts/test_docker_images.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/test_docker_images.sh b/scripts/test_docker_images.sh index 86febb0328fd..173d0b9f340a 100755 --- a/scripts/test_docker_images.sh +++ b/scripts/test_docker_images.sh @@ -7,4 +7,5 @@ docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk ad docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk add g++ git make cmake ninja python3 && GEN=ninja make && make clean" 2>&1 docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk add g++ git make cmake ninja && CXX_STANDARD=23 GEN=ninja make && make clean" 2>&1 docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb ubuntu:20.04 <<< "apt-get update && export DEBIAN_FRONTEND=noninteractive && apt-get install g++ git make cmake ninja-build -y && GEN=ninja make && make clean" 2>&1 +docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb ubuntu:devel <<< "apt-get update && export DEBIAN_FRONTEND=noninteractive && apt-get install g++ git make cmake ninja-build -y && GEN=ninja make && make clean" 2>&1 From d1883806b1537054d4890d0c53ae6619f9261057 Mon Sep 17 00:00:00 2001 From: Carlo Piovesan Date: Thu, 18 Apr 2024 22:43:07 +0200 Subject: [PATCH 181/201] Add centos to containerized tests --- scripts/test_docker_images.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/test_docker_images.sh b/scripts/test_docker_images.sh index 173d0b9f340a..3b180e963784 100755 --- a/scripts/test_docker_images.sh +++ b/scripts/test_docker_images.sh @@ -8,4 +8,4 @@ docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk ad docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb alpine:latest <<< "apk add g++ git make cmake ninja && CXX_STANDARD=23 GEN=ninja make && make clean" 2>&1 docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb ubuntu:20.04 <<< "apt-get update && export DEBIAN_FRONTEND=noninteractive && apt-get install g++ git make cmake ninja-build -y && GEN=ninja make && make clean" 2>&1 docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb ubuntu:devel <<< "apt-get update && export DEBIAN_FRONTEND=noninteractive && apt-get install g++ git make cmake ninja-build -y && GEN=ninja make && make clean" 2>&1 - +docker run -i --rm -v $(pwd):/duckdb --workdir /duckdb centos <<< "sed -i 's/mirrorlist/#mirrorlist/g' /etc/yum.repos.d/CentOS-* && sed -i 's|#baseurl=http://mirror.centos.org|baseurl=http://vault.centos.org|g' /etc/yum.repos.d/CentOS-* && yum install git make cmake clang -y && make && make clean" 2>&1 From 15803cfdd6eb54d17b1a5316503435fffcc0fa2c Mon Sep 17 00:00:00 2001 From: stephaniewang Date: Thu, 18 Apr 2024 16:54:57 -0400 Subject: [PATCH 182/201] feat: rewrite which_secret() into a table function --- src/core_functions/function_list.cpp | 1 - src/core_functions/scalar/CMakeLists.txt | 1 - .../scalar/secret/CMakeLists.txt | 4 - .../scalar/secret/functions.json | 9 --- .../scalar/secret/which_secret.cpp | 28 ------- src/function/table/system/CMakeLists.txt | 1 + .../table/system/duckdb_which_secret.cpp | 75 +++++++++++++++++++ src/function/table/system_functions.cpp | 1 + .../function/table/system_functions.hpp | 4 + test/secrets/test_custom_secret_storage.cpp | 14 ++-- .../secrets/create_secret_scope_matching.test | 23 +++++- 11 files changed, 107 insertions(+), 54 deletions(-) delete mode 100644 src/core_functions/scalar/secret/CMakeLists.txt delete mode 100644 src/core_functions/scalar/secret/functions.json delete mode 100644 src/core_functions/scalar/secret/which_secret.cpp create mode 100644 src/function/table/system/duckdb_which_secret.cpp diff --git a/src/core_functions/function_list.cpp b/src/core_functions/function_list.cpp index 540752a4f48c..23bf82242562 100644 --- a/src/core_functions/function_list.cpp +++ b/src/core_functions/function_list.cpp @@ -385,7 +385,6 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_SCALAR_FUNCTION_SET(WeekFun), DUCKDB_SCALAR_FUNCTION_SET(WeekDayFun), DUCKDB_SCALAR_FUNCTION_SET(WeekOfYearFun), - DUCKDB_SCALAR_FUNCTION(WhichSecretFun), DUCKDB_SCALAR_FUNCTION_SET(BitwiseXorFun), DUCKDB_SCALAR_FUNCTION_SET(YearFun), DUCKDB_SCALAR_FUNCTION_SET(YearWeekFun), diff --git a/src/core_functions/scalar/CMakeLists.txt b/src/core_functions/scalar/CMakeLists.txt index c6dd785c5b37..b4f0ff7c7bc4 100644 --- a/src/core_functions/scalar/CMakeLists.txt +++ b/src/core_functions/scalar/CMakeLists.txt @@ -9,7 +9,6 @@ add_subdirectory(map) add_subdirectory(math) add_subdirectory(operators) add_subdirectory(random) -add_subdirectory(secret) add_subdirectory(string) add_subdirectory(struct) add_subdirectory(union) diff --git a/src/core_functions/scalar/secret/CMakeLists.txt b/src/core_functions/scalar/secret/CMakeLists.txt deleted file mode 100644 index 6937dcd6003e..000000000000 --- a/src/core_functions/scalar/secret/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -add_library_unity(duckdb_func_secret OBJECT which_secret.cpp) -set(ALL_OBJECT_FILES - ${ALL_OBJECT_FILES} $ - PARENT_SCOPE) diff --git a/src/core_functions/scalar/secret/functions.json b/src/core_functions/scalar/secret/functions.json deleted file mode 100644 index fb24476be3d7..000000000000 --- a/src/core_functions/scalar/secret/functions.json +++ /dev/null @@ -1,9 +0,0 @@ -[ - { - "name": "which_secret", - "parameters": "path,type", - "description": "Print out the name of the secret that will be used for reading a path", - "example": "which_secret('s3://some/authenticated/path.csv', 's3')", - "type": "scalar_function" - } -] diff --git a/src/core_functions/scalar/secret/which_secret.cpp b/src/core_functions/scalar/secret/which_secret.cpp deleted file mode 100644 index dfa54e6278cc..000000000000 --- a/src/core_functions/scalar/secret/which_secret.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include "duckdb/core_functions/scalar/secret_functions.hpp" -#include "duckdb/main/secret/secret_manager.hpp" - -namespace duckdb { - -static void WhichSecretFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - - auto &secret_manager = SecretManager::Get(state.GetContext()); - auto transaction = CatalogTransaction::GetSystemCatalogTransaction(state.GetContext()); - - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t path, string_t type) { - auto secret_match = secret_manager.LookupSecret(transaction, path.GetString(), type.GetString()); - if (!secret_match.HasMatch()) { - return string_t(); - } - return StringVector::AddString(result, secret_match.GetSecret().GetName()); - }); -} - -ScalarFunction WhichSecretFun::GetFunction() { - ScalarFunction which_secret("which_secret", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - WhichSecretFunction, nullptr, nullptr, nullptr, nullptr); - return which_secret; -} - -} // namespace duckdb diff --git a/src/function/table/system/CMakeLists.txt b/src/function/table/system/CMakeLists.txt index 066f216de150..14d15264ca17 100644 --- a/src/function/table/system/CMakeLists.txt +++ b/src/function/table/system/CMakeLists.txt @@ -13,6 +13,7 @@ add_library_unity( duckdb_optimizers.cpp duckdb_schemas.cpp duckdb_secrets.cpp + duckdb_which_secret.cpp duckdb_sequences.cpp duckdb_settings.cpp duckdb_tables.cpp diff --git a/src/function/table/system/duckdb_which_secret.cpp b/src/function/table/system/duckdb_which_secret.cpp new file mode 100644 index 000000000000..3314fee95e69 --- /dev/null +++ b/src/function/table/system/duckdb_which_secret.cpp @@ -0,0 +1,75 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/multi_file_reader.hpp" +#include "duckdb/function/function_set.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/extension_helper.hpp" +#include "duckdb/main/secret/secret_manager.hpp" + +namespace duckdb { + +struct DuckDBWhichSecretData : public GlobalTableFunctionState { + DuckDBWhichSecretData() : finished(false) { + } + bool finished; +}; + +struct DuckDBWhichSecretBindData : public TableFunctionData { + explicit DuckDBWhichSecretBindData(TableFunctionBindInput &tf_input) : inputs(tf_input.inputs) {}; + + duckdb::vector inputs; +}; + +static unique_ptr DuckDBWhichSecretBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("persistent"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("storage"); + return_types.emplace_back(LogicalType::VARCHAR); + + return make_uniq(input); +} + +unique_ptr DuckDBWhichSecretInit(ClientContext &context, TableFunctionInitInput &input) { + return make_uniq(); +} + +void DuckDBWhichSecretFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.finished) { + // finished returning values + return; + } + auto &bind_data = data_p.bind_data->Cast(); + + auto &secret_manager = SecretManager::Get(context); + auto transaction = CatalogTransaction::GetSystemCatalogTransaction(context); + + auto &inputs = bind_data.inputs; + auto path = inputs[0].ToString(); + auto type = inputs[1].ToString(); + auto secret_match = secret_manager.LookupSecret(transaction, path, type); + if (secret_match.HasMatch()) { + auto &secret_entry = *secret_match.secret_entry; + output.SetCardinality(1); + output.SetValue(0, 0, secret_entry.secret->GetName()); + output.SetValue(1, 0, EnumUtil::ToString(secret_entry.persist_type)); + output.SetValue(2, 0, secret_entry.storage_mode); + } + data.finished = true; +} + +void DuckDBWhichSecretFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("which_secret", {duckdb::LogicalType::VARCHAR, duckdb::LogicalType::VARCHAR}, + DuckDBWhichSecretFunction, DuckDBWhichSecretBind, DuckDBWhichSecretInit)); +} + +} // namespace duckdb diff --git a/src/function/table/system_functions.cpp b/src/function/table/system_functions.cpp index a9c191543083..a1a821db79f4 100644 --- a/src/function/table/system_functions.cpp +++ b/src/function/table/system_functions.cpp @@ -30,6 +30,7 @@ void BuiltinFunctions::RegisterSQLiteFunctions() { DuckDBMemoryFun::RegisterFunction(*this); DuckDBOptimizersFun::RegisterFunction(*this); DuckDBSecretsFun::RegisterFunction(*this); + DuckDBWhichSecretFun::RegisterFunction(*this); DuckDBSequencesFun::RegisterFunction(*this); DuckDBSettingsFun::RegisterFunction(*this); DuckDBTablesFun::RegisterFunction(*this); diff --git a/src/include/duckdb/function/table/system_functions.hpp b/src/include/duckdb/function/table/system_functions.hpp index 1ab8f87690df..23f4d9cb9caf 100644 --- a/src/include/duckdb/function/table/system_functions.hpp +++ b/src/include/duckdb/function/table/system_functions.hpp @@ -57,6 +57,10 @@ struct DuckDBSecretsFun { static void RegisterFunction(BuiltinFunctions &set); }; +struct DuckDBWhichSecretFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + struct DuckDBDatabasesFun { static void RegisterFunction(BuiltinFunctions &set); }; diff --git a/test/secrets/test_custom_secret_storage.cpp b/test/secrets/test_custom_secret_storage.cpp index e0fa40c8140c..ba344a954470 100644 --- a/test/secrets/test_custom_secret_storage.cpp +++ b/test/secrets/test_custom_secret_storage.cpp @@ -91,8 +91,8 @@ TEST_CASE("Test secret lookups by secret type", "[secret][.]") { REQUIRE_NO_FAIL(con.Query("CREATE SECRET s2 (TYPE secret_type_2, SCOPE '')")); // Note that the secrets collide completely, except for their types - auto res1 = con.Query("SELECT which_secret('blablabla', 'secret_type_1')"); - auto res2 = con.Query("SELECT which_secret('blablabla', 'secret_type_2')"); + auto res1 = con.Query("SELECT name FROM which_secret('blablabla', 'secret_type_1')"); + auto res2 = con.Query("SELECT name FROM which_secret('blablabla', 'secret_type_2')"); // Correct secret is selected REQUIRE(res1->GetValue(0, 0).ToString() == "s1"); @@ -154,14 +154,14 @@ TEST_CASE("Test adding a custom secret storage", "[secret][.]") { REQUIRE(secret_ptr->secret->GetName() == "s2"); // Now try resolve secret by path -> this will return s1 because its scope matches best - auto which_secret_result = con.Query("SELECT which_secret('s3://foo/bar.csv', 'S3');"); + auto which_secret_result = con.Query("SELECT name FROM which_secret('s3://foo/bar.csv', 'S3');"); REQUIRE(which_secret_result->GetValue(0, 0).ToString() == "s1"); // Exclude the storage from lookups storage_ref.include_in_lookups = false; // Now the lookup will choose the other storage - which_secret_result = con.Query("SELECT which_secret('s3://foo/bar.csv', 's3');"); + which_secret_result = con.Query("SELECT name FROM which_secret('s3://foo/bar.csv', 's3');"); REQUIRE(which_secret_result->GetValue(0, 0).ToString() == "s2"); // Lets drop stuff now @@ -222,17 +222,17 @@ TEST_CASE("Test tie-break behaviour for custom secret storage", "[secret][.]") { REQUIRE(result->GetValue(0, 2).ToString() == "s3"); REQUIRE(result->GetValue(1, 2).ToString() == "test_storage_before"); - result = con.Query("SELECT which_secret('s3://', 's3');"); + result = con.Query("SELECT name FROM which_secret('s3://', 's3');"); REQUIRE(result->GetValue(0, 0).ToString() == "s3"); REQUIRE_NO_FAIL(con.Query("DROP SECRET s3")); - result = con.Query("SELECT which_secret('s3://', 's3');"); + result = con.Query("SELECT name FROM which_secret('s3://', 's3');"); REQUIRE(result->GetValue(0, 0).ToString() == "s1"); REQUIRE_NO_FAIL(con.Query("DROP SECRET s1")); - result = con.Query("SELECT which_secret('s3://', 's3');"); + result = con.Query("SELECT name FROM which_secret('s3://', 's3');"); REQUIRE(result->GetValue(0, 0).ToString() == "s2"); REQUIRE_NO_FAIL(con.Query("DROP SECRET s2")); diff --git a/test/sql/secrets/create_secret_scope_matching.test b/test/sql/secrets/create_secret_scope_matching.test index a026ad7cfd50..8f27e1d8adf4 100644 --- a/test/sql/secrets/create_secret_scope_matching.test +++ b/test/sql/secrets/create_secret_scope_matching.test @@ -7,11 +7,21 @@ load __TEST_DIR__/create_secret_scope_matching.db statement ok PRAGMA enable_verification; -require httpfs +statement ok +SET autoinstall_known_extensions=1; + +statement ok +SET autoload_known_extensions=1; statement ok set secret_directory='__TEST_DIR__/create_secret_scope_matching' +# No match +query I +SELECT name FROM which_secret('s3://', 's3') +---- + + statement ok CREATE TEMPORARY SECRET t1 ( TYPE S3 ) @@ -24,16 +34,21 @@ CREATE SECRET p1 IN LOCAL_FILE ( TYPE S3 ) # This ties within the same storage: the two temporary secrets s1 and s2 both score identically. We solve this by # tie-breaking on secret name alphabetical ordering query I -SELECT which_secret('s3://', 's3') +SELECT name FROM which_secret('s3://', 's3') ---- t1 +query III +FROM which_secret('s3://', 's3') +---- +t1 TEMPORARY memory + statement ok DROP SECRET t1 # Temporary secrets take preference over temporary ones query I -SELECT which_secret('s3://', 's3') +SELECT name FROM which_secret('s3://', 's3') ---- t2 @@ -41,7 +56,7 @@ statement ok DROP SECRET t2 query I -SELECT which_secret('s3://', 's3') +SELECT name FROM which_secret('s3://', 's3') ---- p1 From 38216a2c7972e309f3ad0fcfd48f80f7e4345eec Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 18 Apr 2024 22:57:32 +0200 Subject: [PATCH 183/201] add extra test with pyarrow MapArrays with None --- tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py b/tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py index 592778146835..9c6ceb06b4fe 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py +++ b/tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py @@ -183,6 +183,15 @@ def test_map_arrow_to_duckdb(self, duckdb_cursor): ): rel = duckdb.from_arrow(arrow_table).fetchall() + def test_null_map_arrow_to_duckdb(self, duckdb_cursor): + if not can_run: + return + map_type = pa.map_(pa.int32(), pa.int32()) + values = [None, [(5, 42)]] + arrow_table = pa.table({'detail': pa.array(values, map_type)}) + res = duckdb_cursor.sql("select * from arrow_table").fetchall() + assert res == [(None,), ({'key': [5], 'value': [42]},)] + def test_map_arrow_to_pandas(self, duckdb_cursor): if not can_run: return From 2f9bc1e3ab5e51788e6bf60d86af109c500952e8 Mon Sep 17 00:00:00 2001 From: stephaniewang Date: Thu, 18 Apr 2024 18:06:13 -0400 Subject: [PATCH 184/201] update generate_functions.py --- scripts/generate_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/generate_functions.py b/scripts/generate_functions.py index 572d7b703221..f708eb1f5b72 100644 --- a/scripts/generate_functions.py +++ b/scripts/generate_functions.py @@ -15,7 +15,6 @@ 'math', 'operators', 'random', - 'secret', 'string', 'debug', 'struct', From 9718fb052f82b83e9beb3b18e0a96adb7ca3f15a Mon Sep 17 00:00:00 2001 From: stephaniewang Date: Thu, 18 Apr 2024 19:49:50 -0400 Subject: [PATCH 185/201] fix test --- test/sql/secrets/create_secret_scope_matching.test | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/sql/secrets/create_secret_scope_matching.test b/test/sql/secrets/create_secret_scope_matching.test index 8f27e1d8adf4..3d5dd2aac35a 100644 --- a/test/sql/secrets/create_secret_scope_matching.test +++ b/test/sql/secrets/create_secret_scope_matching.test @@ -7,11 +7,7 @@ load __TEST_DIR__/create_secret_scope_matching.db statement ok PRAGMA enable_verification; -statement ok -SET autoinstall_known_extensions=1; - -statement ok -SET autoload_known_extensions=1; +require httpfs statement ok set secret_directory='__TEST_DIR__/create_secret_scope_matching' @@ -21,7 +17,6 @@ query I SELECT name FROM which_secret('s3://', 's3') ---- - statement ok CREATE TEMPORARY SECRET t1 ( TYPE S3 ) From 4f7b3dee5512038dbadd8e4ba2cdeee62b7d0636 Mon Sep 17 00:00:00 2001 From: Mark Raasveldt Date: Fri, 19 Apr 2024 09:30:35 +0200 Subject: [PATCH 186/201] Implement GetUniqueConstraintKeys --- src/catalog/catalog_entry/duck_table_entry.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/catalog/catalog_entry/duck_table_entry.cpp b/src/catalog/catalog_entry/duck_table_entry.cpp index f77e69f65eba..e453df2e66b7 100644 --- a/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/catalog/catalog_entry/duck_table_entry.cpp @@ -72,7 +72,15 @@ IndexStorageInfo GetIndexInfo(const IndexConstraintType &constraint_type, unique } vector GetUniqueConstraintKeys(const ColumnList &columns, const UniqueConstraint &constraint) { - throw InternalException("FIXME: GetUniqueConstraintKeys"); + vector indexes; + if (constraint.HasIndex()) { + indexes.push_back(columns.LogicalToPhysical(constraint.GetIndex())); + } else { + for(auto &keyname : constraint.GetColumnNames()) { + indexes.push_back(columns.GetColumn(keyname).Physical()); + } + } + return indexes; } DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, BoundCreateTableInfo &info, From 8668a2f8e7b0015bfbfbdefeae8cbace92c2147e Mon Sep 17 00:00:00 2001 From: Mark Raasveldt Date: Fri, 19 Apr 2024 09:30:55 +0200 Subject: [PATCH 187/201] Format fix --- .../catalog_entry/duck_table_entry.cpp | 7 ++++--- .../catalog_entry/table_catalog_entry.cpp | 4 ++-- .../operator/persistent/physical_insert.cpp | 3 ++- .../table/system/duckdb_constraints.cpp | 2 +- .../catalog_entry/duck_table_entry.hpp | 3 ++- src/include/duckdb/planner/binder.hpp | 10 +++++++--- src/include/duckdb/storage/data_table.hpp | 8 ++++---- .../duckdb/storage/table/append_state.hpp | 3 ++- .../binder/statement/bind_create_table.cpp | 20 ++++++++++++------- src/storage/data_table.cpp | 10 ++++++---- 10 files changed, 43 insertions(+), 27 deletions(-) diff --git a/src/catalog/catalog_entry/duck_table_entry.cpp b/src/catalog/catalog_entry/duck_table_entry.cpp index e453df2e66b7..c25c76bb279b 100644 --- a/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/catalog/catalog_entry/duck_table_entry.cpp @@ -76,7 +76,7 @@ vector GetUniqueConstraintKeys(const ColumnList &columns, const U if (constraint.HasIndex()) { indexes.push_back(columns.LogicalToPhysical(constraint.GetIndex())); } else { - for(auto &keyname : constraint.GetColumnNames()) { + for (auto &keyname : constraint.GetColumnNames()) { indexes.push_back(columns.GetColumn(keyname).Physical()); } } @@ -360,7 +360,7 @@ void DuckTableEntry::UpdateConstraintsOnColumnDrop(const LogicalIndex &removed_i const vector &adjusted_indices, const RemoveColumnInfo &info, CreateTableInfo &create_info, const vector> &bound_constraints, - bool is_generated) { + bool is_generated) { // handle constraints for the new table D_ASSERT(constraints.size() == bound_constraints.size()); for (idx_t constr_idx = 0; constr_idx < constraints.size(); constr_idx++) { @@ -483,7 +483,8 @@ unique_ptr DuckTableEntry::RemoveColumn(ClientContext &context, Re auto binder = Binder::CreateBinder(context); auto bound_constraints = binder->BindConstraints(constraints, name, columns); - UpdateConstraintsOnColumnDrop(removed_index, adjusted_indices, info, *create_info, bound_constraints, dropped_column_is_generated); + UpdateConstraintsOnColumnDrop(removed_index, adjusted_indices, info, *create_info, bound_constraints, + dropped_column_is_generated); auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); if (columns.GetColumn(LogicalIndex(removed_index)).Generated()) { diff --git a/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/catalog/catalog_entry/table_catalog_entry.cpp index 493403830e2a..8f73ec416e62 100644 --- a/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -234,8 +234,8 @@ vector TableCatalogEntry::GetColumnSegmentInfo() { return {}; } -void TableCatalogEntry::BindUpdateConstraints(Binder &binder, LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, - ClientContext &context) { +void TableCatalogEntry::BindUpdateConstraints(Binder &binder, LogicalGet &get, LogicalProjection &proj, + LogicalUpdate &update, ClientContext &context) { // check the constraints and indexes of the table to see if we need to project any additional columns // we do this for indexes with multiple columns and CHECK constraints in the UPDATE clause // suppose we have a constraint CHECK(i + j < 10); now we need both i and j to check the constraint diff --git a/src/execution/operator/persistent/physical_insert.cpp b/src/execution/operator/persistent/physical_insert.cpp index 181216a15b7f..c7b499855c5a 100644 --- a/src/execution/operator/persistent/physical_insert.cpp +++ b/src/execution/operator/persistent/physical_insert.cpp @@ -108,7 +108,8 @@ class InsertLocalState : public LocalSinkState { idx_t update_count = 0; unique_ptr constraint_state; - ConstraintVerificationState &GetConstraintState(DataTable &table, TableCatalogEntry &tableref, ClientContext &context) { + ConstraintVerificationState &GetConstraintState(DataTable &table, TableCatalogEntry &tableref, + ClientContext &context) { if (!constraint_state) { constraint_state = table.InitializeConstraintVerification(tableref, context); } diff --git a/src/function/table/system/duckdb_constraints.cpp b/src/function/table/system/duckdb_constraints.cpp index c35eaf0da9ad..71fabb16b5ce 100644 --- a/src/function/table/system/duckdb_constraints.cpp +++ b/src/function/table/system/duckdb_constraints.cpp @@ -132,7 +132,7 @@ unique_ptr DuckDBConstraintsInit(ClientContext &contex }); sort(entries.begin(), entries.end(), [&](CatalogEntry &x, CatalogEntry &y) { return (x.name < y.name); }); - for(auto &entry : entries) { + for (auto &entry : entries) { result->entries.emplace_back(context, entry.get().Cast()); } }; diff --git a/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp b/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp index 7c1323a4f300..5396f31115cc 100644 --- a/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp +++ b/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp @@ -58,7 +58,8 @@ class DuckTableEntry : public TableCatalogEntry { unique_ptr SetColumnComment(ClientContext &context, SetColumnCommentInfo &info); void UpdateConstraintsOnColumnDrop(const LogicalIndex &removed_index, const vector &adjusted_indices, - const RemoveColumnInfo &info, CreateTableInfo &create_info, const vector> &bound_constraints, bool is_generated); + const RemoveColumnInfo &info, CreateTableInfo &create_info, + const vector> &bound_constraints, bool is_generated); private: //! A reference to the underlying storage unit used for this table diff --git a/src/include/duckdb/planner/binder.hpp b/src/include/duckdb/planner/binder.hpp index 67e2bc1ec6b1..705f35595017 100644 --- a/src/include/duckdb/planner/binder.hpp +++ b/src/include/duckdb/planner/binder.hpp @@ -122,9 +122,13 @@ class Binder : public std::enable_shared_from_this { unique_ptr BindCreateTableInfo(unique_ptr info, SchemaCatalogEntry &schema); unique_ptr BindCreateTableInfo(unique_ptr info, SchemaCatalogEntry &schema, vector> &bound_defaults); - static vector> BindConstraints(ClientContext &context, const vector> &constraints, const string &table_name, const ColumnList &columns); - vector> BindConstraints(const vector> &constraints, const string &table_name, const ColumnList &columns); - vector> BindNewConstraints(vector> &constraints, const string &table_name, const ColumnList &columns); + static vector> BindConstraints(ClientContext &context, + const vector> &constraints, + const string &table_name, const ColumnList &columns); + vector> BindConstraints(const vector> &constraints, + const string &table_name, const ColumnList &columns); + vector> BindNewConstraints(vector> &constraints, + const string &table_name, const ColumnList &columns); void BindCreateViewInfo(CreateViewInfo &base); SchemaCatalogEntry &BindSchema(CreateInfo &info); diff --git a/src/include/duckdb/storage/data_table.hpp b/src/include/duckdb/storage/data_table.hpp index 74a5c2283830..e6aedf464bc9 100644 --- a/src/include/duckdb/storage/data_table.hpp +++ b/src/include/duckdb/storage/data_table.hpp @@ -115,7 +115,6 @@ class DataTable { //! Delete the entries with the specified row identifier from the table idx_t Delete(TableDeleteState &state, ClientContext &context, Vector &row_ids, idx_t count); - unique_ptr InitializeUpdate(TableCatalogEntry &table, ClientContext &context); //! Update the entries with the specified row identifier from the table void Update(TableUpdateState &state, ClientContext &context, Vector &row_ids, @@ -194,14 +193,15 @@ class DataTable { bool IndexNameIsUnique(const string &name); //! Initialize constraint verification - unique_ptr InitializeConstraintVerification(TableCatalogEntry &table, ClientContext &context); + unique_ptr InitializeConstraintVerification(TableCatalogEntry &table, + ClientContext &context); //! Verify constraints with a chunk from the Append containing all columns of the table void VerifyAppendConstraints(ConstraintVerificationState &state, ClientContext &context, DataChunk &chunk, - optional_ptr conflict_manager = nullptr); + optional_ptr conflict_manager = nullptr); public: static void VerifyUniqueIndexes(TableIndexList &indexes, ClientContext &context, DataChunk &chunk, - optional_ptr conflict_manager); + optional_ptr conflict_manager); private: //! Verify the new added constraints against current persistent&local data diff --git a/src/include/duckdb/storage/table/append_state.hpp b/src/include/duckdb/storage/table/append_state.hpp index 475a59187f1f..8f9a56c5cba2 100644 --- a/src/include/duckdb/storage/table/append_state.hpp +++ b/src/include/duckdb/storage/table/append_state.hpp @@ -71,7 +71,8 @@ struct TableAppendState { }; struct ConstraintVerificationState { - explicit ConstraintVerificationState(TableCatalogEntry &table_p) : table(table_p) {} + explicit ConstraintVerificationState(TableCatalogEntry &table_p) : table(table_p) { + } TableCatalogEntry &table; vector> bound_constraints; diff --git a/src/planner/binder/statement/bind_create_table.cpp b/src/planner/binder/statement/bind_create_table.cpp index 10ef50599a08..8a032b9fe1ee 100644 --- a/src/planner/binder/statement/bind_create_table.cpp +++ b/src/planner/binder/statement/bind_create_table.cpp @@ -35,7 +35,8 @@ static void CreateColumnDependencyManager(BoundCreateTableInfo &info) { } } -static unique_ptr BindCheckConstraint(Binder &binder, const string &table_name, const ColumnList &columns, const unique_ptr &cond) { +static unique_ptr BindCheckConstraint(Binder &binder, const string &table_name, + const ColumnList &columns, const unique_ptr &cond) { auto bound_constraint = make_uniq(); // check constraint: bind the expression CheckBinder check_binder(binder, binder.context, table_name, columns, bound_constraint->bound_columns); @@ -47,12 +48,15 @@ static unique_ptr BindCheckConstraint(Binder &binder, const str return std::move(bound_constraint); } -vector> Binder::BindConstraints(ClientContext &context, const vector> &constraints, const string &table_name, const ColumnList &columns) { +vector> Binder::BindConstraints(ClientContext &context, + const vector> &constraints, + const string &table_name, const ColumnList &columns) { auto binder = Binder::CreateBinder(context); return binder->BindConstraints(constraints, table_name, columns); } -vector> Binder::BindConstraints(const vector> &constraints, const string &table_name, const ColumnList &columns) { +vector> Binder::BindConstraints(const vector> &constraints, + const string &table_name, const ColumnList &columns) { vector> bound_constraints; for (auto &constr : constraints) { switch (constr->type) { @@ -128,16 +132,17 @@ vector> Binder::BindConstraints(const vector> Binder::BindNewConstraints(vector> &constraints, const string &table_name, const ColumnList &columns) { +vector> Binder::BindNewConstraints(vector> &constraints, + const string &table_name, const ColumnList &columns) { auto bound_constraints = BindConstraints(constraints, table_name, columns); // handle primary keys/not null constraints bool has_primary_key = false; logical_index_set_t not_null_columns; vector primary_keys; - for(idx_t c = 0; c < constraints.size(); c++) { + for (idx_t c = 0; c < constraints.size(); c++) { auto &constr = constraints[c]; - switch(constr->type) { + switch (constr->type) { case ConstraintType::NOT_NULL: { auto ¬_null = constr->Cast(); auto &col = columns.GetColumn(LogicalIndex(not_null.index)); @@ -263,7 +268,8 @@ static void ExtractExpressionDependencies(Expression &expr, LogicalDependencyLis expr, [&](Expression &child) { ExtractExpressionDependencies(child, dependencies); }); } -static void ExtractDependencies(BoundCreateTableInfo &info, vector> &defaults, vector> &constraints) { +static void ExtractDependencies(BoundCreateTableInfo &info, vector> &defaults, + vector> &constraints) { for (auto &default_value : defaults) { if (default_value) { ExtractExpressionDependencies(*default_value, info.dependencies); diff --git a/src/storage/data_table.cpp b/src/storage/data_table.cpp index 8dab59b0e1a8..8a610b38c7c8 100644 --- a/src/storage/data_table.cpp +++ b/src/storage/data_table.cpp @@ -617,7 +617,7 @@ void DataTable::VerifyUniqueIndexes(TableIndexList &indexes, ClientContext &cont } void DataTable::VerifyAppendConstraints(ConstraintVerificationState &state, ClientContext &context, DataChunk &chunk, - optional_ptr conflict_manager) { + optional_ptr conflict_manager) { auto &table = state.table; if (table.HasGeneratedColumns()) { // Verify that the generated columns expression work with the inserted values @@ -675,7 +675,8 @@ void DataTable::VerifyAppendConstraints(ConstraintVerificationState &state, Clie } } -unique_ptr DataTable::InitializeConstraintVerification(TableCatalogEntry &table, ClientContext &context) { +unique_ptr DataTable::InitializeConstraintVerification(TableCatalogEntry &table, + ClientContext &context) { auto result = make_uniq(table); auto binder = Binder::CreateBinder(context); result->bound_constraints = binder->BindConstraints(table.GetConstraints(), table.name, table.GetColumns()); @@ -1063,7 +1064,8 @@ idx_t DataTable::Delete(TableDeleteState &state, ClientContext &context, Vector if (state.has_delete_constraints) { // perform the constraint verification ColumnFetchState fetch_state; - local_storage.FetchChunk(*this, offset_ids, current_count, state.col_ids, state.verify_chunk, fetch_state); + local_storage.FetchChunk(*this, offset_ids, current_count, state.col_ids, state.verify_chunk, + fetch_state); VerifyDeleteConstraints(state, context, state.verify_chunk); } delete_count += local_storage.Delete(*this, offset_ids, current_count); @@ -1118,7 +1120,7 @@ static bool CreateMockChunk(TableCatalogEntry &table, const vector &column_ids) { auto &table = state.table; auto &constraints = table.GetConstraints(); From 3e8d2ec4901b8aec742c1f37d2be99697b1bd241 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 19 Apr 2024 10:27:20 +0200 Subject: [PATCH 188/201] slightly more extended tests --- test/sql/types/map/map_null.test | 34 ++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/sql/types/map/map_null.test b/test/sql/types/map/map_null.test index 6bc138488856..68dc58be808a 100644 --- a/test/sql/types/map/map_null.test +++ b/test/sql/types/map/map_null.test @@ -33,3 +33,37 @@ query I select map([1,2,3], NULL::INT[]) ---- NULL + +query I +SELECT * FROM ( VALUES + (MAP(NULL, NULL)), + (MAP(NULL::INT[], NULL::INT[])), + (MAP([1,2,3], [1,2,3])) +) +---- +NULL +NULL +{1=1, 2=2, 3=3} + +query I +select MAP(a, b) FROM ( VALUES + (NULL, ['b', 'c']), + (NULL::INT[], NULL), + (NULL::INT[], NULL::VARCHAR[]), + (NULL::INT[], ['a', 'b', 'c']), + (NULL, ['longer string than inlined', 'smol']), + (NULL, NULL), + ([1,2,3], NULL), + ([1,2,3], ['z', 'y', 'x']), + ([1,2,3], NULL::VARCHAR[]), +) t(a, b) +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +{1=z, 2=y, 3=x} +NULL From 8838f07cb47885cdcd9e3a467df048154d4eb9d8 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 19 Apr 2024 10:41:00 +0200 Subject: [PATCH 189/201] removed unused enum constants --- src/common/enum_util.cpp | 10 ---------- src/common/types/vector.cpp | 4 ---- src/function/table/arrow_conversion.cpp | 3 --- src/include/duckdb/common/types/vector.hpp | 10 +--------- 4 files changed, 1 insertion(+), 26 deletions(-) diff --git a/src/common/enum_util.cpp b/src/common/enum_util.cpp index b2db02f412c9..12a94a2b0f1c 100644 --- a/src/common/enum_util.cpp +++ b/src/common/enum_util.cpp @@ -3802,14 +3802,10 @@ const char* EnumUtil::ToChars(MapInvalidReason value) { switch(value) { case MapInvalidReason::VALID: return "VALID"; - case MapInvalidReason::NULL_KEY_LIST: - return "NULL_KEY_LIST"; case MapInvalidReason::NULL_KEY: return "NULL_KEY"; case MapInvalidReason::DUPLICATE_KEY: return "DUPLICATE_KEY"; - case MapInvalidReason::NULL_VALUE_LIST: - return "NULL_VALUE_LIST"; case MapInvalidReason::NOT_ALIGNED: return "NOT_ALIGNED"; case MapInvalidReason::INVALID_PARAMS: @@ -3824,18 +3820,12 @@ MapInvalidReason EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "VALID")) { return MapInvalidReason::VALID; } - if (StringUtil::Equals(value, "NULL_KEY_LIST")) { - return MapInvalidReason::NULL_KEY_LIST; - } if (StringUtil::Equals(value, "NULL_KEY")) { return MapInvalidReason::NULL_KEY; } if (StringUtil::Equals(value, "DUPLICATE_KEY")) { return MapInvalidReason::DUPLICATE_KEY; } - if (StringUtil::Equals(value, "NULL_VALUE_LIST")) { - return MapInvalidReason::NULL_VALUE_LIST; - } if (StringUtil::Equals(value, "NOT_ALIGNED")) { return MapInvalidReason::NOT_ALIGNED; } diff --git a/src/common/types/vector.cpp b/src/common/types/vector.cpp index e61ad46dd259..112dc0de96ab 100644 --- a/src/common/types/vector.cpp +++ b/src/common/types/vector.cpp @@ -2089,10 +2089,6 @@ void MapVector::EvalMapInvalidReason(MapInvalidReason reason) { throw InvalidInputException("Map keys must be unique."); case MapInvalidReason::NULL_KEY: throw InvalidInputException("Map keys can not be NULL."); - case MapInvalidReason::NULL_KEY_LIST: - throw InvalidInputException("The list of map keys must not be NULL."); - case MapInvalidReason::NULL_VALUE_LIST: - throw InvalidInputException("The list of map values must not be NULL."); case MapInvalidReason::NOT_ALIGNED: throw InvalidInputException("The map key list does not align with the map value list."); case MapInvalidReason::INVALID_PARAMS: diff --git a/src/function/table/arrow_conversion.cpp b/src/function/table/arrow_conversion.cpp index 78d4cca859f6..c1759ef8484c 100644 --- a/src/function/table/arrow_conversion.cpp +++ b/src/function/table/arrow_conversion.cpp @@ -317,9 +317,6 @@ static void ArrowToDuckDBMapVerify(Vector &vector, idx_t count) { case MapInvalidReason::NULL_KEY: { throw InvalidInputException("Arrow map contains NULL as map key, which isn't supported by DuckDB map type"); } - case MapInvalidReason::NULL_KEY_LIST: { - throw InvalidInputException("Arrow map contains NULL as key list, which isn't supported by DuckDB map type"); - } default: { throw InternalException("MapInvalidReason not implemented"); } diff --git a/src/include/duckdb/common/types/vector.hpp b/src/include/duckdb/common/types/vector.hpp index b0786597a662..49cb9111c464 100644 --- a/src/include/duckdb/common/types/vector.hpp +++ b/src/include/duckdb/common/types/vector.hpp @@ -464,15 +464,7 @@ struct FSSTVector { DUCKDB_API static idx_t GetCount(Vector &vector); }; -enum class MapInvalidReason : uint8_t { - VALID, - NULL_KEY_LIST, - NULL_KEY, - DUPLICATE_KEY, - NULL_VALUE_LIST, - NOT_ALIGNED, - INVALID_PARAMS -}; +enum class MapInvalidReason : uint8_t { VALID, NULL_KEY, DUPLICATE_KEY, NOT_ALIGNED, INVALID_PARAMS }; struct MapVector { DUCKDB_API static const Vector &GetKeys(const Vector &vector); From 738b30192333b3d36a510f3fe6d1d9f247e7a220 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 19 Apr 2024 11:15:22 +0200 Subject: [PATCH 190/201] remove NULL_KEY_LIST error from numpy scan --- tools/pythonpkg/src/numpy/numpy_scan.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/tools/pythonpkg/src/numpy/numpy_scan.cpp b/tools/pythonpkg/src/numpy/numpy_scan.cpp index 032d3b97f014..b4b1d3dbe276 100644 --- a/tools/pythonpkg/src/numpy/numpy_scan.cpp +++ b/tools/pythonpkg/src/numpy/numpy_scan.cpp @@ -153,8 +153,6 @@ static void VerifyMapConstraints(Vector &vec, idx_t count) { return; case MapInvalidReason::DUPLICATE_KEY: throw InvalidInputException("Dict->Map conversion failed because 'key' list contains duplicates"); - case MapInvalidReason::NULL_KEY_LIST: - throw InvalidInputException("Dict->Map conversion failed because 'key' list is None"); case MapInvalidReason::NULL_KEY: throw InvalidInputException("Dict->Map conversion failed because 'key' list contains None"); default: From 87f3fe3fb9ffe97553367963f9a731b827535fca Mon Sep 17 00:00:00 2001 From: Mark Raasveldt Date: Fri, 19 Apr 2024 11:20:17 +0200 Subject: [PATCH 191/201] Move binding of constraints out of execution and into the binding phase of insert/update/delete --- .../catalog_entry/table_catalog_entry.cpp | 2 +- .../persistent/physical_batch_insert.cpp | 20 +++---- .../operator/persistent/physical_delete.cpp | 14 +++-- .../operator/persistent/physical_insert.cpp | 53 +++++++++---------- .../operator/persistent/physical_update.cpp | 21 +++++--- src/execution/physical_plan/plan_delete.cpp | 4 +- src/execution/physical_plan/plan_insert.cpp | 11 ++-- src/execution/physical_plan/plan_update.cpp | 6 +-- .../catalog_entry/table_catalog_entry.hpp | 2 +- .../persistent/physical_batch_insert.hpp | 4 +- .../operator/persistent/physical_delete.hpp | 9 ++-- .../operator/persistent/physical_insert.hpp | 13 +++-- .../operator/persistent/physical_update.hpp | 4 +- src/include/duckdb/planner/binder.hpp | 1 + .../planner/operator/logical_delete.hpp | 1 + .../planner/operator/logical_insert.hpp | 2 + .../planner/operator/logical_update.hpp | 1 + src/include/duckdb/storage/data_table.hpp | 27 ++++++---- .../duckdb/storage/table/append_state.hpp | 9 ++-- .../duckdb/storage/table/delete_state.hpp | 2 +- .../duckdb/storage/table/update_state.hpp | 2 +- src/main/appender.cpp | 4 +- src/main/client_context.cpp | 4 +- .../binder/statement/bind_create_table.cpp | 4 ++ src/planner/binder/statement/bind_delete.cpp | 1 + src/planner/binder/statement/bind_insert.cpp | 1 + src/planner/binder/statement/bind_update.cpp | 1 + src/planner/operator/logical_delete.cpp | 2 + src/planner/operator/logical_insert.cpp | 2 + src/planner/operator/logical_update.cpp | 2 + src/storage/data_table.cpp | 42 ++++++++------- src/storage/wal_replay.cpp | 4 +- 32 files changed, 164 insertions(+), 111 deletions(-) diff --git a/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/catalog/catalog_entry/table_catalog_entry.cpp index 8f73ec416e62..271970a957de 100644 --- a/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -164,7 +164,7 @@ const ColumnDefinition &TableCatalogEntry::GetColumn(LogicalIndex idx) { return columns.GetColumn(idx); } -const vector> &TableCatalogEntry::GetConstraints() { +const vector> &TableCatalogEntry::GetConstraints() const { return constraints; } diff --git a/src/execution/operator/persistent/physical_batch_insert.cpp b/src/execution/operator/persistent/physical_batch_insert.cpp index 210f0b3d201d..12eda91c7e82 100644 --- a/src/execution/operator/persistent/physical_batch_insert.cpp +++ b/src/execution/operator/persistent/physical_batch_insert.cpp @@ -13,12 +13,14 @@ namespace duckdb { -PhysicalBatchInsert::PhysicalBatchInsert(vector types, TableCatalogEntry &table, - physical_index_vector_t column_index_map, - vector> bound_defaults, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::BATCH_INSERT, std::move(types), estimated_cardinality), - column_index_map(std::move(column_index_map)), insert_table(&table), insert_types(table.GetTypes()), - bound_defaults(std::move(bound_defaults)) { +PhysicalBatchInsert::PhysicalBatchInsert(vector types_p, TableCatalogEntry &table, + physical_index_vector_t column_index_map_p, + vector> bound_defaults_p, + vector> bound_constraints_p, + idx_t estimated_cardinality) + : PhysicalOperator(PhysicalOperatorType::BATCH_INSERT, std::move(types_p), estimated_cardinality), + column_index_map(std::move(column_index_map_p)), insert_table(&table), insert_types(table.GetTypes()), + bound_defaults(std::move(bound_defaults_p)), bound_constraints(std::move(bound_constraints_p)) { } PhysicalBatchInsert::PhysicalBatchInsert(LogicalOperator &op, SchemaCatalogEntry &schema, @@ -171,7 +173,7 @@ class BatchInsertLocalState : public LocalSinkState { TableAppendState current_append_state; unique_ptr current_collection; optional_ptr writer; - unique_ptr constraint_state; + unique_ptr constraint_state; void CreateNewCollection(DuckTableEntry &table, const vector &insert_types) { auto &table_info = table.GetStorage().info; @@ -496,7 +498,7 @@ SinkResultType PhysicalBatchInsert::Sink(ExecutionContext &context, DataChunk &c } if (!lstate.constraint_state) { - lstate.constraint_state = table.GetStorage().InitializeConstraintVerification(table, context.client); + lstate.constraint_state = table.GetStorage().InitializeConstraintState(table, bound_constraints); } table.GetStorage().VerifyAppendConstraints(*lstate.constraint_state, context.client, lstate.insert_chunk); @@ -599,7 +601,7 @@ SinkFinalizeType PhysicalBatchInsert::Finalize(Pipeline &pipeline, Event &event, auto &table = gstate.table; auto &storage = table.GetStorage(); LocalAppendState append_state; - storage.InitializeLocalAppend(append_state, table, context); + storage.InitializeLocalAppend(append_state, table, context, bound_constraints); auto &transaction = DuckTransaction::Get(context, table.catalog); for (auto &entry : gstate.collections) { if (entry.type != RowGroupBatchType::NOT_FLUSHED) { diff --git a/src/execution/operator/persistent/physical_delete.cpp b/src/execution/operator/persistent/physical_delete.cpp index 300376ce27fe..bb2b1a76057a 100644 --- a/src/execution/operator/persistent/physical_delete.cpp +++ b/src/execution/operator/persistent/physical_delete.cpp @@ -10,6 +10,13 @@ namespace duckdb { +PhysicalDelete::PhysicalDelete(vector types, TableCatalogEntry &tableref, DataTable &table, + vector> bound_constraints, idx_t row_id_index, + idx_t estimated_cardinality, bool return_chunk) + : PhysicalOperator(PhysicalOperatorType::DELETE_OPERATOR, std::move(types), estimated_cardinality), + tableref(tableref), table(table), bound_constraints(std::move(bound_constraints)), row_id_index(row_id_index), + return_chunk(return_chunk) { +} //===--------------------------------------------------------------------===// // Sink //===--------------------------------------------------------------------===// @@ -26,9 +33,10 @@ class DeleteGlobalState : public GlobalSinkState { class DeleteLocalState : public LocalSinkState { public: - DeleteLocalState(ClientContext &context, TableCatalogEntry &table) { + DeleteLocalState(ClientContext &context, TableCatalogEntry &table, + const vector> &bound_constraints) { delete_chunk.Initialize(Allocator::Get(context), table.GetTypes()); - delete_state = table.GetStorage().InitializeDelete(table, context); + delete_state = table.GetStorage().InitializeDelete(table, context, bound_constraints); } DataChunk delete_chunk; unique_ptr delete_state; @@ -64,7 +72,7 @@ unique_ptr PhysicalDelete::GetGlobalSinkState(ClientContext &co } unique_ptr PhysicalDelete::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, tableref); + return make_uniq(context.client, tableref, bound_constraints); } //===--------------------------------------------------------------------===// diff --git a/src/execution/operator/persistent/physical_insert.cpp b/src/execution/operator/persistent/physical_insert.cpp index c7b499855c5a..a2d3e9c61110 100644 --- a/src/execution/operator/persistent/physical_insert.cpp +++ b/src/execution/operator/persistent/physical_insert.cpp @@ -21,22 +21,20 @@ namespace duckdb { -PhysicalInsert::PhysicalInsert(vector types_p, TableCatalogEntry &table, - physical_index_vector_t column_index_map, - vector> bound_defaults, - vector> set_expressions, vector set_columns, - vector set_types, idx_t estimated_cardinality, bool return_chunk, - bool parallel, OnConflictAction action_type, - unique_ptr on_conflict_condition_p, - unique_ptr do_update_condition_p, unordered_set conflict_target_p, - vector columns_to_fetch_p) +PhysicalInsert::PhysicalInsert( + vector types_p, TableCatalogEntry &table, physical_index_vector_t column_index_map, + vector> bound_defaults, vector> bound_constraints_p, + vector> set_expressions, vector set_columns, vector set_types, + idx_t estimated_cardinality, bool return_chunk, bool parallel, OnConflictAction action_type, + unique_ptr on_conflict_condition_p, unique_ptr do_update_condition_p, + unordered_set conflict_target_p, vector columns_to_fetch_p) : PhysicalOperator(PhysicalOperatorType::INSERT, std::move(types_p), estimated_cardinality), column_index_map(std::move(column_index_map)), insert_table(&table), insert_types(table.GetTypes()), - bound_defaults(std::move(bound_defaults)), return_chunk(return_chunk), parallel(parallel), - action_type(action_type), set_expressions(std::move(set_expressions)), set_columns(std::move(set_columns)), - set_types(std::move(set_types)), on_conflict_condition(std::move(on_conflict_condition_p)), - do_update_condition(std::move(do_update_condition_p)), conflict_target(std::move(conflict_target_p)), - columns_to_fetch(std::move(columns_to_fetch_p)) { + bound_defaults(std::move(bound_defaults)), bound_constraints(std::move(bound_constraints_p)), + return_chunk(return_chunk), parallel(parallel), action_type(action_type), + set_expressions(std::move(set_expressions)), set_columns(std::move(set_columns)), set_types(std::move(set_types)), + on_conflict_condition(std::move(on_conflict_condition_p)), do_update_condition(std::move(do_update_condition_p)), + conflict_target(std::move(conflict_target_p)), columns_to_fetch(std::move(columns_to_fetch_p)) { if (action_type == OnConflictAction::THROW) { return; @@ -91,8 +89,9 @@ class InsertGlobalState : public GlobalSinkState { class InsertLocalState : public LocalSinkState { public: InsertLocalState(ClientContext &context, const vector &types, - const vector> &bound_defaults) - : default_executor(context, bound_defaults) { + const vector> &bound_defaults, + const vector> &bound_constraints) + : default_executor(context, bound_defaults), bound_constraints(bound_constraints) { insert_chunk.Initialize(Allocator::Get(context), types); } @@ -106,12 +105,12 @@ class InsertLocalState : public LocalSinkState { // Rows in the transaction-local storage that have been updated by a DO UPDATE conflict unordered_set updated_local_rows; idx_t update_count = 0; - unique_ptr constraint_state; + unique_ptr constraint_state; + const vector> &bound_constraints; - ConstraintVerificationState &GetConstraintState(DataTable &table, TableCatalogEntry &tableref, - ClientContext &context) { + ConstraintState &GetConstraintState(DataTable &table, TableCatalogEntry &tableref) { if (!constraint_state) { - constraint_state = table.InitializeConstraintVerification(tableref, context); + constraint_state = table.InitializeConstraintState(tableref, bound_constraints); } return *constraint_state; } @@ -135,7 +134,7 @@ unique_ptr PhysicalInsert::GetGlobalSinkState(ClientContext &co } unique_ptr PhysicalInsert::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, insert_types, bound_defaults); + return make_uniq(context.client, insert_types, bound_defaults, bound_constraints); } void PhysicalInsert::ResolveDefaults(const TableCatalogEntry &table, DataChunk &chunk, @@ -288,7 +287,7 @@ static idx_t PerformOnConflictAction(ExecutionContext &context, DataChunk &chunk auto &data_table = table.GetStorage(); // Perform the update, using the results of the SET expressions if (GLOBAL) { - auto update_state = data_table.InitializeUpdate(table, context.client); + auto update_state = data_table.InitializeUpdate(table, context.client, op.bound_constraints); data_table.Update(*update_state, context.client, row_ids, set_columns, update_chunk); } else { auto &local_storage = LocalStorage::Get(context.client, data_table.db); @@ -331,7 +330,7 @@ static idx_t HandleInsertConflicts(TableCatalogEntry &table, ExecutionContext &c ConflictInfo conflict_info(conflict_target); ConflictManager conflict_manager(VerifyExistenceType::APPEND, lstate.insert_chunk.size(), &conflict_info); if (GLOBAL) { - auto &constraint_state = lstate.GetConstraintState(data_table, table, context.client); + auto &constraint_state = lstate.GetConstraintState(data_table, table); data_table.VerifyAppendConstraints(constraint_state, context.client, lstate.insert_chunk, &conflict_manager); } else { DataTable::VerifyUniqueIndexes(local_storage.GetIndexes(data_table), context.client, lstate.insert_chunk, @@ -392,7 +391,7 @@ static idx_t HandleInsertConflicts(TableCatalogEntry &table, ExecutionContext &c combined_chunk.Slice(sel.Selection(), sel.Count()); row_ids.Slice(sel.Selection(), sel.Count()); if (GLOBAL) { - auto &constraint_state = lstate.GetConstraintState(data_table, table, context.client); + auto &constraint_state = lstate.GetConstraintState(data_table, table); data_table.VerifyAppendConstraints(constraint_state, context.client, combined_chunk, nullptr); } else { DataTable::VerifyUniqueIndexes(local_storage.GetIndexes(data_table), context.client, @@ -419,7 +418,7 @@ idx_t PhysicalInsert::OnConflictHandling(TableCatalogEntry &table, ExecutionCont InsertLocalState &lstate) const { auto &data_table = table.GetStorage(); if (action_type == OnConflictAction::THROW) { - auto &constraint_state = lstate.GetConstraintState(data_table, table, context.client); + auto &constraint_state = lstate.GetConstraintState(data_table, table); data_table.VerifyAppendConstraints(constraint_state, context.client, lstate.insert_chunk, nullptr); return 0; } @@ -443,7 +442,7 @@ SinkResultType PhysicalInsert::Sink(ExecutionContext &context, DataChunk &chunk, if (!parallel) { if (!gstate.initialized) { - storage.InitializeLocalAppend(gstate.append_state, table, context.client); + storage.InitializeLocalAppend(gstate.append_state, table, context.client, bound_constraints); gstate.initialized = true; } @@ -501,7 +500,7 @@ SinkCombineResultType PhysicalInsert::Combine(ExecutionContext &context, Operato // we have few rows - append to the local storage directly auto &table = gstate.table; auto &storage = table.GetStorage(); - storage.InitializeLocalAppend(gstate.append_state, table, context.client); + storage.InitializeLocalAppend(gstate.append_state, table, context.client, bound_constraints); auto &transaction = DuckTransaction::Get(context.client, table.catalog); lstate.local_collection->Scan(transaction, [&](DataChunk &insert_chunk) { storage.LocalAppend(gstate.append_state, table, context.client, insert_chunk); diff --git a/src/execution/operator/persistent/physical_update.cpp b/src/execution/operator/persistent/physical_update.cpp index 6c8c6c041028..8dc663430642 100644 --- a/src/execution/operator/persistent/physical_update.cpp +++ b/src/execution/operator/persistent/physical_update.cpp @@ -15,11 +15,13 @@ namespace duckdb { PhysicalUpdate::PhysicalUpdate(vector types, TableCatalogEntry &tableref, DataTable &table, vector columns, vector> expressions, - vector> bound_defaults, idx_t estimated_cardinality, + vector> bound_defaults, + vector> bound_constraints, idx_t estimated_cardinality, bool return_chunk) : PhysicalOperator(PhysicalOperatorType::UPDATE, std::move(types), estimated_cardinality), tableref(tableref), table(table), columns(std::move(columns)), expressions(std::move(expressions)), - bound_defaults(std::move(bound_defaults)), return_chunk(return_chunk) { + bound_defaults(std::move(bound_defaults)), bound_constraints(std::move(bound_constraints)), + return_chunk(return_chunk) { } //===--------------------------------------------------------------------===// @@ -40,8 +42,9 @@ class UpdateGlobalState : public GlobalSinkState { class UpdateLocalState : public LocalSinkState { public: UpdateLocalState(ClientContext &context, const vector> &expressions, - const vector &table_types, const vector> &bound_defaults) - : default_executor(context, bound_defaults) { + const vector &table_types, const vector> &bound_defaults, + const vector> &bound_constraints) + : default_executor(context, bound_defaults), bound_constraints(bound_constraints) { // initialize the update chunk auto &allocator = Allocator::Get(context); vector update_types; @@ -59,17 +62,18 @@ class UpdateLocalState : public LocalSinkState { ExpressionExecutor default_executor; unique_ptr delete_state; unique_ptr update_state; + const vector> &bound_constraints; TableDeleteState &GetDeleteState(DataTable &table, TableCatalogEntry &tableref, ClientContext &context) { if (!delete_state) { - delete_state = table.InitializeDelete(tableref, context); + delete_state = table.InitializeDelete(tableref, context, bound_constraints); } return *delete_state; } TableUpdateState &GetUpdateState(DataTable &table, TableCatalogEntry &tableref, ClientContext &context) { if (!update_state) { - update_state = table.InitializeUpdate(tableref, context); + update_state = table.InitializeUpdate(tableref, context, bound_constraints); } return *update_state; } @@ -131,7 +135,7 @@ SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, for (idx_t i = 0; i < columns.size(); i++) { mock_chunk.data[columns[i].index].Reference(update_chunk.data[i]); } - table.LocalAppend(tableref, context.client, mock_chunk); + table.LocalAppend(tableref, context.client, mock_chunk, bound_constraints); } else { if (return_chunk) { mock_chunk.SetCardinality(update_chunk); @@ -157,7 +161,8 @@ unique_ptr PhysicalUpdate::GetGlobalSinkState(ClientContext &co } unique_ptr PhysicalUpdate::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, expressions, table.GetTypes(), bound_defaults); + return make_uniq(context.client, expressions, table.GetTypes(), bound_defaults, + bound_constraints); } SinkCombineResultType PhysicalUpdate::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { diff --git a/src/execution/physical_plan/plan_delete.cpp b/src/execution/physical_plan/plan_delete.cpp index 3a748c00da4b..7550da083d44 100644 --- a/src/execution/physical_plan/plan_delete.cpp +++ b/src/execution/physical_plan/plan_delete.cpp @@ -12,8 +12,8 @@ unique_ptr DuckCatalog::PlanDelete(ClientContext &context, Log // get the index of the row_id column auto &bound_ref = op.expressions[0]->Cast(); - auto del = make_uniq(op.types, op.table, op.table.GetStorage(), bound_ref.index, - op.estimated_cardinality, op.return_chunk); + auto del = make_uniq(op.types, op.table, op.table.GetStorage(), std::move(op.bound_constraints), + bound_ref.index, op.estimated_cardinality, op.return_chunk); del->children.push_back(std::move(plan)); return std::move(del); } diff --git a/src/execution/physical_plan/plan_insert.cpp b/src/execution/physical_plan/plan_insert.cpp index 3aa87d9abd7e..0ce031d07b0b 100644 --- a/src/execution/physical_plan/plan_insert.cpp +++ b/src/execution/physical_plan/plan_insert.cpp @@ -94,13 +94,14 @@ unique_ptr DuckCatalog::PlanInsert(ClientContext &context, Log unique_ptr insert; if (use_batch_index && !parallel_streaming_insert) { insert = make_uniq(op.types, op.table, op.column_index_map, std::move(op.bound_defaults), - op.estimated_cardinality); + std::move(op.bound_constraints), op.estimated_cardinality); } else { insert = make_uniq( - op.types, op.table, op.column_index_map, std::move(op.bound_defaults), std::move(op.expressions), - std::move(op.set_columns), std::move(op.set_types), op.estimated_cardinality, op.return_chunk, - parallel_streaming_insert && num_threads > 1, op.action_type, std::move(op.on_conflict_condition), - std::move(op.do_update_condition), std::move(op.on_conflict_filter), std::move(op.columns_to_fetch)); + op.types, op.table, op.column_index_map, std::move(op.bound_defaults), std::move(op.bound_constraints), + std::move(op.expressions), std::move(op.set_columns), std::move(op.set_types), op.estimated_cardinality, + op.return_chunk, parallel_streaming_insert && num_threads > 1, op.action_type, + std::move(op.on_conflict_condition), std::move(op.do_update_condition), std::move(op.on_conflict_filter), + std::move(op.columns_to_fetch)); } D_ASSERT(plan); insert->children.push_back(std::move(plan)); diff --git a/src/execution/physical_plan/plan_update.cpp b/src/execution/physical_plan/plan_update.cpp index 2789008347b6..b3591423303f 100644 --- a/src/execution/physical_plan/plan_update.cpp +++ b/src/execution/physical_plan/plan_update.cpp @@ -8,9 +8,9 @@ namespace duckdb { unique_ptr DuckCatalog::PlanUpdate(ClientContext &context, LogicalUpdate &op, unique_ptr plan) { - auto update = - make_uniq(op.types, op.table, op.table.GetStorage(), op.columns, std::move(op.expressions), - std::move(op.bound_defaults), op.estimated_cardinality, op.return_chunk); + auto update = make_uniq(op.types, op.table, op.table.GetStorage(), op.columns, + std::move(op.expressions), std::move(op.bound_defaults), + std::move(op.bound_constraints), op.estimated_cardinality, op.return_chunk); update->update_is_del_and_insert = op.update_is_del_and_insert; update->children.push_back(std::move(plan)); diff --git a/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp b/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp index fa9a795f2b0f..6e8c0f248944 100644 --- a/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp +++ b/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp @@ -77,7 +77,7 @@ class TableCatalogEntry : public StandardEntry { virtual DataTable &GetStorage(); //! Returns a list of the constraints of the table - DUCKDB_API const vector> &GetConstraints(); + DUCKDB_API const vector> &GetConstraints() const; DUCKDB_API string ToSQL() const override; //! Get statistics of a column (physical or virtual) within the table diff --git a/src/include/duckdb/execution/operator/persistent/physical_batch_insert.hpp b/src/include/duckdb/execution/operator/persistent/physical_batch_insert.hpp index 1a7155f37076..bce3ac633cf9 100644 --- a/src/include/duckdb/execution/operator/persistent/physical_batch_insert.hpp +++ b/src/include/duckdb/execution/operator/persistent/physical_batch_insert.hpp @@ -20,7 +20,7 @@ class PhysicalBatchInsert : public PhysicalOperator { //! INSERT INTO PhysicalBatchInsert(vector types, TableCatalogEntry &table, physical_index_vector_t column_index_map, vector> bound_defaults, - idx_t estimated_cardinality); + vector> bound_constraints, idx_t estimated_cardinality); //! CREATE TABLE AS PhysicalBatchInsert(LogicalOperator &op, SchemaCatalogEntry &schema, unique_ptr info, idx_t estimated_cardinality); @@ -33,6 +33,8 @@ class PhysicalBatchInsert : public PhysicalOperator { vector insert_types; //! The default expressions of the columns for which no value is provided vector> bound_defaults; + //! The bound constraints for the table + vector> bound_constraints; //! Table schema, in case of CREATE TABLE AS optional_ptr schema; //! Create table info, in case of CREATE TABLE AS diff --git a/src/include/duckdb/execution/operator/persistent/physical_delete.hpp b/src/include/duckdb/execution/operator/persistent/physical_delete.hpp index 387df24d05bf..740db21ad759 100644 --- a/src/include/duckdb/execution/operator/persistent/physical_delete.hpp +++ b/src/include/duckdb/execution/operator/persistent/physical_delete.hpp @@ -19,14 +19,13 @@ class PhysicalDelete : public PhysicalOperator { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::DELETE_OPERATOR; public: - PhysicalDelete(vector types, TableCatalogEntry &tableref, DataTable &table, idx_t row_id_index, - idx_t estimated_cardinality, bool return_chunk) - : PhysicalOperator(PhysicalOperatorType::DELETE_OPERATOR, std::move(types), estimated_cardinality), - tableref(tableref), table(table), row_id_index(row_id_index), return_chunk(return_chunk) { - } + PhysicalDelete(vector types, TableCatalogEntry &tableref, DataTable &table, + vector> bound_constraints, idx_t row_id_index, + idx_t estimated_cardinality, bool return_chunk); TableCatalogEntry &tableref; DataTable &table; + vector> bound_constraints; idx_t row_id_index; bool return_chunk; diff --git a/src/include/duckdb/execution/operator/persistent/physical_insert.hpp b/src/include/duckdb/execution/operator/persistent/physical_insert.hpp index 8609a928a6fd..9680b5837b71 100644 --- a/src/include/duckdb/execution/operator/persistent/physical_insert.hpp +++ b/src/include/duckdb/execution/operator/persistent/physical_insert.hpp @@ -26,11 +26,12 @@ class PhysicalInsert : public PhysicalOperator { public: //! INSERT INTO PhysicalInsert(vector types, TableCatalogEntry &table, physical_index_vector_t column_index_map, - vector> bound_defaults, vector> set_expressions, - vector set_columns, vector set_types, idx_t estimated_cardinality, - bool return_chunk, bool parallel, OnConflictAction action_type, - unique_ptr on_conflict_condition, unique_ptr do_update_condition, - unordered_set on_conflict_filter, vector columns_to_fetch); + vector> bound_defaults, vector> bound_constraints, + vector> set_expressions, vector set_columns, + vector set_types, idx_t estimated_cardinality, bool return_chunk, bool parallel, + OnConflictAction action_type, unique_ptr on_conflict_condition, + unique_ptr do_update_condition, unordered_set on_conflict_filter, + vector columns_to_fetch); //! CREATE TABLE AS PhysicalInsert(LogicalOperator &op, SchemaCatalogEntry &schema, unique_ptr info, idx_t estimated_cardinality, bool parallel); @@ -43,6 +44,8 @@ class PhysicalInsert : public PhysicalOperator { vector insert_types; //! The default expressions of the columns for which no value is provided vector> bound_defaults; + //! The bound constraints for the table + vector> bound_constraints; //! If the returning statement is present, return the whole chunk bool return_chunk; //! Table schema, in case of CREATE TABLE AS diff --git a/src/include/duckdb/execution/operator/persistent/physical_update.hpp b/src/include/duckdb/execution/operator/persistent/physical_update.hpp index 0cccd98d51ee..b064c8b2e98f 100644 --- a/src/include/duckdb/execution/operator/persistent/physical_update.hpp +++ b/src/include/duckdb/execution/operator/persistent/physical_update.hpp @@ -22,13 +22,15 @@ class PhysicalUpdate : public PhysicalOperator { public: PhysicalUpdate(vector types, TableCatalogEntry &tableref, DataTable &table, vector columns, vector> expressions, - vector> bound_defaults, idx_t estimated_cardinality, bool return_chunk); + vector> bound_defaults, vector> bound_constraints, + idx_t estimated_cardinality, bool return_chunk); TableCatalogEntry &tableref; DataTable &table; vector columns; vector> expressions; vector> bound_defaults; + vector> bound_constraints; bool update_is_del_and_insert; //! If the returning statement is present, return the whole chunk bool return_chunk; diff --git a/src/include/duckdb/planner/binder.hpp b/src/include/duckdb/planner/binder.hpp index 088b4ae858b8..c2afbc1b8f5a 100644 --- a/src/include/duckdb/planner/binder.hpp +++ b/src/include/duckdb/planner/binder.hpp @@ -127,6 +127,7 @@ class Binder : public enable_shared_from_this { const string &table_name, const ColumnList &columns); vector> BindConstraints(const vector> &constraints, const string &table_name, const ColumnList &columns); + vector> BindConstraints(const TableCatalogEntry &table); vector> BindNewConstraints(vector> &constraints, const string &table_name, const ColumnList &columns); diff --git a/src/include/duckdb/planner/operator/logical_delete.hpp b/src/include/duckdb/planner/operator/logical_delete.hpp index 005d955b6b82..9a243a1f730b 100644 --- a/src/include/duckdb/planner/operator/logical_delete.hpp +++ b/src/include/duckdb/planner/operator/logical_delete.hpp @@ -23,6 +23,7 @@ class LogicalDelete : public LogicalOperator { TableCatalogEntry &table; idx_t table_index; bool return_chunk; + vector> bound_constraints; public: void Serialize(Serializer &serializer) const override; diff --git a/src/include/duckdb/planner/operator/logical_insert.hpp b/src/include/duckdb/planner/operator/logical_insert.hpp index 0c31fd6e4bd4..c8eb3997f844 100644 --- a/src/include/duckdb/planner/operator/logical_insert.hpp +++ b/src/include/duckdb/planner/operator/logical_insert.hpp @@ -37,6 +37,8 @@ class LogicalInsert : public LogicalOperator { bool return_chunk; //! The default statements used by the table vector> bound_defaults; + //! The constraints used by the table + vector> bound_constraints; //! Which action to take on conflict OnConflictAction action_type; diff --git a/src/include/duckdb/planner/operator/logical_update.hpp b/src/include/duckdb/planner/operator/logical_update.hpp index 9215ff694685..e0356419b0d2 100644 --- a/src/include/duckdb/planner/operator/logical_update.hpp +++ b/src/include/duckdb/planner/operator/logical_update.hpp @@ -28,6 +28,7 @@ class LogicalUpdate : public LogicalOperator { bool return_chunk; vector columns; vector> bound_defaults; + vector> bound_constraints; bool update_is_del_and_insert; public: diff --git a/src/include/duckdb/storage/data_table.hpp b/src/include/duckdb/storage/data_table.hpp index e6aedf464bc9..dcde5af382b5 100644 --- a/src/include/duckdb/storage/data_table.hpp +++ b/src/include/duckdb/storage/data_table.hpp @@ -42,7 +42,7 @@ class TableDataWriter; class ConflictManager; class TableScanState; struct TableDeleteState; -struct ConstraintVerificationState; +struct ConstraintState; struct TableUpdateState; enum class VerifyExistenceType : uint8_t; @@ -95,27 +95,32 @@ class DataTable { const Vector &row_ids, idx_t fetch_count, ColumnFetchState &state); //! Initializes an append to transaction-local storage - void InitializeLocalAppend(LocalAppendState &state, TableCatalogEntry &table, ClientContext &context); + void InitializeLocalAppend(LocalAppendState &state, TableCatalogEntry &table, ClientContext &context, + const vector> &bound_constraints); //! Append a DataChunk to the transaction-local storage of the table. void LocalAppend(LocalAppendState &state, TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, bool unsafe = false); //! Finalizes a transaction-local append void FinalizeLocalAppend(LocalAppendState &state); //! Append a chunk to the transaction-local storage of this table - void LocalAppend(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk); + void LocalAppend(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, + const vector> &bound_constraints); //! Append a column data collection to the transaction-local storage of this table - void LocalAppend(TableCatalogEntry &table, ClientContext &context, ColumnDataCollection &collection); + void LocalAppend(TableCatalogEntry &table, ClientContext &context, ColumnDataCollection &collection, + const vector> &bound_constraints); //! Merge a row group collection into the transaction-local storage void LocalMerge(ClientContext &context, RowGroupCollection &collection); //! Creates an optimistic writer for this table - used for optimistically writing parallel appends OptimisticDataWriter &CreateOptimisticWriter(ClientContext &context); void FinalizeOptimisticWriter(ClientContext &context, OptimisticDataWriter &writer); - unique_ptr InitializeDelete(TableCatalogEntry &table, ClientContext &context); + unique_ptr InitializeDelete(TableCatalogEntry &table, ClientContext &context, + const vector> &bound_constraints); //! Delete the entries with the specified row identifier from the table idx_t Delete(TableDeleteState &state, ClientContext &context, Vector &row_ids, idx_t count); - unique_ptr InitializeUpdate(TableCatalogEntry &table, ClientContext &context); + unique_ptr InitializeUpdate(TableCatalogEntry &table, ClientContext &context, + const vector> &bound_constraints); //! Update the entries with the specified row identifier from the table void Update(TableUpdateState &state, ClientContext &context, Vector &row_ids, const vector &column_ids, DataChunk &data); @@ -192,11 +197,11 @@ class DataTable { //! FIXME: This is only necessary until we treat all indexes as catalog entries, allowing to alter constraints bool IndexNameIsUnique(const string &name); - //! Initialize constraint verification - unique_ptr InitializeConstraintVerification(TableCatalogEntry &table, - ClientContext &context); + //! Initialize constraint verification state + unique_ptr InitializeConstraintState(TableCatalogEntry &table, + const vector> &bound_constraints); //! Verify constraints with a chunk from the Append containing all columns of the table - void VerifyAppendConstraints(ConstraintVerificationState &state, ClientContext &context, DataChunk &chunk, + void VerifyAppendConstraints(ConstraintState &state, ClientContext &context, DataChunk &chunk, optional_ptr conflict_manager = nullptr); public: @@ -207,7 +212,7 @@ class DataTable { //! Verify the new added constraints against current persistent&local data void VerifyNewConstraint(ClientContext &context, DataTable &parent, const BoundConstraint *constraint); //! Verify constraints with a chunk from the Update containing only the specified column_ids - void VerifyUpdateConstraints(ConstraintVerificationState &state, ClientContext &context, DataChunk &chunk, + void VerifyUpdateConstraints(ConstraintState &state, ClientContext &context, DataChunk &chunk, const vector &column_ids); //! Verify constraints with a chunk from the Delete containing all columns of the table void VerifyDeleteConstraints(TableDeleteState &state, ClientContext &context, DataChunk &chunk); diff --git a/src/include/duckdb/storage/table/append_state.hpp b/src/include/duckdb/storage/table/append_state.hpp index 8f9a56c5cba2..4488e0aaa791 100644 --- a/src/include/duckdb/storage/table/append_state.hpp +++ b/src/include/duckdb/storage/table/append_state.hpp @@ -70,18 +70,19 @@ struct TableAppendState { TransactionData transaction; }; -struct ConstraintVerificationState { - explicit ConstraintVerificationState(TableCatalogEntry &table_p) : table(table_p) { +struct ConstraintState { + explicit ConstraintState(TableCatalogEntry &table_p, const vector> &bound_constraints) + : table(table_p), bound_constraints(bound_constraints) { } TableCatalogEntry &table; - vector> bound_constraints; + const vector> &bound_constraints; }; struct LocalAppendState { TableAppendState append_state; LocalTableStorage *storage; - unique_ptr constraint_state; + unique_ptr constraint_state; }; } // namespace duckdb diff --git a/src/include/duckdb/storage/table/delete_state.hpp b/src/include/duckdb/storage/table/delete_state.hpp index ec2c2d57c547..6d25df4a5aee 100644 --- a/src/include/duckdb/storage/table/delete_state.hpp +++ b/src/include/duckdb/storage/table/delete_state.hpp @@ -14,7 +14,7 @@ namespace duckdb { class TableCatalogEntry; struct TableDeleteState { - vector> bound_constraints; + unique_ptr constraint_state; bool has_delete_constraints = false; DataChunk verify_chunk; vector col_ids; diff --git a/src/include/duckdb/storage/table/update_state.hpp b/src/include/duckdb/storage/table/update_state.hpp index 50ce404e0132..0a8193c54776 100644 --- a/src/include/duckdb/storage/table/update_state.hpp +++ b/src/include/duckdb/storage/table/update_state.hpp @@ -14,7 +14,7 @@ namespace duckdb { class TableCatalogEntry; struct TableUpdateState { - unique_ptr constraint_state; + unique_ptr constraint_state; }; } // namespace duckdb diff --git a/src/main/appender.cpp b/src/main/appender.cpp index 28a622755b3a..1a8fae39cf96 100644 --- a/src/main/appender.cpp +++ b/src/main/appender.cpp @@ -376,7 +376,9 @@ void Appender::FlushInternal(ColumnDataCollection &collection) { } void InternalAppender::FlushInternal(ColumnDataCollection &collection) { - table.GetStorage().LocalAppend(table, context, collection); + auto binder = Binder::CreateBinder(context); + auto bound_constraints = binder->BindConstraints(table); + table.GetStorage().LocalAppend(table, context, collection, bound_constraints); } void BaseAppender::Close() { diff --git a/src/main/client_context.cpp b/src/main/client_context.cpp index 6cd4573d9168..e6d7eddda401 100644 --- a/src/main/client_context.cpp +++ b/src/main/client_context.cpp @@ -1121,7 +1121,9 @@ void ClientContext::Append(TableDescription &description, ColumnDataCollection & throw InvalidInputException("Failed to append: table entry has different number of columns!"); } } - table_entry.GetStorage().LocalAppend(table_entry, *this, collection); + auto binder = Binder::CreateBinder(*this); + auto bound_constraints = binder->BindConstraints(table_entry); + table_entry.GetStorage().LocalAppend(table_entry, *this, collection, bound_constraints); }); } diff --git a/src/planner/binder/statement/bind_create_table.cpp b/src/planner/binder/statement/bind_create_table.cpp index 8a032b9fe1ee..4651ede86deb 100644 --- a/src/planner/binder/statement/bind_create_table.cpp +++ b/src/planner/binder/statement/bind_create_table.cpp @@ -55,6 +55,10 @@ vector> Binder::BindConstraints(ClientContext &conte return binder->BindConstraints(constraints, table_name, columns); } +vector> Binder::BindConstraints(const TableCatalogEntry &table) { + return BindConstraints(table.GetConstraints(), table.name, table.GetColumns()); +} + vector> Binder::BindConstraints(const vector> &constraints, const string &table_name, const ColumnList &columns) { vector> bound_constraints; diff --git a/src/planner/binder/statement/bind_delete.cpp b/src/planner/binder/statement/bind_delete.cpp index 7ae86228250a..da1707160776 100644 --- a/src/planner/binder/statement/bind_delete.cpp +++ b/src/planner/binder/statement/bind_delete.cpp @@ -69,6 +69,7 @@ BoundStatement Binder::Bind(DeleteStatement &stmt) { } // create the delete node auto del = make_uniq(table, GenerateTableIndex()); + del->bound_constraints = BindConstraints(table); del->AddChild(std::move(root)); // set up the delete expression diff --git a/src/planner/binder/statement/bind_insert.cpp b/src/planner/binder/statement/bind_insert.cpp index cb3782b822db..13ad3ea7ca9a 100644 --- a/src/planner/binder/statement/bind_insert.cpp +++ b/src/planner/binder/statement/bind_insert.cpp @@ -472,6 +472,7 @@ BoundStatement Binder::Bind(InsertStatement &stmt) { // bind the default values BindDefaultValues(table.GetColumns(), insert->bound_defaults); + insert->bound_constraints = BindConstraints(table); if (!stmt.select_statement && !stmt.default_values) { result.plan = std::move(insert); return result; diff --git a/src/planner/binder/statement/bind_update.cpp b/src/planner/binder/statement/bind_update.cpp index d3ea3a5f4774..2d0f6687ca61 100644 --- a/src/planner/binder/statement/bind_update.cpp +++ b/src/planner/binder/statement/bind_update.cpp @@ -106,6 +106,7 @@ BoundStatement Binder::Bind(UpdateStatement &stmt) { } // bind the default values BindDefaultValues(table.GetColumns(), update->bound_defaults); + update->bound_constraints = BindConstraints(table); // project any additional columns required for the condition/expressions if (stmt.set_info->condition) { diff --git a/src/planner/operator/logical_delete.cpp b/src/planner/operator/logical_delete.cpp index 950f2eaa2ebb..d82c3bfd86f2 100644 --- a/src/planner/operator/logical_delete.cpp +++ b/src/planner/operator/logical_delete.cpp @@ -15,6 +15,8 @@ LogicalDelete::LogicalDelete(ClientContext &context, const unique_ptr(context, table_info->catalog, table_info->schema, table_info->Cast().table)) { + auto binder = Binder::CreateBinder(context); + bound_constraints = binder->BindConstraints(table); } idx_t LogicalDelete::EstimateCardinality(ClientContext &context) { diff --git a/src/planner/operator/logical_insert.cpp b/src/planner/operator/logical_insert.cpp index 518661616058..dd5bf92aaeab 100644 --- a/src/planner/operator/logical_insert.cpp +++ b/src/planner/operator/logical_insert.cpp @@ -15,6 +15,8 @@ LogicalInsert::LogicalInsert(ClientContext &context, const unique_ptr(context, table_info->catalog, table_info->schema, table_info->Cast().table)) { + auto binder = Binder::CreateBinder(context); + bound_constraints = binder->BindConstraints(table); } idx_t LogicalInsert::EstimateCardinality(ClientContext &context) { diff --git a/src/planner/operator/logical_update.cpp b/src/planner/operator/logical_update.cpp index edcfd5d891be..386264478417 100644 --- a/src/planner/operator/logical_update.cpp +++ b/src/planner/operator/logical_update.cpp @@ -13,6 +13,8 @@ LogicalUpdate::LogicalUpdate(ClientContext &context, const unique_ptr(context, table_info->catalog, table_info->schema, table_info->Cast().table)) { + auto binder = Binder::CreateBinder(context); + bound_constraints = binder->BindConstraints(table); } idx_t LogicalUpdate::EstimateCardinality(ClientContext &context) { diff --git a/src/storage/data_table.cpp b/src/storage/data_table.cpp index 39c855a20e2b..588b7c38139a 100644 --- a/src/storage/data_table.cpp +++ b/src/storage/data_table.cpp @@ -616,7 +616,7 @@ void DataTable::VerifyUniqueIndexes(TableIndexList &indexes, ClientContext &cont }); } -void DataTable::VerifyAppendConstraints(ConstraintVerificationState &state, ClientContext &context, DataChunk &chunk, +void DataTable::VerifyAppendConstraints(ConstraintState &state, ClientContext &context, DataChunk &chunk, optional_ptr conflict_manager) { auto &table = state.table; if (table.HasGeneratedColumns()) { @@ -675,22 +675,21 @@ void DataTable::VerifyAppendConstraints(ConstraintVerificationState &state, Clie } } -unique_ptr DataTable::InitializeConstraintVerification(TableCatalogEntry &table, - ClientContext &context) { - auto result = make_uniq(table); - auto binder = Binder::CreateBinder(context); - result->bound_constraints = binder->BindConstraints(table.GetConstraints(), table.name, table.GetColumns()); - return result; +unique_ptr +DataTable::InitializeConstraintState(TableCatalogEntry &table, + const vector> &bound_constraints) { + return make_uniq(table, bound_constraints); } -void DataTable::InitializeLocalAppend(LocalAppendState &state, TableCatalogEntry &table, ClientContext &context) { +void DataTable::InitializeLocalAppend(LocalAppendState &state, TableCatalogEntry &table, ClientContext &context, + const vector> &bound_constraints) { if (!is_root) { throw TransactionException("Transaction conflict: adding entries to a table that has been altered!"); } auto &local_storage = LocalStorage::Get(context, db); local_storage.InitializeAppend(state, *this); - state.constraint_state = InitializeConstraintVerification(table, context); + state.constraint_state = InitializeConstraintState(table, bound_constraints); } void DataTable::LocalAppend(LocalAppendState &state, TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, @@ -733,18 +732,20 @@ void DataTable::LocalMerge(ClientContext &context, RowGroupCollection &collectio local_storage.LocalMerge(*this, collection); } -void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk) { +void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, DataChunk &chunk, + const vector> &bound_constraints) { LocalAppendState append_state; auto &storage = table.GetStorage(); - storage.InitializeLocalAppend(append_state, table, context); + storage.InitializeLocalAppend(append_state, table, context, bound_constraints); storage.LocalAppend(append_state, table, context, chunk); storage.FinalizeLocalAppend(append_state); } -void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, ColumnDataCollection &collection) { +void DataTable::LocalAppend(TableCatalogEntry &table, ClientContext &context, ColumnDataCollection &collection, + const vector> &bound_constraints) { LocalAppendState append_state; auto &storage = table.GetStorage(); - storage.InitializeLocalAppend(append_state, table, context); + storage.InitializeLocalAppend(append_state, table, context, bound_constraints); for (auto &chunk : collection.Chunks()) { storage.LocalAppend(append_state, table, context, chunk); } @@ -990,7 +991,7 @@ static bool TableHasDeleteConstraints(TableCatalogEntry &table) { } void DataTable::VerifyDeleteConstraints(TableDeleteState &state, ClientContext &context, DataChunk &chunk) { - for (auto &constraint : state.bound_constraints) { + for (auto &constraint : state.constraint_state->bound_constraints) { switch (constraint->type) { case ConstraintType::NOT_NULL: case ConstraintType::CHECK: @@ -1010,12 +1011,12 @@ void DataTable::VerifyDeleteConstraints(TableDeleteState &state, ClientContext & } } -unique_ptr DataTable::InitializeDelete(TableCatalogEntry &table, ClientContext &context) { +unique_ptr DataTable::InitializeDelete(TableCatalogEntry &table, ClientContext &context, + const vector> &bound_constraints) { // initialize indexes (if any) info->InitializeIndexes(context, true); auto binder = Binder::CreateBinder(context); - vector> bound_constraints; vector types; auto result = make_uniq(); result->has_delete_constraints = TableHasDeleteConstraints(table); @@ -1026,7 +1027,7 @@ unique_ptr DataTable::InitializeDelete(TableCatalogEntry &tabl types.emplace_back(column_definitions[i].Type()); } result->verify_chunk.Initialize(Allocator::Get(context), types); - result->bound_constraints = binder->BindConstraints(table.GetConstraints(), table.name, table.GetColumns()); + result->constraint_state = make_uniq(table, bound_constraints); } return result; } @@ -1120,7 +1121,7 @@ static bool CreateMockChunk(TableCatalogEntry &table, const vector &column_ids) { auto &table = state.table; auto &constraints = table.GetConstraints(); @@ -1170,12 +1171,13 @@ void DataTable::VerifyUpdateConstraints(ConstraintVerificationState &state, Clie #endif } -unique_ptr DataTable::InitializeUpdate(TableCatalogEntry &table, ClientContext &context) { +unique_ptr DataTable::InitializeUpdate(TableCatalogEntry &table, ClientContext &context, + const vector> &bound_constraints) { // check that there are no unknown indexes info->InitializeIndexes(context, true); auto result = make_uniq(); - result->constraint_state = InitializeConstraintVerification(table, context); + result->constraint_state = InitializeConstraintState(table, bound_constraints); return result; } diff --git a/src/storage/wal_replay.cpp b/src/storage/wal_replay.cpp index ab819a017145..ddd0e7422187 100644 --- a/src/storage/wal_replay.cpp +++ b/src/storage/wal_replay.cpp @@ -659,7 +659,9 @@ void WriteAheadLogDeserializer::ReplayInsert() { } // append to the current table - state.current_table->GetStorage().LocalAppend(*state.current_table, context, chunk); + // we don't do any constraint verification here + vector> bound_constraints; + state.current_table->GetStorage().LocalAppend(*state.current_table, context, chunk, bound_constraints); } void WriteAheadLogDeserializer::ReplayDelete() { From 2bbebda1e09ef81969d9814a3521b30a74048ae6 Mon Sep 17 00:00:00 2001 From: Mark Raasveldt Date: Fri, 19 Apr 2024 13:11:19 +0200 Subject: [PATCH 192/201] Add missing includes --- src/execution/operator/persistent/physical_delete.cpp | 1 + .../duckdb/execution/operator/persistent/physical_delete.hpp | 1 + .../duckdb/execution/operator/persistent/physical_update.hpp | 1 + src/include/duckdb/planner/operator/logical_delete.hpp | 1 + src/include/duckdb/planner/operator/logical_update.hpp | 1 + src/storage/wal_replay.cpp | 1 + 6 files changed, 6 insertions(+) diff --git a/src/execution/operator/persistent/physical_delete.cpp b/src/execution/operator/persistent/physical_delete.cpp index bb2b1a76057a..83c92b67f655 100644 --- a/src/execution/operator/persistent/physical_delete.cpp +++ b/src/execution/operator/persistent/physical_delete.cpp @@ -7,6 +7,7 @@ #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/transaction/duck_transaction.hpp" #include "duckdb/storage/table/delete_state.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" namespace duckdb { diff --git a/src/include/duckdb/execution/operator/persistent/physical_delete.hpp b/src/include/duckdb/execution/operator/persistent/physical_delete.hpp index 740db21ad759..e9c9c24c0ce0 100644 --- a/src/include/duckdb/execution/operator/persistent/physical_delete.hpp +++ b/src/include/duckdb/execution/operator/persistent/physical_delete.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/execution/physical_operator.hpp" +#include "duckdb/planner/bound_constraint.hpp" namespace duckdb { class DataTable; diff --git a/src/include/duckdb/execution/operator/persistent/physical_update.hpp b/src/include/duckdb/execution/operator/persistent/physical_update.hpp index b064c8b2e98f..9556b48c92b6 100644 --- a/src/include/duckdb/execution/operator/persistent/physical_update.hpp +++ b/src/include/duckdb/execution/operator/persistent/physical_update.hpp @@ -10,6 +10,7 @@ #include "duckdb/execution/physical_operator.hpp" #include "duckdb/planner/expression.hpp" +#include "duckdb/planner/bound_constraint.hpp" namespace duckdb { class DataTable; diff --git a/src/include/duckdb/planner/operator/logical_delete.hpp b/src/include/duckdb/planner/operator/logical_delete.hpp index 9a243a1f730b..513370ea254b 100644 --- a/src/include/duckdb/planner/operator/logical_delete.hpp +++ b/src/include/duckdb/planner/operator/logical_delete.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/bound_constraint.hpp" namespace duckdb { class TableCatalogEntry; diff --git a/src/include/duckdb/planner/operator/logical_update.hpp b/src/include/duckdb/planner/operator/logical_update.hpp index e0356419b0d2..587341ab00c2 100644 --- a/src/include/duckdb/planner/operator/logical_update.hpp +++ b/src/include/duckdb/planner/operator/logical_update.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/bound_constraint.hpp" namespace duckdb { class TableCatalogEntry; diff --git a/src/storage/wal_replay.cpp b/src/storage/wal_replay.cpp index ddd0e7422187..f0d78083e422 100644 --- a/src/storage/wal_replay.cpp +++ b/src/storage/wal_replay.cpp @@ -25,6 +25,7 @@ #include "duckdb/common/checksum.hpp" #include "duckdb/execution/index/index_type_set.hpp" #include "duckdb/execution/index/art/art.hpp" +#include "duckdb/storage/table/delete_state.hpp" namespace duckdb { From 239a1d11a5143b8c9aa9ce68fbf09e5b0284a43d Mon Sep 17 00:00:00 2001 From: Pedro Holanda Date: Fri, 19 Apr 2024 13:22:04 +0200 Subject: [PATCH 193/201] Basics of implicit time/timestamp cast --- .../scanner/string_value_scanner.cpp | 43 ++++++++++--------- src/function/scalar/strftime_format.cpp | 23 +++++++--- .../function/scalar/strftime_format.hpp | 4 ++ 3 files changed, 43 insertions(+), 27 deletions(-) diff --git a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp index 9e6271c82818..438fb1b8f0c9 100644 --- a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp +++ b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp @@ -215,16 +215,32 @@ void StringValueResult::AddValueToVector(const char *value_ptr, const idx_t size false, state_machine.options.decimal_separator[0]); break; case LogicalTypeId::DATE: { - idx_t pos; - bool special; - success = Date::TryConvertDate(value_ptr, size, pos, - static_cast(vector_ptr[chunk_col_id])[number_of_rows], special, false); + if (!state_machine.dialect_options.date_format.find(LogicalTypeId::DATE)->second.GetValue().Empty()) { + string error_message; + success = state_machine.dialect_options.date_format.find(LogicalTypeId::DATE) + ->second.GetValue() + .TryParseDate(value_ptr, size, + static_cast(vector_ptr[chunk_col_id])[number_of_rows], error_message); + } else { + idx_t pos; + bool special; + success = Date::TryConvertDate( + value_ptr, size, pos, static_cast(vector_ptr[chunk_col_id])[number_of_rows], special, false); + } break; } case LogicalTypeId::TIMESTAMP: { - success = Timestamp::TryConvertTimestamp( - value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows]) == - TimestampCastResult::SUCCESS; + if (!state_machine.dialect_options.date_format.find(LogicalTypeId::TIMESTAMP)->second.GetValue().Empty()) { + timestamp_t result; + string error_message; + // success = state_machine.dialect_options.date_format.find(LogicalTypeId::TIMESTAMP) + // ->second.GetValue() + // .TryParseTimestamp(value, result, error_message); + } else { + success = Timestamp::TryConvertTimestamp( + value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows]) == + TimestampCastResult::SUCCESS; + } break; } default: { @@ -1185,7 +1201,6 @@ bool StringValueScanner::CanDirectlyCast(const LogicalType &type, const map> &format_options) { switch (type.id()) { - // All Integers (Except HugeInt) case LogicalTypeId::TINYINT: case LogicalTypeId::SMALLINT: case LogicalTypeId::INTEGER: @@ -1196,20 +1211,8 @@ bool StringValueScanner::CanDirectlyCast(const LogicalType &type, case LogicalTypeId::UBIGINT: case LogicalTypeId::DOUBLE: case LogicalTypeId::FLOAT: - return true; case LogicalTypeId::DATE: - // We can only internally cast YYYY-MM-DD - if (format_options.at(LogicalTypeId::DATE).GetValue().format_specifier == "%Y-%m-%d") { - return true; - } else { - return false; - } case LogicalTypeId::TIMESTAMP: - if (format_options.at(LogicalTypeId::TIMESTAMP).GetValue().format_specifier == "%Y-%m-%d %H:%M:%S") { - return true; - } else { - return false; - } case LogicalType::VARCHAR: return true; default: diff --git a/src/function/scalar/strftime_format.cpp b/src/function/scalar/strftime_format.cpp index f0ec24920a59..4a8416f6e8e3 100644 --- a/src/function/scalar/strftime_format.cpp +++ b/src/function/scalar/strftime_format.cpp @@ -740,8 +740,7 @@ int32_t StrpTimeFormat::TryParseCollection(const char *data, idx_t &pos, idx_t s return -1; } -//! Parses a timestamp using the given specifier -bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { +bool StrpTimeFormat::Parse(const char *data, size_t size, ParseResult &result) const { auto &result_data = result.data; auto &error_message = result.error_message; auto &error_position = result.error_position; @@ -755,15 +754,11 @@ bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { result_data[5] = 0; result_data[6] = 0; result_data[7] = 0; - - auto data = str.GetData(); - idx_t size = str.GetSize(); // skip leading spaces while (StringUtil::CharacterIsSpace(*data)) { data++; size--; } - // Check for specials // Precheck for alphas for performance. idx_t pos = 0; @@ -1067,7 +1062,6 @@ bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { case StrTimeSpecifier::YEAR_DECIMAL: // Just validate, don't use break; - break; case StrTimeSpecifier::WEEKDAY_DECIMAL: // First offset specifier offset_specifier = specifiers[i]; @@ -1324,6 +1318,13 @@ bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { return true; } +//! Parses a timestamp using the given specifier +bool StrpTimeFormat::Parse(string_t str, ParseResult &result) const { + auto data = str.GetData(); + idx_t size = str.GetSize(); + return Parse(data, size, result); +} + StrpTimeFormat::ParseResult StrpTimeFormat::Parse(const string &format_string, const string &text) { StrpTimeFormat format; format.format_specifier = format_string; @@ -1413,6 +1414,14 @@ bool StrpTimeFormat::TryParseDate(string_t input, date_t &result, string &error_ return parse_result.TryToDate(result); } +bool StrpTimeFormat::TryParseDate(const char *data, size_t size, date_t &result, string &error_message) const { + ParseResult parse_result; + if (!Parse(data, size, parse_result)) { + return false; + } + return parse_result.TryToDate(result); +} + bool StrpTimeFormat::TryParseTime(string_t input, dtime_t &result, string &error_message) const { ParseResult parse_result; if (!Parse(input, parse_result)) { diff --git a/src/include/duckdb/function/scalar/strftime_format.hpp b/src/include/duckdb/function/scalar/strftime_format.hpp index a7b3addde7d2..f65e10131d5e 100644 --- a/src/include/duckdb/function/scalar/strftime_format.hpp +++ b/src/include/duckdb/function/scalar/strftime_format.hpp @@ -161,6 +161,10 @@ struct StrpTimeFormat : public StrTimeFormat { // NOLINT: work-around bug in cla DUCKDB_API bool Parse(string_t str, ParseResult &result) const; + DUCKDB_API bool Parse(const char *data, size_t size, ParseResult &result) const; + + DUCKDB_API bool TryParseDate(const char *data, size_t size, date_t &result, string &error_message) const; + DUCKDB_API bool TryParseDate(string_t str, date_t &result, string &error_message) const; DUCKDB_API bool TryParseTime(string_t str, dtime_t &result, string &error_message) const; DUCKDB_API bool TryParseTimestamp(string_t str, timestamp_t &result, string &error_message) const; From d78effa2ae09efb373e1b038f289004a42dd81c3 Mon Sep 17 00:00:00 2001 From: Mark Raasveldt Date: Fri, 19 Apr 2024 13:32:41 +0200 Subject: [PATCH 194/201] Extension fixes --- .github/config/out_of_tree_extensions.cmake | 1 + .../postgres_scanner/shared_ptr.patch | 27 +++++++++++++++++++ .../sqlite_scanner/binder_update.patch | 27 +++++++++++++++++++ 3 files changed, 55 insertions(+) create mode 100644 .github/patches/extensions/sqlite_scanner/binder_update.patch diff --git a/.github/config/out_of_tree_extensions.cmake b/.github/config/out_of_tree_extensions.cmake index 7b489f8c14ff..159ed15ebe6f 100644 --- a/.github/config/out_of_tree_extensions.cmake +++ b/.github/config/out_of_tree_extensions.cmake @@ -94,6 +94,7 @@ duckdb_extension_load(sqlite_scanner ${STATIC_LINK_SQLITE} LOAD_TESTS GIT_URL https://github.com/duckdb/sqlite_scanner GIT_TAG 091197efb34579c7195afa43dfb5925023c915c0 + APPLY_PATCHES ) ################# SUBSTRAIT diff --git a/.github/patches/extensions/postgres_scanner/shared_ptr.patch b/.github/patches/extensions/postgres_scanner/shared_ptr.patch index 98d351393f23..e94920829735 100644 --- a/.github/patches/extensions/postgres_scanner/shared_ptr.patch +++ b/.github/patches/extensions/postgres_scanner/shared_ptr.patch @@ -58,6 +58,19 @@ index e20a803..4fe45f6 100644 namespace duckdb { struct DropInfo; +diff --git a/src/include/storage/postgres_table_entry.hpp b/src/include/storage/postgres_table_entry.hpp +index d96dfad..529c234 100644 +--- a/src/include/storage/postgres_table_entry.hpp ++++ b/src/include/storage/postgres_table_entry.hpp +@@ -50,7 +50,7 @@ public: + + TableStorageInfo GetStorageInfo(ClientContext &context) override; + +- void BindUpdateConstraints(LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, ++ void BindUpdateConstraints(Binder &binder, LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, + ClientContext &context) override; + + //! Get the copy format (text or binary) that should be used when writing data to this table diff --git a/src/postgres_binary_copy.cpp b/src/postgres_binary_copy.cpp index f0d86a3..4c89c3f 100644 --- a/src/postgres_binary_copy.cpp @@ -212,3 +225,17 @@ index 93c3f28..cd3b46f 100644 namespace duckdb { +diff --git a/src/storage/postgres_table_entry.cpp b/src/storage/postgres_table_entry.cpp +index d791678..7ba1ad6 100644 +--- a/src/storage/postgres_table_entry.cpp ++++ b/src/storage/postgres_table_entry.cpp +@@ -31,7 +31,8 @@ unique_ptr PostgresTableEntry::GetStatistics(ClientContext &cont + return nullptr; + } + +-void PostgresTableEntry::BindUpdateConstraints(LogicalGet &, LogicalProjection &, LogicalUpdate &, ClientContext &) { ++void PostgresTableEntry::BindUpdateConstraints(Binder &binder, LogicalGet &, LogicalProjection &, LogicalUpdate &, ++ ClientContext &) { + } + + TableFunction PostgresTableEntry::GetScanFunction(ClientContext &context, unique_ptr &bind_data) { diff --git a/.github/patches/extensions/sqlite_scanner/binder_update.patch b/.github/patches/extensions/sqlite_scanner/binder_update.patch new file mode 100644 index 000000000000..7973f54281e5 --- /dev/null +++ b/.github/patches/extensions/sqlite_scanner/binder_update.patch @@ -0,0 +1,27 @@ +diff --git a/src/include/storage/sqlite_table_entry.hpp b/src/include/storage/sqlite_table_entry.hpp +index 6e64d55..b08319b 100644 +--- a/src/include/storage/sqlite_table_entry.hpp ++++ b/src/include/storage/sqlite_table_entry.hpp +@@ -25,7 +25,7 @@ public: + + TableStorageInfo GetStorageInfo(ClientContext &context) override; + +- void BindUpdateConstraints(LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, ++ void BindUpdateConstraints(Binder &binder, LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, + ClientContext &context) override; + }; + +diff --git a/src/storage/sqlite_table_entry.cpp b/src/storage/sqlite_table_entry.cpp +index fadbb39..47378b0 100644 +--- a/src/storage/sqlite_table_entry.cpp ++++ b/src/storage/sqlite_table_entry.cpp +@@ -16,7 +16,8 @@ unique_ptr SQLiteTableEntry::GetStatistics(ClientContext &contex + return nullptr; + } + +-void SQLiteTableEntry::BindUpdateConstraints(LogicalGet &, LogicalProjection &, LogicalUpdate &, ClientContext &) { ++void SQLiteTableEntry::BindUpdateConstraints(Binder &, LogicalGet &, LogicalProjection &, LogicalUpdate &, ++ ClientContext &) { + } + + TableFunction SQLiteTableEntry::GetScanFunction(ClientContext &context, unique_ptr &bind_data) { From 09662ad5def7811a1bc5da642545b86178532de7 Mon Sep 17 00:00:00 2001 From: Pedro Holanda Date: Fri, 19 Apr 2024 13:38:01 +0200 Subject: [PATCH 195/201] Finishing up the work of implicit casting and fixing up error messages --- .../scanner/string_value_scanner.cpp | 18 ++++++++---------- src/function/scalar/strftime_format.cpp | 10 +++++++++- .../duckdb/function/scalar/strftime_format.hpp | 3 ++- test/sql/copy/csv/csv_hive_filename_union.test | 2 +- .../csv/rejects/csv_rejects_flush_cast.test | 4 ++-- .../test_multiple_errors_same_line.test | 10 +++++++--- test/sql/copy/csv/timestamp_with_tz.test | 2 +- 7 files changed, 30 insertions(+), 19 deletions(-) diff --git a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp index 438fb1b8f0c9..473a07a93c4c 100644 --- a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp +++ b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp @@ -216,11 +216,10 @@ void StringValueResult::AddValueToVector(const char *value_ptr, const idx_t size break; case LogicalTypeId::DATE: { if (!state_machine.dialect_options.date_format.find(LogicalTypeId::DATE)->second.GetValue().Empty()) { - string error_message; - success = state_machine.dialect_options.date_format.find(LogicalTypeId::DATE) - ->second.GetValue() - .TryParseDate(value_ptr, size, - static_cast(vector_ptr[chunk_col_id])[number_of_rows], error_message); + success = + state_machine.dialect_options.date_format.find(LogicalTypeId::DATE) + ->second.GetValue() + .TryParseDate(value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows]); } else { idx_t pos; bool special; @@ -231,11 +230,10 @@ void StringValueResult::AddValueToVector(const char *value_ptr, const idx_t size } case LogicalTypeId::TIMESTAMP: { if (!state_machine.dialect_options.date_format.find(LogicalTypeId::TIMESTAMP)->second.GetValue().Empty()) { - timestamp_t result; - string error_message; - // success = state_machine.dialect_options.date_format.find(LogicalTypeId::TIMESTAMP) - // ->second.GetValue() - // .TryParseTimestamp(value, result, error_message); + success = state_machine.dialect_options.date_format.find(LogicalTypeId::TIMESTAMP) + ->second.GetValue() + .TryParseTimestamp(value_ptr, size, + static_cast(vector_ptr[chunk_col_id])[number_of_rows]); } else { success = Timestamp::TryConvertTimestamp( value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows]) == diff --git a/src/function/scalar/strftime_format.cpp b/src/function/scalar/strftime_format.cpp index 4a8416f6e8e3..f8ff39c37247 100644 --- a/src/function/scalar/strftime_format.cpp +++ b/src/function/scalar/strftime_format.cpp @@ -1414,7 +1414,7 @@ bool StrpTimeFormat::TryParseDate(string_t input, date_t &result, string &error_ return parse_result.TryToDate(result); } -bool StrpTimeFormat::TryParseDate(const char *data, size_t size, date_t &result, string &error_message) const { +bool StrpTimeFormat::TryParseDate(const char *data, size_t size, date_t &result) const { ParseResult parse_result; if (!Parse(data, size, parse_result)) { return false; @@ -1440,4 +1440,12 @@ bool StrpTimeFormat::TryParseTimestamp(string_t input, timestamp_t &result, stri return parse_result.TryToTimestamp(result); } +bool StrpTimeFormat::TryParseTimestamp(const char *data, size_t size, timestamp_t &result) const { + ParseResult parse_result; + if (!Parse(data, size, parse_result)) { + return false; + } + return parse_result.TryToTimestamp(result); +} + } // namespace duckdb diff --git a/src/include/duckdb/function/scalar/strftime_format.hpp b/src/include/duckdb/function/scalar/strftime_format.hpp index f65e10131d5e..a538db4255b3 100644 --- a/src/include/duckdb/function/scalar/strftime_format.hpp +++ b/src/include/duckdb/function/scalar/strftime_format.hpp @@ -163,7 +163,8 @@ struct StrpTimeFormat : public StrTimeFormat { // NOLINT: work-around bug in cla DUCKDB_API bool Parse(const char *data, size_t size, ParseResult &result) const; - DUCKDB_API bool TryParseDate(const char *data, size_t size, date_t &result, string &error_message) const; + DUCKDB_API bool TryParseDate(const char *data, size_t size, date_t &result) const; + DUCKDB_API bool TryParseTimestamp(const char *data, size_t size, timestamp_t &result) const; DUCKDB_API bool TryParseDate(string_t str, date_t &result, string &error_message) const; DUCKDB_API bool TryParseTime(string_t str, dtime_t &result, string &error_message) const; diff --git a/test/sql/copy/csv/csv_hive_filename_union.test b/test/sql/copy/csv/csv_hive_filename_union.test index 79f84efcdfc0..ff145d929f77 100644 --- a/test/sql/copy/csv/csv_hive_filename_union.test +++ b/test/sql/copy/csv/csv_hive_filename_union.test @@ -56,7 +56,7 @@ xxx 42 statement error select * from read_csv_auto(['data/csv/hive-partitioning/mismatching_contents/part=1/test.csv', 'data/csv/hive-partitioning/mismatching_contents/part=2/test.csv']) order by 1 ---- -date field value out of range +Error when converting column "c". Could not convert string "world" to 'DATE' query III select a, b, c from read_csv_auto('data/csv/hive-partitioning/mismatching_contents/*/*.csv', UNION_BY_NAME=1) order by 2 NULLS LAST diff --git a/test/sql/copy/csv/rejects/csv_rejects_flush_cast.test b/test/sql/copy/csv/rejects/csv_rejects_flush_cast.test index ba48d9fe2a99..09daa735d6eb 100644 --- a/test/sql/copy/csv/rejects/csv_rejects_flush_cast.test +++ b/test/sql/copy/csv/rejects/csv_rejects_flush_cast.test @@ -20,6 +20,6 @@ DATE VARCHAR 2811 query IIIIIIIII SELECT * EXCLUDE (scan_id) FROM reject_errors order by all; ---- -0 439 6997 NULL 1 a CAST B, bla Error when converting column "a". Could not parse string "B" according to format specifier "%d-%m-%Y" -0 2813 44972 NULL 1 a CAST c, bla Error when converting column "a". Could not parse string "c" according to format specifier "%d-%m-%Y" +0 439 6997 6997 1 a CAST B, bla Error when converting column "a". Could not convert string "B" to 'DATE' +0 2813 44972 44972 1 a CAST c, bla Error when converting column "a". Could not convert string "c" to 'DATE' diff --git a/test/sql/copy/csv/rejects/test_multiple_errors_same_line.test b/test/sql/copy/csv/rejects/test_multiple_errors_same_line.test index 6d9f1fcb5baa..85c3140097bc 100644 --- a/test/sql/copy/csv/rejects/test_multiple_errors_same_line.test +++ b/test/sql/copy/csv/rejects/test_multiple_errors_same_line.test @@ -61,8 +61,8 @@ oogie boogie 3 2023-01-02 2023-01-03 query IIIIIIIII rowsort SElECT * EXCLUDE (scan_id) FROM reject_errors ORDER BY ALL; ---- -0 4 110 NULL 3 current_day CAST oogie boogie,3, bla_2, bla_1 Error when converting column "current_day". date field value out of range: " bla_2", expected format is (YYYY-MM-DD) -0 4 110 NULL 4 tomorrow CAST oogie boogie,3, bla_2, bla_1 Error when converting column "tomorrow". date field value out of range: " bla_1", expected format is (YYYY-MM-DD) +0 4 110 125 3 current_day CAST oogie boogie,3, bla_2, bla_1 Error when converting column "current_day". Could not convert string " bla_2" to 'DATE' +0 4 110 132 4 tomorrow CAST oogie boogie,3, bla_2, bla_1 Error when converting column "tomorrow". Could not convert string " bla_1" to 'DATE' statement ok DROP TABLE reject_errors; @@ -82,6 +82,7 @@ oogie boogie 3 2023-01-02 5 query IIIIIIIII rowsort SElECT * EXCLUDE (scan_id) FROM reject_errors ORDER BY ALL; ---- +0 4 89 104 3 current_day CAST oogie boogie,3, bla_2, bla_1 Error when converting column "current_day". Could not convert string " bla_2" to 'DATE' 0 4 89 111 4 barks CAST oogie boogie,3, bla_2, bla_1 Error when converting column "barks". Could not convert string " bla_1" to 'INTEGER' statement ok @@ -370,10 +371,13 @@ SELECT * EXCLUDE (scan_id) FROM reject_errors ORDER BY ALL; ---- 0 4 89 116 4 barks CAST oogie boogie,3, 2023-01-03, bla, 7 Error when converting column "barks". Could not convert string " bla" to 'INTEGER' 0 4 89 120 5 NULL TOO MANY COLUMNS oogie boogie,3, 2023-01-03, bla, 7 Expected Number of Columns: 4 Found: 5 +0 5 124 139 3 current_day CAST oogie boogie,3, bla, bla, 7 Error when converting column "current_day". Could not convert string " bla" to 'DATE' 0 5 124 144 4 barks CAST oogie boogie,3, bla, bla, 7 Error when converting column "barks". Could not convert string " bla" to 'INTEGER' 0 5 124 148 5 NULL TOO MANY COLUMNS oogie boogie,3, bla, bla, 7 Expected Number of Columns: 4 Found: 5 0 6 152 152 1 name UNQUOTED VALUE "oogie boogie"bla,3, 2023-01-04 Value with unterminated quote found. 0 6 152 183 3 barks MISSING COLUMNS "oogie boogie"bla,3, 2023-01-04 Expected Number of Columns: 4 Found: 3 +0 7 184 199 3 current_day CAST oogie boogie,3, bla Error when converting column "current_day". Could not convert string " bla" to 'DATE' 0 7 184 203 3 barks MISSING COLUMNS oogie boogie,3, bla Expected Number of Columns: 4 Found: 3 0 8 204 204 NULL NULL LINE SIZE OVER MAXIMUM oogie boogieoogie boogieoogie boogieoogie boogieoogie boogieoogie boogieoogie boogie,3, bla Maximum line size of 40 bytes exceeded. Actual Size:92 bytes. -0 8 204 295 3 barks MISSING COLUMNS oogie boogieoogie boogieoogie boogieoogie boogieoogie boogieoogie boogieoogie boogie,3, bla Expected Number of Columns: 4 Found: 3 +0 8 204 291 3 current_day CAST oogie boogieoogie boogieoogie boogieoogie boogieoogie boogieoogie boogieoogie boogie,3, bla Error when converting column "current_day". Could not convert string " bla" to 'DATE' +0 8 204 295 3 barks MISSING COLUMNS oogie boogieoogie boogieoogie boogieoogie boogieoogie boogieoogie boogieoogie boogie,3, bla Expected Number of Columns: 4 Found: 3 \ No newline at end of file diff --git a/test/sql/copy/csv/timestamp_with_tz.test b/test/sql/copy/csv/timestamp_with_tz.test index 9478957d222e..b8dbd3de4e74 100644 --- a/test/sql/copy/csv/timestamp_with_tz.test +++ b/test/sql/copy/csv/timestamp_with_tz.test @@ -9,7 +9,7 @@ CREATE TABLE tbl(id int, ts timestamp); statement error COPY tbl FROM 'data/csv/timestamp_with_tz.csv' (HEADER) ---- -timestamp that is not UTC +Error when converting column "ts". Could not convert string "2021-05-25 04:55:03.382494 EST" to 'TIMESTAMP' require icu From 58c44f6d14bd4059cddf2e1f0a828d54a28cf647 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 19 Apr 2024 13:56:39 +0200 Subject: [PATCH 196/201] return properly typed NULL value, fix up python tests and behavior --- src/core_functions/scalar/map/map.cpp | 29 ++++++++++----- .../src/native/python_conversion.cpp | 37 +++++++++++++++++-- tools/pythonpkg/src/pandas/analyzer.cpp | 4 ++ .../fast/pandas/test_df_object_resolution.py | 5 +-- 4 files changed, 60 insertions(+), 15 deletions(-) diff --git a/src/core_functions/scalar/map/map.cpp b/src/core_functions/scalar/map/map.cpp index 62f54cdd25d0..664946267cee 100644 --- a/src/core_functions/scalar/map/map.cpp +++ b/src/core_functions/scalar/map/map.cpp @@ -21,21 +21,28 @@ static void MapFunctionEmptyInput(Vector &result, const idx_t row_count) { result.Verify(row_count); } +static bool MapIsNull(const LogicalType &map) { + D_ASSERT(map.id() == LogicalTypeId::MAP); + auto &key = MapType::KeyType(map); + auto &value = MapType::ValueType(map); + return (key.id() == LogicalTypeId::SQLNULL && value.id() == LogicalTypeId::SQLNULL); +} + static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { // internal MAP representation // - LIST-vector that contains STRUCTs as child entries // - STRUCTs have exactly two fields, a key-field, and a value-field // - key names are unique + D_ASSERT(result.GetType().id() == LogicalTypeId::MAP); - if (result.GetType().id() == LogicalTypeId::SQLNULL) { + if (MapIsNull(result.GetType())) { auto &validity = FlatVector::Validity(result); validity.SetInvalid(0); result.SetVectorType(VectorType::CONSTANT_VECTOR); return; } - D_ASSERT(result.GetType().id() == LogicalTypeId::MAP); auto row_count = args.size(); // early-out, if no data @@ -162,16 +169,20 @@ static unique_ptr MapBind(ClientContext &, ScalarFunction &bound_f MapVector::EvalMapInvalidReason(MapInvalidReason::INVALID_PARAMS); } - // bind an empty MAP + bool is_null = false; if (arguments.empty()) { - bound_function.return_type = LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL); - return make_uniq(bound_function.return_type); + is_null = true; + } + if (!is_null) { + auto key_id = arguments[0]->return_type.id(); + auto value_id = arguments[1]->return_type.id(); + if (key_id == LogicalTypeId::SQLNULL || value_id == LogicalTypeId::SQLNULL) { + is_null = true; + } } - auto key_id = arguments[0]->return_type.id(); - auto value_id = arguments[1]->return_type.id(); - if (key_id == LogicalTypeId::SQLNULL || value_id == LogicalTypeId::SQLNULL) { - bound_function.return_type = LogicalTypeId::SQLNULL; + if (is_null) { + bound_function.return_type = LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL); return make_uniq(bound_function.return_type); } diff --git a/tools/pythonpkg/src/native/python_conversion.cpp b/tools/pythonpkg/src/native/python_conversion.cpp index cbdbb9f57f00..133d3fb768f1 100644 --- a/tools/pythonpkg/src/native/python_conversion.cpp +++ b/tools/pythonpkg/src/native/python_conversion.cpp @@ -37,6 +37,20 @@ vector TransformStructKeys(py::handle keys, idx_t size, const LogicalTyp return res; } +static bool IsValidMapComponent(const py::handle &component) { + // The component is either NULL + if (py::none().is(component)) { + return true; + } + if (!py::hasattr(component, "__getitem__")) { + return false; + } + if (!py::hasattr(component, "__len__")) { + return false; + } + return true; +} + bool DictionaryHasMapFormat(const PyDictionary &dict) { if (dict.len != 2) { return false; @@ -51,13 +65,19 @@ bool DictionaryHasMapFormat(const PyDictionary &dict) { return false; } - // Dont check for 'py::list' to allow ducktyping - if (!py::hasattr(keys, "__getitem__") || !py::hasattr(keys, "__len__")) { + if (!IsValidMapComponent(keys)) { return false; } - if (!py::hasattr(values, "__getitem__") || !py::hasattr(values, "__len__")) { + if (!IsValidMapComponent(values)) { return false; } + + // If either of the components is NULL, return early + if (py::none().is(keys) || py::none().is(values)) { + return true; + } + + // Verify that both the keys and values are of the same length auto size = py::len(keys); if (size != py::len(values)) { return false; @@ -91,6 +111,11 @@ Value TransformStructFormatDictionaryToMap(const PyDictionary &dict, const Logic if (target_type.id() != LogicalTypeId::MAP) { throw InvalidInputException("Please provide a valid target type for transform from Python to Value"); } + + if (py::none().is(dict.keys) || py::none().is(dict.values)) { + return Value(LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL)); + } + auto size = py::len(dict.keys); D_ASSERT(size == py::len(dict.values)); @@ -130,12 +155,18 @@ Value TransformDictionaryToMap(const PyDictionary &dict, const LogicalType &targ auto keys = dict.values.attr("__getitem__")(0); auto values = dict.values.attr("__getitem__")(1); + if (py::none().is(keys) || py::none().is(values)) { + // Either 'key' or 'value' is None, return early with a NULL value + return Value(LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL)); + } + auto key_size = py::len(keys); D_ASSERT(key_size == py::len(values)); if (key_size == 0) { // dict == { 'key': [], 'value': [] } return EmptyMapValue(); } + // dict == { 'key': [ ... ], 'value' : [ ... ] } LogicalType key_target = LogicalTypeId::UNKNOWN; LogicalType value_target = LogicalTypeId::UNKNOWN; diff --git a/tools/pythonpkg/src/pandas/analyzer.cpp b/tools/pythonpkg/src/pandas/analyzer.cpp index 508270894403..660d1fb2b3d2 100644 --- a/tools/pythonpkg/src/pandas/analyzer.cpp +++ b/tools/pythonpkg/src/pandas/analyzer.cpp @@ -331,6 +331,10 @@ LogicalType PandasAnalyzer::DictToMap(const PyDictionary &dict, bool &can_conver auto keys = dict.values.attr("__getitem__")(0); auto values = dict.values.attr("__getitem__")(1); + if (py::none().is(keys) || py::none().is(values)) { + return LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL); + } + auto key_type = GetListType(keys, can_convert); if (!can_convert) { return EmptyMap(); diff --git a/tools/pythonpkg/tests/fast/pandas/test_df_object_resolution.py b/tools/pythonpkg/tests/fast/pandas/test_df_object_resolution.py index 0f20f9fe0309..1a07e47fc8f4 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_df_object_resolution.py +++ b/tools/pythonpkg/tests/fast/pandas/test_df_object_resolution.py @@ -324,7 +324,7 @@ def test_map_duplicate(self, pandas, duckdb_cursor): with pytest.raises( duckdb.InvalidInputException, match="Dict->Map conversion failed because 'key' list contains duplicates" ): - converted_col = duckdb_cursor.sql("select * from x").df() + duckdb_cursor.sql("select * from x").show() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_map_nullkey(self, pandas, duckdb_cursor): @@ -337,9 +337,8 @@ def test_map_nullkey(self, pandas, duckdb_cursor): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_map_nullkeylist(self, pandas, duckdb_cursor): x = pandas.DataFrame([[{'key': None, 'value': None}]]) - # Isn't actually converted to MAP because isinstance(None, list) != True converted_col = duckdb_cursor.sql("select * from x").df() - duckdb_col = duckdb_cursor.sql("SELECT {key: NULL, value: NULL} as '0'").df() + duckdb_col = duckdb_cursor.sql("SELECT MAP(NULL, NULL) as '0'").df() pandas.testing.assert_frame_equal(duckdb_col, converted_col) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) From 9f57957c13bbdd798b82d99ce71696d884478cc3 Mon Sep 17 00:00:00 2001 From: Pedro Holanda Date: Fri, 19 Apr 2024 13:59:05 +0200 Subject: [PATCH 197/201] Removing 4x sleep function --- benchmark/micro/csv/time_type.benchmark | 0 .../scanner/string_value_scanner.cpp | 18 ++++++++---------- .../csv_scanner/string_value_scanner.hpp | 2 +- 3 files changed, 9 insertions(+), 11 deletions(-) create mode 100644 benchmark/micro/csv/time_type.benchmark diff --git a/benchmark/micro/csv/time_type.benchmark b/benchmark/micro/csv/time_type.benchmark new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp index 473a07a93c4c..49d761ce6ca9 100644 --- a/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp +++ b/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp @@ -97,6 +97,8 @@ StringValueResult::StringValueResult(CSVStates &states, CSVStateMachine &state_m null_str_ptr[i] = state_machine.options.null_str[i].c_str(); null_str_size[i] = state_machine.options.null_str[i].size(); } + date_format = state_machine.options.dialect_options.date_format.at(LogicalTypeId::DATE).GetValue(); + timestamp_format = state_machine.options.dialect_options.date_format.at(LogicalTypeId::TIMESTAMP).GetValue(); } StringValueResult::~StringValueResult() { @@ -215,11 +217,9 @@ void StringValueResult::AddValueToVector(const char *value_ptr, const idx_t size false, state_machine.options.decimal_separator[0]); break; case LogicalTypeId::DATE: { - if (!state_machine.dialect_options.date_format.find(LogicalTypeId::DATE)->second.GetValue().Empty()) { - success = - state_machine.dialect_options.date_format.find(LogicalTypeId::DATE) - ->second.GetValue() - .TryParseDate(value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows]); + if (!date_format.Empty()) { + success = date_format.TryParseDate(value_ptr, size, + static_cast(vector_ptr[chunk_col_id])[number_of_rows]); } else { idx_t pos; bool special; @@ -229,11 +229,9 @@ void StringValueResult::AddValueToVector(const char *value_ptr, const idx_t size break; } case LogicalTypeId::TIMESTAMP: { - if (!state_machine.dialect_options.date_format.find(LogicalTypeId::TIMESTAMP)->second.GetValue().Empty()) { - success = state_machine.dialect_options.date_format.find(LogicalTypeId::TIMESTAMP) - ->second.GetValue() - .TryParseTimestamp(value_ptr, size, - static_cast(vector_ptr[chunk_col_id])[number_of_rows]); + if (!timestamp_format.Empty()) { + success = timestamp_format.TryParseTimestamp( + value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows]); } else { success = Timestamp::TryConvertTimestamp( value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows]) == diff --git a/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp b/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp index 9ad42fd16635..a1fa130e4ffd 100644 --- a/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp +++ b/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp @@ -145,7 +145,7 @@ class StringValueResult : public ScannerResult { //! Errors happening in the current line (if any) vector current_errors; - + StrpTimeFormat date_format, timestamp_format; bool sniffing; //! Specialized code for quoted values, makes sure to remove quotes and escapes static inline void AddQuotedValue(StringValueResult &result, const idx_t buffer_pos); From f9ebe008a3ae0475ef829b98fa9fec5402e62c59 Mon Sep 17 00:00:00 2001 From: Pedro Holanda Date: Fri, 19 Apr 2024 13:59:42 +0200 Subject: [PATCH 198/201] Add Benchmark --- benchmark/micro/csv/time_type.benchmark | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/benchmark/micro/csv/time_type.benchmark b/benchmark/micro/csv/time_type.benchmark index e69de29bb2d1..98f4a556bb06 100644 --- a/benchmark/micro/csv/time_type.benchmark +++ b/benchmark/micro/csv/time_type.benchmark @@ -0,0 +1,14 @@ +# name: benchmark/micro/csv/time_type.benchmark +# description: Run CSV scan with timestamp and date types +# group: [csv] + +name CSV Read Benchmark with timestamp and date types +group csv + +load +CREATE TABLE t1 AS select '30/07/1992', '30/07/1992 17:15:30'; +insert into t1 select '30/07/1992', '30/07/1992 17:15:30' from range(0,10000000) tbl(i); +COPY t1 TO '${BENCHMARK_DIR}/time_timestamp.csv' (FORMAT CSV, HEADER 0); + +run +SELECT * from read_csv('${BENCHMARK_DIR}/time_timestamp.csv',delim= ',', header = 0) From 788e25787f2083ae855266320b6acc9a70e2f806 Mon Sep 17 00:00:00 2001 From: Pedro Holanda Date: Fri, 19 Apr 2024 14:02:38 +0200 Subject: [PATCH 199/201] Add it to csv benchmarks --- .github/regression/csv.csv | 3 ++- .github/regression/micro_extended.csv | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/regression/csv.csv b/.github/regression/csv.csv index 272b7ced2a04..ddfe10bad1e8 100644 --- a/.github/regression/csv.csv +++ b/.github/regression/csv.csv @@ -4,4 +4,5 @@ benchmark/micro/csv/small_csv.benchmark benchmark/micro/csv/null_padding.benchmark benchmark/micro/csv/projection_pushdown.benchmark benchmark/micro/csv/1_byte_values.benchmark -benchmark/micro/csv/16_byte_values.benchmark \ No newline at end of file +benchmark/micro/csv/16_byte_values.benchmark +benchmark/micro/csv/time_type.benchmark \ No newline at end of file diff --git a/.github/regression/micro_extended.csv b/.github/regression/micro_extended.csv index 6973785b4c98..365347ad3ad3 100644 --- a/.github/regression/micro_extended.csv +++ b/.github/regression/micro_extended.csv @@ -86,6 +86,7 @@ benchmark/micro/csv/read.benchmark benchmark/micro/csv/small_csv.benchmark benchmark/micro/csv/sniffer.benchmark benchmark/micro/csv/sniffer_quotes.benchmark +benchmark/micro/csv/time_type.benchmark benchmark/micro/cte/cte.benchmark benchmark/micro/cte/materialized_cte.benchmark benchmark/micro/date/extract_month.benchmark From d54e152722a577ec8406f843318601d091d814d0 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 19 Apr 2024 16:12:30 +0200 Subject: [PATCH 200/201] fix compilation --- src/core_functions/scalar/map/map.cpp | 22 ++++++++++++++----- .../types/nested/map/test_map_subscript.test | 3 +++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/core_functions/scalar/map/map.cpp b/src/core_functions/scalar/map/map.cpp index 664946267cee..ab67475d151b 100644 --- a/src/core_functions/scalar/map/map.cpp +++ b/src/core_functions/scalar/map/map.cpp @@ -21,11 +21,21 @@ static void MapFunctionEmptyInput(Vector &result, const idx_t row_count) { result.Verify(row_count); } -static bool MapIsNull(const LogicalType &map) { - D_ASSERT(map.id() == LogicalTypeId::MAP); - auto &key = MapType::KeyType(map); - auto &value = MapType::ValueType(map); - return (key.id() == LogicalTypeId::SQLNULL && value.id() == LogicalTypeId::SQLNULL); +static bool MapIsNull(DataChunk &chunk) { + if (chunk.data.empty()) { + return false; + } + D_ASSERT(chunk.data.size() == 2); + auto &keys = chunk.data[0]; + auto &values = chunk.data[1]; + + if (keys.GetType().id() == LogicalTypeId::SQLNULL) { + return true; + } + if (values.GetType().id() == LogicalTypeId::SQLNULL) { + return true; + } + return false; } static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { @@ -36,7 +46,7 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { // - key names are unique D_ASSERT(result.GetType().id() == LogicalTypeId::MAP); - if (MapIsNull(result.GetType())) { + if (MapIsNull(args)) { auto &validity = FlatVector::Validity(result); validity.SetInvalid(0); result.SetVectorType(VectorType::CONSTANT_VECTOR); diff --git a/test/sql/types/nested/map/test_map_subscript.test b/test/sql/types/nested/map/test_map_subscript.test index f75482857dad..8ad48d29e48b 100644 --- a/test/sql/types/nested/map/test_map_subscript.test +++ b/test/sql/types/nested/map/test_map_subscript.test @@ -2,6 +2,9 @@ # description: Test cardinality function for maps # group: [map] +statement ok +pragma enable_verification + # Single element on map query I select m[1] from (select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as m) as T From 58354a26bebd4515aab5615d7275c4e31534b84d Mon Sep 17 00:00:00 2001 From: stephaniewang Date: Fri, 19 Apr 2024 10:59:20 -0400 Subject: [PATCH 201/201] add patch --- .../extensions/aws/0001-update-tests.patch | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 .github/patches/extensions/aws/0001-update-tests.patch diff --git a/.github/patches/extensions/aws/0001-update-tests.patch b/.github/patches/extensions/aws/0001-update-tests.patch new file mode 100644 index 000000000000..8e668b4e0615 --- /dev/null +++ b/.github/patches/extensions/aws/0001-update-tests.patch @@ -0,0 +1,53 @@ +From e8e6c286376d97e0a695284fc32b3b67a77e35af Mon Sep 17 00:00:00 2001 +From: stephaniewang +Date: Thu, 18 Apr 2024 21:41:29 -0400 +Subject: [PATCH] update tests + +--- + test/sql/aws_minio_secret.test | 2 +- + test/sql/aws_secret_gcs.test | 2 +- + test/sql/aws_secret_r2.test | 2 +- + 3 files changed, 3 insertions(+), 3 deletions(-) + +diff --git a/test/sql/aws_minio_secret.test b/test/sql/aws_minio_secret.test +index 2ddc29f..34c4c92 100644 +--- a/test/sql/aws_minio_secret.test ++++ b/test/sql/aws_minio_secret.test +@@ -28,7 +28,7 @@ CREATE SECRET my_aws_secret ( + ); + + query I +-SELECT which_secret('s3://test-bucket/aws_minio_secret/secret1/test.csv', 's3') ++SELECT name FROM which_secret('s3://test-bucket/aws_minio_secret/secret1/test.csv', 's3') + ---- + my_aws_secret + +diff --git a/test/sql/aws_secret_gcs.test b/test/sql/aws_secret_gcs.test +index 0b1fd40..cbed048 100644 +--- a/test/sql/aws_secret_gcs.test ++++ b/test/sql/aws_secret_gcs.test +@@ -18,7 +18,7 @@ CREATE SECRET s1 ( + ); + + query I +-SELECT which_secret('gcs://haha/hoehoe.parkoe', 'gcs') ++SELECT name FROM which_secret('gcs://haha/hoehoe.parkoe', 'gcs') + ---- + s1 + +diff --git a/test/sql/aws_secret_r2.test b/test/sql/aws_secret_r2.test +index 01be38b..19ebd1e 100644 +--- a/test/sql/aws_secret_r2.test ++++ b/test/sql/aws_secret_r2.test +@@ -19,7 +19,7 @@ CREATE SECRET s1 ( + ); + + query I +-SELECT which_secret('r2://haha/hoehoe.parkoe', 'r2') ++SELECT name FROM which_secret('r2://haha/hoehoe.parkoe', 'r2') + ---- + s1 + +-- +2.39.2 (Apple Git-143) +