From 898e784578ae232b1eec7ee630a673c13794a163 Mon Sep 17 00:00:00 2001 From: kostas Date: Tue, 7 Jan 2025 17:23:35 +0200 Subject: [PATCH] chore: split geo and zset families Signed-off-by: kostas --- src/server/CMakeLists.txt | 2 +- src/server/geo_family.cc | 792 ++++++++++++++++++++++ src/server/geo_family.h | 34 + src/server/main_service.cc | 2 + src/server/zset_family.cc | 1149 ++++++-------------------------- src/server/zset_family.h | 50 +- src/server/zset_family_test.cc | 1 + 7 files changed, 1074 insertions(+), 956 deletions(-) create mode 100644 src/server/geo_family.cc create mode 100644 src/server/geo_family.h diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index e51f7033df75..0187b684bb23 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -55,7 +55,7 @@ add_library(dragonfly_lib bloom_family.cc detail/save_stages_controller.cc detail/snapshot_storage.cc set_family.cc stream_family.cc string_family.cc - zset_family.cc version.cc bitops_family.cc container_utils.cc + zset_family.cc geo_family.cc version.cc bitops_family.cc container_utils.cc top_keys.cc multi_command_squasher.cc hll_family.cc ${DF_SEARCH_SRCS} ${DF_LINUX_SRCS} diff --git a/src/server/geo_family.cc b/src/server/geo_family.cc new file mode 100644 index 000000000000..62c41bd3513a --- /dev/null +++ b/src/server/geo_family.cc @@ -0,0 +1,792 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/geo_family.h" + +#include "server/acl/acl_commands_def.h" +#include "server/zset_family.h" + +extern "C" { +#include "redis/geo.h" +#include "redis/geohash.h" +#include "redis/geohash_helper.h" +#include "redis/redis_aux.h" +#include "redis/util.h" +#include "redis/zmalloc.h" +#include "redis/zset.h" +} + +#include "base/logging.h" +#include "facade/error.h" +#include "server/command_registry.h" +#include "server/conn_context.h" +#include "server/engine_shard_set.h" +#include "server/error.h" +#include "server/family_utils.h" +#include "server/transaction.h" + +namespace dfly { + +using namespace std; +using namespace facade; +using absl::SimpleAtoi; +namespace { + +using CI = CommandId; + +static const char kNxXxErr[] = "XX and NX options at the same time are not compatible"; +static const char kFromMemberLonglatErr[] = + "FROMMEMBER and FROMLONLAT options at the same time are not compatible"; +static const char kByRadiusBoxErr[] = + "BYRADIUS and BYBOX options at the same time are not compatible"; +static const char kAscDescErr[] = "ASC and DESC options at the same time are not compatible"; +static const char kStoreTypeErr[] = + "STORE and STOREDIST options at the same time are not compatible"; +static const char kStoreCompatErr[] = + "STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options"; +static const char kMemberNotFound[] = "could not decode requested zset member"; +constexpr string_view kGeoAlphabet = "0123456789bcdefghjkmnpqrstuvwxyz"sv; + +using MScoreResponse = std::vector>; + +using ScoredMember = std::pair; +using ScoredArray = std::vector; +using ScoredMemberView = std::pair; +using ScoredMemberSpan = absl::Span; + +struct GeoPoint { + double longitude; + double latitude; + double dist; + double score; + std::string member; + GeoPoint() : longitude(0.0), latitude(0.0), dist(0.0), score(0.0){}; + GeoPoint(double _longitude, double _latitude, double _dist, double _score, + const std::string& _member) + : longitude(_longitude), latitude(_latitude), dist(_dist), score(_score), member(_member){}; +}; +using GeoArray = std::vector; + +enum class Sorting { kUnsorted, kAsc, kDesc }; +enum class GeoStoreType { kNoStore, kStoreHash, kStoreDist }; +struct GeoSearchOpts { + double conversion = 0; + uint64_t count = 0; + Sorting sorting = Sorting::kUnsorted; + bool any = 0; + bool withdist = 0; + bool withcoord = 0; + bool withhash = 0; + GeoStoreType store = GeoStoreType::kNoStore; + string_view store_key; + + bool HasWithStatement() const { + return withdist || withcoord || withhash; + } +}; + +bool ParseLongLat(string_view lon, string_view lat, std::pair* res) { + if (!ParseDouble(lon, &res->first)) + return false; + + if (!ParseDouble(lat, &res->second)) + return false; + + if (res->first < GEO_LONG_MIN || res->first > GEO_LONG_MAX || res->second < GEO_LAT_MIN || + res->second > GEO_LAT_MAX) { + return false; + } + return true; +} + +bool ScoreToLongLat(const std::optional& val, double* xy) { + if (!val.has_value()) + return false; + + double score = *val; + + GeoHashBits hash = {.bits = (uint64_t)score, .step = GEO_STEP_MAX}; + + return geohashDecodeToLongLatType(hash, xy) == 1; +} + +bool ToAsciiGeoHash(const std::optional& val, array* buf) { + if (!val.has_value()) + return false; + + double score = *val; + + GeoHashBits hash = {.bits = (uint64_t)score, .step = GEO_STEP_MAX}; + + double xy[2]; + if (!geohashDecodeToLongLatType(hash, xy)) { + return false; + } + + /* Re-encode */ + GeoHashRange r[2]; + r[0].min = -180; + r[0].max = 180; + r[1].min = -90; + r[1].max = 90; + + geohashEncode(&r[0], &r[1], xy[0], xy[1], 26, &hash); + + for (int i = 0; i < 11; i++) { + int idx; + if (i == 10) { + /* We have just 52 bits, but the API used to output + * an 11 bytes geohash. For compatibility we assume + * zero. */ + idx = 0; + } else { + idx = (hash.bits >> (52 - ((i + 1) * 5))) % kGeoAlphabet.size(); + } + (*buf)[i] = kGeoAlphabet[idx]; + } + (*buf)[11] = '\0'; + + return true; +} + +double ExtractUnit(std::string_view arg) { + const string unit = absl::AsciiStrToUpper(arg); + if (unit == "M") { + return 1; + } else if (unit == "KM") { + return 1000; + } else if (unit == "FT") { + return 0.3048; + } else if (unit == "MI") { + return 1609.34; + } else { + return -1; + } +} + +} // namespace + +void GeoFamily::GeoAdd(CmdArgList args, const CommandContext& cmd_cntx) { + string_view key = ArgS(args, 0); + + ZSetFamily::ZParams zparams; + size_t i = 1; + for (; i < args.size(); ++i) { + string cur_arg = absl::AsciiStrToUpper(ArgS(args, i)); + + if (cur_arg == "XX") { + zparams.flags |= ZADD_IN_XX; // update only + } else if (cur_arg == "NX") { + zparams.flags |= ZADD_IN_NX; // add new only. + } else if (cur_arg == "CH") { + zparams.ch = true; + } else { + break; + } + } + + auto* builder = cmd_cntx.rb; + args.remove_prefix(i); + if (args.empty() || args.size() % 3 != 0) { + builder->SendError(kSyntaxErr); + return; + } + + if ((zparams.flags & ZADD_IN_NX) && (zparams.flags & ZADD_IN_XX)) { + builder->SendError(kNxXxErr); + return; + } + + absl::InlinedVector members; + for (i = 0; i < args.size(); i += 3) { + string_view longitude = ArgS(args, i); + string_view latitude = ArgS(args, i + 1); + string_view member = ArgS(args, i + 2); + + pair longlat; + + if (!ParseLongLat(longitude, latitude, &longlat)) { + string err = absl::StrCat("-ERR invalid longitude,latitude pair ", longitude, ",", latitude, + ",", member); + + return builder->SendError(err, kSyntaxErrType); + } + + /* Turn the coordinates into the score of the element. */ + GeoHashBits hash; + geohashEncodeWGS84(longlat.first, longlat.second, GEO_STEP_MAX, &hash); + GeoHashFix52Bits bits = geohashAlign52Bits(hash); + + members.emplace_back(bits, member); + } + DCHECK(cmd_cntx.tx); + + absl::Span memb_sp{members.data(), members.size()}; + ZSetFamily::ZAddGeneric(key, zparams, memb_sp, cmd_cntx.tx, builder); +} + +void GeoFamily::GeoHash(CmdArgList args, const CommandContext& cmd_cntx) { + auto* rb = static_cast(cmd_cntx.rb); + + OpResult result = ZSetFamily::ZGetMembers(args, cmd_cntx.tx, rb); + + if (result.status() == OpStatus::WRONG_TYPE) { + return rb->SendError(kWrongTypeErr); + } + + rb->StartArray(result->size()); // Array return type. + const MScoreResponse& arr = result.value(); + + array buf; + for (const auto& p : arr) { + if (ToAsciiGeoHash(p, &buf)) { + rb->SendBulkString(string_view{buf.data(), buf.size() - 1}); + } else { + rb->SendNull(); + } + } +} + +void GeoFamily::GeoPos(CmdArgList args, const CommandContext& cmd_cntx) { + auto* rb = static_cast(cmd_cntx.rb); + + OpResult result = ZSetFamily::ZGetMembers(args, cmd_cntx.tx, rb); + + if (result.status() != OpStatus::OK) { + return rb->SendError(result.status()); + } + + rb->StartArray(result->size()); // Array return type. + const MScoreResponse& arr = result.value(); + + double xy[2]; + for (const auto& p : arr) { + if (ScoreToLongLat(p, xy)) { + rb->StartArray(2); + rb->SendDouble(xy[0]); + rb->SendDouble(xy[1]); + } else { + rb->SendNull(); + } + } +} + +void GeoFamily::GeoDist(CmdArgList args, const CommandContext& cmd_cntx) { + double distance_multiplier = 1; + auto* rb = static_cast(cmd_cntx.rb); + + if (args.size() == 4) { + string_view unit = ArgS(args, 3); + distance_multiplier = ExtractUnit(unit); + args.remove_suffix(1); + if (distance_multiplier < 0) { + return rb->SendError("unsupported unit provided. please use M, KM, FT, MI"); + } + } else if (args.size() != 3) { + return rb->SendError(kSyntaxErr); + } + + OpResult result = ZSetFamily::ZGetMembers(args, cmd_cntx.tx, rb); + + if (result.status() != OpStatus::OK) { + return rb->SendError(result.status()); + } + + const MScoreResponse& arr = result.value(); + + if (arr.size() != 2) { + return rb->SendError(kSyntaxErr); + } + + double xyxy[4]; // 2 pairs of score holding 2 locations + for (size_t i = 0; i < arr.size(); i++) { + if (!ScoreToLongLat(arr[i], xyxy + (i * 2))) { + return rb->SendNull(); + } + } + + return rb->SendDouble(geohashGetDistance(xyxy[0], xyxy[1], xyxy[2], xyxy[3]) / + distance_multiplier); +} + +namespace { +std::vector GetGeoRangeSpec(const GeoHashRadius& n) { + array neighbors; + unsigned int last_processed = 0; + + neighbors[0] = n.hash; + neighbors[1] = n.neighbors.north; + neighbors[2] = n.neighbors.south; + neighbors[3] = n.neighbors.east; + neighbors[4] = n.neighbors.west; + neighbors[5] = n.neighbors.north_east; + neighbors[6] = n.neighbors.north_west; + neighbors[7] = n.neighbors.south_east; + neighbors[8] = n.neighbors.south_west; + + // Get range_specs for neighbors (*and* our own hashbox) + std::vector range_specs; + for (unsigned int i = 0; i < neighbors.size(); i++) { + if (HASHISZERO(neighbors[i])) { + continue; + } + + // When a huge Radius (in the 5000 km range or more) is used, + // adjacent neighbors can be the same, leading to duplicated + // elements. Skip every range which is the same as the one + // processed previously. + if (last_processed && neighbors[i].bits == neighbors[last_processed].bits && + neighbors[i].step == neighbors[last_processed].step) { + continue; + } + + GeoHashFix52Bits min, max; + scoresOfGeoHashBox(neighbors[i], &min, &max); + + ZSetFamily::ScoreInterval si; + si.first = ZSetFamily::Bound{static_cast(min), false}; + si.second = ZSetFamily::Bound{static_cast(max), true}; + + ZSetFamily::RangeParams range_params; + range_params.interval_type = ZSetFamily::RangeParams::IntervalType::SCORE; + range_params.with_scores = true; + range_specs.emplace_back(si, range_params); + + last_processed = i; + } + return range_specs; +} + +void SortIfNeeded(GeoArray* ga, Sorting sorting, uint64_t count) { + if (sorting == Sorting::kUnsorted) + return; + + auto comparator = [&](const GeoPoint& a, const GeoPoint& b) { + if (sorting == Sorting::kAsc) { + return a.dist < b.dist; + } else { + DCHECK(sorting == Sorting::kDesc); + return a.dist > b.dist; + } + }; + + if (count > 0) { + std::partial_sort(ga->begin(), ga->begin() + count, ga->end(), comparator); + ga->resize(count); + } else { + std::sort(ga->begin(), ga->end(), comparator); + } +} + +void GeoSearchStoreGeneric(Transaction* tx, facade::SinkReplyBuilder* builder, + const GeoShape& shape_ref, string_view key, string_view member, + const GeoSearchOpts& geo_ops) { + GeoShape* shape = &(const_cast(shape_ref)); + auto* rb = static_cast(builder); + + ShardId from_shard = Shard(key, shard_set->size()); + + if (!member.empty()) { + // get shape.xy from member + OpResult member_score; + auto cb = [&](Transaction* t, EngineShard* shard) { + if (shard->shard_id() == from_shard) { + member_score = ZSetFamily::OpScore(t->GetOpArgs(shard), key, member); + } + return OpStatus::OK; + }; + tx->Execute(std::move(cb), false); + auto member_sts = member_score.status(); + if (member_sts != OpStatus::OK) { + tx->Conclude(); + switch (member_sts) { + case OpStatus::WRONG_TYPE: + return builder->SendError(kWrongTypeErr); + case OpStatus::KEY_NOTFOUND: + return rb->StartArray(0); + case OpStatus::MEMBER_NOTFOUND: + return builder->SendError(kMemberNotFound); + default: + return builder->SendError(member_sts); + } + } + ScoreToLongLat(*member_score, shape->xy); + } else { + // verify key is valid + OpResult result; + auto cb = [&](Transaction* t, EngineShard* shard) { + if (shard->shard_id() == from_shard) { + result = ZSetFamily::OpKeyExisted(t->GetOpArgs(shard), key); + } + return OpStatus::OK; + }; + tx->Execute(std::move(cb), false); + auto result_sts = result.status(); + if (result_sts != OpStatus::OK) { + tx->Conclude(); + switch (result_sts) { + case OpStatus::WRONG_TYPE: + return builder->SendError(kWrongTypeErr); + case OpStatus::KEY_NOTFOUND: + return rb->StartArray(0); + default: + return builder->SendError(result_sts); + } + } + } + DCHECK(shape->xy[0] >= -180.0 && shape->xy[0] <= 180.0); + DCHECK(shape->xy[1] >= -90.0 && shape->xy[1] <= 90.0); + + // query + GeoHashRadius georadius = geohashCalculateAreasByShapeWGS84(shape); + GeoArray ga; + auto range_specs = GetGeoRangeSpec(georadius); + // get all the matching members and add them to the potential result list + vector>> result_arrays; + auto cb = [&](Transaction* t, EngineShard* shard) { + auto res_it = ZSetFamily::OpRanges(range_specs, t->GetOpArgs(shard), key); + if (res_it) { + result_arrays.emplace_back(res_it); + } + return OpStatus::OK; + }; + + tx->Execute(std::move(cb), geo_ops.store == GeoStoreType::kNoStore); + + // filter potential result list + double xy[2]; + double distance; + unsigned long limit = geo_ops.any ? geo_ops.count : 0; + for (auto& result_array : result_arrays) { + for (auto& arr : *result_array) { + for (auto& p : arr) { + if (geoWithinShape(shape, p.second, xy, &distance) == 0) { + ga.emplace_back(xy[0], xy[1], distance, p.second, p.first); + if (limit > 0 && ga.size() >= limit) + break; + } + } + } + } + + // sort and trim by count + SortIfNeeded(&ga, geo_ops.sorting, geo_ops.count); + + if (geo_ops.store == GeoStoreType::kNoStore) { + // case 1: read mode + // case 2: write mode, kNoStore + // generate reply array withdist, withcoords, withhash + int record_size = 1; + if (geo_ops.withdist) { + record_size++; + } + if (geo_ops.withhash) { + record_size++; + } + if (geo_ops.withcoord) { + record_size++; + } + rb->StartArray(ga.size()); + for (const auto& p : ga) { + // [member, dist, x, y, hash] + if (geo_ops.HasWithStatement()) { + rb->StartArray(record_size); + } + rb->SendBulkString(p.member); + if (geo_ops.withdist) { + rb->SendDouble(p.dist / geo_ops.conversion); + } + if (geo_ops.withhash) { + rb->SendDouble(p.score); + } + if (geo_ops.withcoord) { + rb->StartArray(2); + rb->SendDouble(p.longitude); + rb->SendDouble(p.latitude); + } + } + } else { + // case 3: write mode, !kNoStore + DCHECK(geo_ops.store == GeoStoreType::kStoreDist || geo_ops.store == GeoStoreType::kStoreHash); + ShardId dest_shard = Shard(geo_ops.store_key, shard_set->size()); + DVLOG(1) << "store shard:" << dest_shard << ", key " << geo_ops.store_key; + ZSetFamily::AddResult add_result; + vector smvec; + for (const auto& p : ga) { + if (geo_ops.store == GeoStoreType::kStoreDist) { + smvec.emplace_back(p.dist / geo_ops.conversion, p.member); + } else { + DCHECK(geo_ops.store == GeoStoreType::kStoreHash); + smvec.emplace_back(p.score, p.member); + } + } + + auto store_cb = [&](Transaction* t, EngineShard* shard) { + if (shard->shard_id() == dest_shard) { + ZSetFamily::ZParams zparams; + zparams.override = true; + add_result = ZSetFamily::OpAdd(t->GetOpArgs(shard), zparams, geo_ops.store_key, + ScoredMemberSpan{smvec}) + .value(); + } + return OpStatus::OK; + }; + tx->Execute(std::move(store_cb), true); + + rb->SendLong(smvec.size()); + } +} +} // namespace + +void GeoFamily::GeoSearch(CmdArgList args, const CommandContext& cmd_cntx) { + // parse arguments + string_view key = ArgS(args, 0); + GeoShape shape = {}; + GeoSearchOpts geo_ops; + string_view member; + + // FROMMEMBER or FROMLONLAT is set + bool from_set = false; + // BYRADIUS or BYBOX is set + bool by_set = false; + auto* builder = cmd_cntx.rb; + + for (size_t i = 1; i < args.size(); ++i) { + string cur_arg = absl::AsciiStrToUpper(ArgS(args, i)); + + if (cur_arg == "FROMMEMBER") { + if (from_set) { + return builder->SendError(kFromMemberLonglatErr); + } else if (i + 1 < args.size()) { + member = ArgS(args, i + 1); + from_set = true; + i++; + } else { + return builder->SendError(kSyntaxErr); + } + } else if (cur_arg == "FROMLONLAT") { + if (from_set) { + return builder->SendError(kFromMemberLonglatErr); + } else if (i + 2 < args.size()) { + string_view longitude_str = ArgS(args, i + 1); + string_view latitude_str = ArgS(args, i + 2); + pair longlat; + if (!ParseLongLat(longitude_str, latitude_str, &longlat)) { + string err = absl::StrCat("-ERR invalid longitude,latitude pair ", longitude_str, ",", + latitude_str); + return builder->SendError(err, kSyntaxErrType); + } + shape.xy[0] = longlat.first; + shape.xy[1] = longlat.second; + from_set = true; + i += 2; + } else { + return builder->SendError(kSyntaxErr); + } + } else if (cur_arg == "BYRADIUS") { + if (by_set) { + return builder->SendError(kByRadiusBoxErr); + } else if (i + 2 < args.size()) { + if (!ParseDouble(ArgS(args, i + 1), &shape.t.radius)) { + return builder->SendError(kInvalidFloatErr); + } + string_view unit = ArgS(args, i + 2); + shape.conversion = ExtractUnit(unit); + geo_ops.conversion = shape.conversion; + if (shape.conversion == -1) { + return builder->SendError("unsupported unit provided. please use M, KM, FT, MI"); + } + shape.type = CIRCULAR_TYPE; + by_set = true; + i += 2; + } else { + return builder->SendError(kSyntaxErr); + } + } else if (cur_arg == "BYBOX") { + if (by_set) { + return builder->SendError(kByRadiusBoxErr); + } else if (i + 3 < args.size()) { + if (!ParseDouble(ArgS(args, i + 1), &shape.t.r.width)) { + return builder->SendError(kInvalidFloatErr); + } + if (!ParseDouble(ArgS(args, i + 2), &shape.t.r.height)) { + return builder->SendError(kInvalidFloatErr); + } + string_view unit = ArgS(args, i + 3); + shape.conversion = ExtractUnit(unit); + geo_ops.conversion = shape.conversion; + if (shape.conversion == -1) { + return builder->SendError("unsupported unit provided. please use M, KM, FT, MI"); + } + shape.type = RECTANGLE_TYPE; + by_set = true; + i += 3; + } else { + return builder->SendError(kSyntaxErr); + } + } else if (cur_arg == "ASC") { + if (geo_ops.sorting != Sorting::kUnsorted) { + return builder->SendError(kAscDescErr); + } else { + geo_ops.sorting = Sorting::kAsc; + } + } else if (cur_arg == "DESC") { + if (geo_ops.sorting != Sorting::kUnsorted) { + return builder->SendError(kAscDescErr); + } else { + geo_ops.sorting = Sorting::kDesc; + } + } else if (cur_arg == "COUNT") { + if (i + 1 < args.size() && absl::SimpleAtoi(ArgS(args, i + 1), &geo_ops.count)) { + i++; + } else { + return builder->SendError(kSyntaxErr); + } + if (i + 1 < args.size() && ArgS(args, i + 1) == "ANY") { + geo_ops.any = true; + i++; + } + } else if (cur_arg == "WITHCOORD") { + geo_ops.withcoord = true; + } else if (cur_arg == "WITHDIST") { + geo_ops.withdist = true; + } else if (cur_arg == "WITHHASH") { + geo_ops.withhash = true; + } else { + return builder->SendError(kSyntaxErr); + } + } + + // check mandatory options + if (!from_set) { + return builder->SendError(kSyntaxErr); + } + if (!by_set) { + return builder->SendError(kSyntaxErr); + } + // parsing completed + + GeoSearchStoreGeneric(cmd_cntx.tx, builder, shape, key, member, geo_ops); +} + +void GeoFamily::GeoRadiusByMember(CmdArgList args, const CommandContext& cmd_cntx) { + GeoShape shape = {}; + GeoSearchOpts geo_ops; + // parse arguments + string_view key = ArgS(args, 0); + // member to latlong, set shape.xy + string_view member = ArgS(args, 1); + + auto* builder = cmd_cntx.rb; + if (!ParseDouble(ArgS(args, 2), &shape.t.radius)) { + return builder->SendError(kInvalidFloatErr); + } + string_view unit = ArgS(args, 3); + shape.conversion = ExtractUnit(unit); + geo_ops.conversion = shape.conversion; + if (shape.conversion == -1) { + return builder->SendError("unsupported unit provided. please use M, KM, FT, MI"); + } + shape.type = CIRCULAR_TYPE; + + for (size_t i = 4; i < args.size(); ++i) { + string cur_arg = absl::AsciiStrToUpper(ArgS(args, i)); + + if (cur_arg == "ASC") { + if (geo_ops.sorting != Sorting::kUnsorted) { + return builder->SendError(kAscDescErr); + } else { + geo_ops.sorting = Sorting::kAsc; + } + } else if (cur_arg == "DESC") { + if (geo_ops.sorting != Sorting::kUnsorted) { + return builder->SendError(kAscDescErr); + } else { + geo_ops.sorting = Sorting::kDesc; + } + } else if (cur_arg == "COUNT") { + if (i + 1 < args.size() && absl::SimpleAtoi(ArgS(args, i + 1), &geo_ops.count)) { + i++; + } else { + return builder->SendError(kSyntaxErr); + } + if (i + 1 < args.size() && ArgS(args, i + 1) == "ANY") { + geo_ops.any = true; + i++; + } + } else if (cur_arg == "WITHCOORD") { + if (geo_ops.store != GeoStoreType::kNoStore) { + return builder->SendError(kStoreCompatErr); + } + geo_ops.withcoord = true; + } else if (cur_arg == "WITHDIST") { + if (geo_ops.store != GeoStoreType::kNoStore) { + return builder->SendError(kStoreCompatErr); + } + geo_ops.withdist = true; + } else if (cur_arg == "WITHHASH") { + if (geo_ops.store != GeoStoreType::kNoStore) { + return builder->SendError(kStoreCompatErr); + } + geo_ops.withhash = true; + } else if (cur_arg == "STORE") { + if (geo_ops.store != GeoStoreType::kNoStore) { + return builder->SendError(kStoreTypeErr); + } else if (geo_ops.withcoord || geo_ops.withdist || geo_ops.withhash) { + return builder->SendError(kStoreCompatErr); + } + if (i + 1 < args.size()) { + geo_ops.store_key = ArgS(args, i + 1); + geo_ops.store = GeoStoreType::kStoreHash; + i++; + } else { + return builder->SendError(kSyntaxErr); + } + } else if (cur_arg == "STOREDIST") { + if (geo_ops.store != GeoStoreType::kNoStore) { + return builder->SendError(kStoreTypeErr); + } else if (geo_ops.withcoord || geo_ops.withdist || geo_ops.withhash) { + return builder->SendError(kStoreCompatErr); + } + if (i + 1 < args.size()) { + geo_ops.store_key = ArgS(args, i + 1); + geo_ops.store = GeoStoreType::kStoreDist; + i++; + } else { + return builder->SendError(kSyntaxErr); + } + } else { + return builder->SendError(kSyntaxErr); + } + } + // parsing completed + + GeoSearchStoreGeneric(cmd_cntx.tx, builder, shape, key, member, geo_ops); +} + +#define HFUNC(x) SetHandler(&GeoFamily::x) + +namespace acl { +constexpr uint32_t kGeoAdd = WRITE | GEO | SLOW; +constexpr uint32_t kGeoHash = READ | GEO | SLOW; +constexpr uint32_t kGeoPos = READ | GEO | SLOW; +constexpr uint32_t kGeoDist = READ | GEO | SLOW; +constexpr uint32_t kGeoSearch = READ | GEO | SLOW; +constexpr uint32_t kGeoRadiusByMember = WRITE | GEO | SLOW; +} // namespace acl + +void GeoFamily::Register(CommandRegistry* registry) { + registry->StartFamily(); + *registry << CI{"GEOADD", CO::FAST | CO::WRITE | CO::DENYOOM, -5, 1, 1, acl::kGeoAdd}.HFUNC( + GeoAdd) + << CI{"GEOHASH", CO::FAST | CO::READONLY, -2, 1, 1, acl::kGeoHash}.HFUNC(GeoHash) + << CI{"GEOPOS", CO::FAST | CO::READONLY, -2, 1, 1, acl::kGeoPos}.HFUNC(GeoPos) + << CI{"GEODIST", CO::READONLY, -4, 1, 1, acl::kGeoDist}.HFUNC(GeoDist) + << CI{"GEOSEARCH", CO::READONLY, -4, 1, 1, acl::kGeoSearch}.HFUNC(GeoSearch) + << CI{"GEORADIUSBYMEMBER", CO::WRITE | CO::STORE_LAST_KEY, -4, 1, 1, + acl::kGeoRadiusByMember} + .HFUNC(GeoRadiusByMember); +} + +} // namespace dfly diff --git a/src/server/geo_family.h b/src/server/geo_family.h new file mode 100644 index 000000000000..e07027b418c1 --- /dev/null +++ b/src/server/geo_family.h @@ -0,0 +1,34 @@ +// Copyright 2025, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include + +#include "server/common.h" + +namespace facade { +class SinkReplyBuilder; +} // namespace facade + +namespace dfly { + +class CommandRegistry; +struct CommandContext; + +class GeoFamily { + public: + static void Register(CommandRegistry* registry); + using SinkReplyBuilder = facade::SinkReplyBuilder; + + private: + static void GeoAdd(CmdArgList args, const CommandContext& cmd_cntx); + static void GeoHash(CmdArgList args, const CommandContext& cmd_cntx); + static void GeoPos(CmdArgList args, const CommandContext& cmd_cntx); + static void GeoDist(CmdArgList args, const CommandContext& cmd_cntx); + static void GeoSearch(CmdArgList args, const CommandContext& cmd_cntx); + static void GeoRadiusByMember(CmdArgList args, const CommandContext& cmd_cntx); +}; + +} // namespace dfly diff --git a/src/server/main_service.cc b/src/server/main_service.cc index a33a86739af9..456e13c5fdca 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -43,6 +43,7 @@ extern "C" { #include "server/conn_context.h" #include "server/error.h" #include "server/generic_family.h" +#include "server/geo_family.h" #include "server/hll_family.h" #include "server/hset_family.h" #include "server/http_api.h" @@ -2672,6 +2673,7 @@ void Service::RegisterCommands() { SetFamily::Register(®istry_); HSetFamily::Register(®istry_); ZSetFamily::Register(®istry_); + GeoFamily::Register(®istry_); JsonFamily::Register(®istry_); BitOpsFamily::Register(®istry_); HllFamily::Register(®istry_); diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index bf08072884ab..80275ba4333a 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -41,56 +41,15 @@ namespace { using CI = CommandId; static const char kNxXxErr[] = "XX and NX options at the same time are not compatible"; -static const char kFromMemberLonglatErr[] = - "FROMMEMBER and FROMLONLAT options at the same time are not compatible"; -static const char kByRadiusBoxErr[] = - "BYRADIUS and BYBOX options at the same time are not compatible"; -static const char kAscDescErr[] = "ASC and DESC options at the same time are not compatible"; -static const char kStoreTypeErr[] = - "STORE and STOREDIST options at the same time are not compatible"; -static const char kScoreNaN[] = "resulting score is not a number (NaN)"; -static const char kFloatRangeErr[] = "min or max is not a float"; static const char kLexRangeErr[] = "min or max not valid string range item"; -static const char kStoreCompatErr[] = - "STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options"; -static const char kMemberNotFound[] = "could not decode requested zset member"; -constexpr string_view kGeoAlphabet = "0123456789bcdefghjkmnpqrstuvwxyz"sv; +static const char kFloatRangeErr[] = "min or max is not a float"; +static const char kScoreNaN[] = "resulting score is not a number (NaN)"; using MScoreResponse = std::vector>; - -using ScoredMember = std::pair; -using ScoredArray = std::vector; - -struct GeoPoint { - double longitude; - double latitude; - double dist; - double score; - std::string member; - GeoPoint() : longitude(0.0), latitude(0.0), dist(0.0), score(0.0){}; - GeoPoint(double _longitude, double _latitude, double _dist, double _score, - const std::string& _member) - : longitude(_longitude), latitude(_latitude), dist(_dist), score(_score), member(_member){}; -}; -using GeoArray = std::vector; - -enum class Sorting { kUnsorted, kAsc, kDesc }; -enum class GeoStoreType { kNoStore, kStoreHash, kStoreDist }; -struct GeoSearchOpts { - double conversion = 0; - uint64_t count = 0; - Sorting sorting = Sorting::kUnsorted; - bool any = 0; - bool withdist = 0; - bool withcoord = 0; - bool withhash = 0; - GeoStoreType store = GeoStoreType::kNoStore; - string_view store_key; - - bool HasWithStatement() const { - return withdist || withcoord || withhash; - } -}; +using ScoredMember = ZSetFamily::ScoredMember; +using ScoredArray = ZSetFamily::ScoredArray; +using ScoredMemberView = ZSetFamily::ScoredMemberView; +using ScoredMemberSpan = ZSetFamily::ScoredMemberSpan; inline zrangespec GetZrangeSpec(bool reverse, const ZSetFamily::ScoreInterval& si) { auto interval = si; @@ -177,12 +136,6 @@ std::optional GetZsetScore(const detail::RobjWrapper* robj_wrapper, sds return 0; } -struct ZParams { - unsigned flags = 0; // mask of ZADD_IN_ macros. - bool ch = false; // Corresponds to CH option. - bool override = false; -}; - void OutputScoredArrayResult(const OpResult& result, SinkReplyBuilder* builder) { if (result.status() == OpStatus::WRONG_TYPE) { return builder->SendError(kWrongTypeErr); @@ -194,8 +147,9 @@ void OutputScoredArrayResult(const OpResult& result, SinkReplyBuild rb->SendScoredArray(result.value(), true /* with scores */); } -OpResult FindZEntry(const ZParams& zparams, const OpArgs& op_args, - string_view key, size_t member_len) { +OpResult FindZEntry(const ZSetFamily::ZParams& zparams, + const OpArgs& op_args, string_view key, + size_t member_len) { auto& db_slice = op_args.GetDbSlice(); if (zparams.flags & ZADD_IN_XX) { return db_slice.FindMutable(op_args.db_cntx, key, OBJ_ZSET); @@ -231,56 +185,6 @@ OpResult FindZEntry(const ZParams& zparams, const OpArgs& return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)}; } -bool ScoreToLongLat(const std::optional& val, double* xy) { - if (!val.has_value()) - return false; - - double score = *val; - - GeoHashBits hash = {.bits = (uint64_t)score, .step = GEO_STEP_MAX}; - - return geohashDecodeToLongLatType(hash, xy) == 1; -} - -bool ToAsciiGeoHash(const std::optional& val, array* buf) { - if (!val.has_value()) - return false; - - double score = *val; - - GeoHashBits hash = {.bits = (uint64_t)score, .step = GEO_STEP_MAX}; - - double xy[2]; - if (!geohashDecodeToLongLatType(hash, xy)) { - return false; - } - - /* Re-encode */ - GeoHashRange r[2]; - r[0].min = -180; - r[0].max = 180; - r[1].min = -90; - r[1].max = 90; - - geohashEncode(&r[0], &r[1], xy[0], xy[1], 26, &hash); - - for (int i = 0; i < 11; i++) { - int idx; - if (i == 10) { - /* We have just 52 bits, but the API used to output - * an 11 bytes geohash. For compatibility we assume - * zero. */ - idx = 0; - } else { - idx = (hash.bits >> (52 - ((i + 1) * 5))) % kGeoAlphabet.size(); - } - (*buf)[i] = kGeoAlphabet[idx]; - } - (*buf)[11] = '\0'; - - return true; -} - enum class Action { RANGE = 0, REMOVE = 1, POP = 2 }; class IntervalVisitor { @@ -680,20 +584,6 @@ bool ParseBound(string_view src, ZSetFamily::Bound* bound) { return ParseDouble(src, &bound->val); } -bool ParseLongLat(string_view lon, string_view lat, std::pair* res) { - if (!ParseDouble(lon, &res->first)) - return false; - - if (!ParseDouble(lat, &res->second)) - return false; - - if (res->first < GEO_LONG_MIN || res->first > GEO_LONG_MAX || res->second < GEO_LAT_MIN || - res->second > GEO_LAT_MAX) { - return false; - } - return true; -} - bool ParseLexBound(string_view src, ZSetFamily::LexBound* bound) { if (src.empty()) return false; @@ -943,16 +833,6 @@ OpResult OpInter(EngineShard* shard, Transaction* t, string_view dest return result; } -using ScoredMemberView = std::pair; -using ScoredMemberSpan = absl::Span; - -struct AddResult { - double new_score = 0; - unsigned num_updated = 0; - - bool is_nan = false; -}; - size_t EstimateListpackMinBytes(ScoredMemberSpan members) { size_t bytes = members.size() * 2; // at least 2 bytes per score; for (const auto& member : members) { @@ -961,97 +841,6 @@ size_t EstimateListpackMinBytes(ScoredMemberSpan members) { return bytes; } -OpResult OpAdd(const OpArgs& op_args, const ZParams& zparams, string_view key, - ScoredMemberSpan members) { - DCHECK(!members.empty() || zparams.override); - auto& db_slice = op_args.GetDbSlice(); - - if (zparams.override && members.empty()) { - auto it = db_slice.FindMutable(op_args.db_cntx, key).it; // post_updater will run immediately - if (IsValid(it)) { - db_slice.Del(op_args.db_cntx, it); - } - return OpStatus::OK; - } - - // When we have too many members to add, make sure field_len is large enough to use - // skiplist encoding. - size_t field_len = members.size() > server.zset_max_listpack_entries - ? UINT32_MAX - : members.front().second.size(); - auto res_it = FindZEntry(zparams, op_args, key, field_len); - - if (!res_it) - return res_it.status(); - - unsigned added = 0; - unsigned updated = 0; - - double new_score = 0; - int retflags = 0; - - OpStatus op_status = OpStatus::OK; - AddResult aresult; - detail::RobjWrapper* robj_wrapper = res_it->it->second.GetRobjWrapper(); - bool is_list_pack = IsListPack(robj_wrapper); - - // opportunistically reserve space if multiple entries are about to be added. - if ((zparams.flags & ZADD_IN_XX) == 0 && members.size() > 2) { - if (is_list_pack) { - uint8_t* zl = (uint8_t*)robj_wrapper->inner_obj(); - size_t malloc_reserved = zmalloc_size(zl); - size_t min_sz = EstimateListpackMinBytes(members); - if (min_sz > malloc_reserved) { - zl = (uint8_t*)zrealloc(zl, min_sz); - robj_wrapper->set_inner_obj(zl); - } - } else { - detail::SortedMap* sm = (detail::SortedMap*)robj_wrapper->inner_obj(); - sm->Reserve(members.size()); - } - } - - for (size_t j = 0; j < members.size(); j++) { - const auto& m = members[j]; - int retval = - robj_wrapper->ZsetAdd(m.first, WrapSds(m.second), zparams.flags, &retflags, &new_score); - - if (zparams.flags & ZADD_IN_INCR) { - if (retval == 0) { - CHECK_EQ(1u, members.size()); - - aresult.is_nan = true; - break; - } - - if (retflags & ZADD_OUT_NOP) { - op_status = OpStatus::SKIPPED; - } - } - - if (retflags & ZADD_OUT_ADDED) - added++; - if (retflags & ZADD_OUT_UPDATED) - updated++; - } - - // if we migrated to skip_list - update listpack stats. - if (is_list_pack && !IsListPack(robj_wrapper)) { - DbTableStats* stats = db_slice.MutableStats(op_args.db_cntx.db_index); - --stats->listpack_blob_cnt; - } - - if (zparams.flags & ZADD_IN_INCR) { - aresult.new_score = new_score; - } else { - aresult.num_updated = zparams.ch ? added + updated : added; - } - - if (op_status != OpStatus::OK) - return op_status; - return aresult; -} - struct SetOpArgs { AggType agg_type = AggType::SUM; unsigned num_keys; @@ -1357,24 +1146,6 @@ auto OpRange(const ZSetFamily::ZRangeSpec& range_spec, const OpArgs& op_args, st return iv.PopResult(); } -auto OpRanges(const std::vector& range_specs, const OpArgs& op_args, - string_view key) -> OpResult> { - auto res_it = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); - if (!res_it) - return res_it.status(); - - // Action::RANGE is read-only, but requires mutable pointer, thus const_cast - PrimeValue& pv = const_cast(res_it.value()->second); - vector result_arrays; - for (auto& range_spec : range_specs) { - IntervalVisitor iv{Action::RANGE, range_spec.params, &pv}; - std::visit(iv, range_spec.interval); - result_arrays.push_back(iv.PopResult()); - } - - return result_arrays; -} - OpResult OpRemRange(const OpArgs& op_args, string_view key, const ZSetFamily::ZRangeSpec& range_spec) { auto& db_slice = op_args.GetDbSlice(); @@ -1575,25 +1346,6 @@ OpResult OpRem(const OpArgs& op_args, string_view key, facade::ArgRang return deleted; } -OpResult OpKeyExisted(const OpArgs& op_args, string_view key) { - auto res_it = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); - return res_it.status(); -} - -OpResult OpScore(const OpArgs& op_args, string_view key, string_view member) { - auto res_it = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); - if (!res_it) - return res_it.status(); - - const PrimeValue& pv = res_it.value()->second; - const detail::RobjWrapper* robj_wrapper = pv.GetRobjWrapper(); - auto res = GetZsetScore(robj_wrapper, WrapSds(member)); - if (!res) { - return OpStatus::MEMBER_NOTFOUND; - } - return *res; -} - OpResult OpMScore(const OpArgs& op_args, string_view key, facade::ArgRange members) { auto res_it = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); @@ -1705,52 +1457,6 @@ OpResult OpRandMember(int count, const ZSetFamily::RangeParams& par return result; } -void ZAddGeneric(string_view key, const ZParams& zparams, ScoredMemberSpan memb_sp, Transaction* tx, - SinkReplyBuilder* builder) { - auto cb = [&](Transaction* t, EngineShard* shard) { - return OpAdd(t->GetOpArgs(shard), zparams, key, memb_sp); - }; - - OpResult add_result = tx->ScheduleSingleHopT(std::move(cb)); - if (base::_in(add_result.status(), {OpStatus::WRONG_TYPE, OpStatus::OUT_OF_MEMORY})) { - return builder->SendError(add_result.status()); - } - - auto* rb = static_cast(builder); - // KEY_NOTFOUND may happen in case of XX flag. - if (add_result.status() == OpStatus::KEY_NOTFOUND) { - if (zparams.flags & ZADD_IN_INCR) - rb->SendNull(); - else - rb->SendLong(0); - } else if (add_result.status() == OpStatus::SKIPPED) { - rb->SendNull(); - } else if (add_result->is_nan) { - builder->SendError(kScoreNaN); - } else { - if (zparams.flags & ZADD_IN_INCR) { - rb->SendDouble(add_result->new_score); - } else { - rb->SendLong(add_result->num_updated); - } - } -} - -double ExtractUnit(std::string_view arg) { - const string unit = absl::AsciiStrToUpper(arg); - if (unit == "M") { - return 1; - } else if (unit == "KM") { - return 1000; - } else if (unit == "FT") { - return 0.3048; - } else if (unit == "MI") { - return 1609.34; - } else { - return -1; - } -} - // Boolean operation: union or intersection, optionally storing output to destination key void ZBooleanOperation(CmdArgList args, string_view cmd, bool is_union, bool store, Transaction* tx, SinkReplyBuilder* builder) { @@ -1805,7 +1511,8 @@ void ZBooleanOperation(CmdArgList args, string_view cmd, bool is_union, bool sto auto store_cb = [&, dest_shard = Shard(dest_key, maps.size())](Transaction* t, EngineShard* shard) { if (shard->shard_id() == dest_shard) - OpAdd(t->GetOpArgs(shard), ZParams{.override = true}, dest_key, smvec); + ZSetFamily::OpAdd(t->GetOpArgs(shard), ZSetFamily::ZParams{.override = true}, dest_key, + smvec); return OpStatus::OK; }; tx->Execute(store_cb, true); @@ -1869,16 +1576,6 @@ void ZPopMinMaxFromArgs(CmdArgList args, bool reverse, Transaction* tx, SinkRepl OutputScoredArrayResult(ZPopMinMaxInternal(key, FilterShards::NO, count, reverse, tx), builder); } -OpResult ZGetMembers(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) { - string_view key = ArgS(args, 0); - auto members = args.subspan(1); - auto cb = [key, members](Transaction* t, EngineShard* shard) { - return OpMScore(t->GetOpArgs(shard), key, members); - }; - - return tx->ScheduleSingleHopT(std::move(cb)); -} - void ZRangeInternal(CmdArgList args, ZSetFamily::RangeParams range_params, Transaction* tx, SinkReplyBuilder* builder) { string_view key = ArgS(args, 0); @@ -1945,7 +1642,7 @@ void ZRangeInternal(CmdArgList args, ZSetFamily::RangeParams range_params, Trans return; } - OpResult add_result; + OpResult add_result; ShardId dest_shard = Shard(*range_params.store_key, shard_set->size()); auto add_cb = [&](Transaction* t, EngineShard* shard) { if (shard->shard_id() != dest_shard) { @@ -1959,8 +1656,8 @@ void ZRangeInternal(CmdArgList args, ZSetFamily::RangeParams range_params, Trans mvec[i++] = {score, str}; } - add_result = - OpAdd(t->GetOpArgs(shard), ZParams{.override = true}, *range_params.store_key, mvec); + add_result = ZSetFamily::OpAdd(t->GetOpArgs(shard), ZSetFamily::ZParams{.override = true}, + *range_params.store_key, mvec); return OpStatus::OK; }; @@ -2148,28 +1845,199 @@ bool ValidateZMPopCommand(CmdArgList args, uint32* num_keys, bool* is_max, int* } // namespace -void ZSetFamily::BZPopMin(CmdArgList args, const CommandContext& cmd_cntx) { - BZPopMinMax(args, cmd_cntx.tx, cmd_cntx.rb, cmd_cntx.conn_cntx, false); -} +void ZSetFamily::ZAddGeneric(string_view key, const ZParams& zparams, ScoredMemberSpan memb_sp, + Transaction* tx, SinkReplyBuilder* builder) { + auto cb = [&](Transaction* t, EngineShard* shard) { + return ZSetFamily::OpAdd(t->GetOpArgs(shard), zparams, key, memb_sp); + }; -void ZSetFamily::BZPopMax(CmdArgList args, const CommandContext& cmd_cntx) { - BZPopMinMax(args, cmd_cntx.tx, cmd_cntx.rb, cmd_cntx.conn_cntx, true); + OpResult add_result = tx->ScheduleSingleHopT(std::move(cb)); + if (base::_in(add_result.status(), {OpStatus::WRONG_TYPE, OpStatus::OUT_OF_MEMORY})) { + return builder->SendError(add_result.status()); + } + + auto* rb = static_cast(builder); + // KEY_NOTFOUND may happen in case of XX flag. + if (add_result.status() == OpStatus::KEY_NOTFOUND) { + if (zparams.flags & ZADD_IN_INCR) + rb->SendNull(); + else + rb->SendLong(0); + } else if (add_result.status() == OpStatus::SKIPPED) { + rb->SendNull(); + } else if (add_result->is_nan) { + builder->SendError(kScoreNaN); + } else { + if (zparams.flags & ZADD_IN_INCR) { + rb->SendDouble(add_result->new_score); + } else { + rb->SendLong(add_result->num_updated); + } + } } -void ZSetFamily::ZAdd(CmdArgList args, const CommandContext& cmd_cntx) { +OpResult ZSetFamily::ZGetMembers(CmdArgList args, Transaction* tx, + SinkReplyBuilder* builder) { string_view key = ArgS(args, 0); + auto members = args.subspan(1); + auto cb = [key, members](Transaction* t, EngineShard* shard) { + return OpMScore(t->GetOpArgs(shard), key, members); + }; - ZParams zparams; - size_t i = 1; - for (; i < args.size() - 1; ++i) { - string cur_arg = absl::AsciiStrToUpper(ArgS(args, i)); + return tx->ScheduleSingleHopT(std::move(cb)); +} - if (cur_arg == "XX") { - zparams.flags |= ZADD_IN_XX; // update only - } else if (cur_arg == "NX") { - zparams.flags |= ZADD_IN_NX; // add new only. - } else if (cur_arg == "GT") { - zparams.flags |= ZADD_IN_GT; +auto ZSetFamily::OpRanges(const std::vector& range_specs, + const OpArgs& op_args, string_view key) -> OpResult> { + auto res_it = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); + if (!res_it) + return res_it.status(); + + // Action::RANGE is read-only, but requires mutable pointer, thus const_cast + PrimeValue& pv = const_cast(res_it.value()->second); + vector result_arrays; + for (auto& range_spec : range_specs) { + IntervalVisitor iv{Action::RANGE, range_spec.params, &pv}; + std::visit(iv, range_spec.interval); + result_arrays.push_back(iv.PopResult()); + } + + return result_arrays; +} + +OpResult ZSetFamily::OpAdd(const OpArgs& op_args, + const ZSetFamily::ZParams& zparams, + string_view key, ScoredMemberSpan members) { + DCHECK(!members.empty() || zparams.override); + auto& db_slice = op_args.GetDbSlice(); + + if (zparams.override && members.empty()) { + auto it = db_slice.FindMutable(op_args.db_cntx, key).it; // post_updater will run immediately + if (IsValid(it)) { + db_slice.Del(op_args.db_cntx, it); + } + return OpStatus::OK; + } + + // When we have too many members to add, make sure field_len is large enough to use + // skiplist encoding. + size_t field_len = members.size() > server.zset_max_listpack_entries + ? UINT32_MAX + : members.front().second.size(); + auto res_it = FindZEntry(zparams, op_args, key, field_len); + + if (!res_it) + return res_it.status(); + + unsigned added = 0; + unsigned updated = 0; + + double new_score = 0; + int retflags = 0; + + OpStatus op_status = OpStatus::OK; + AddResult aresult; + detail::RobjWrapper* robj_wrapper = res_it->it->second.GetRobjWrapper(); + bool is_list_pack = IsListPack(robj_wrapper); + + // opportunistically reserve space if multiple entries are about to be added. + if ((zparams.flags & ZADD_IN_XX) == 0 && members.size() > 2) { + if (is_list_pack) { + uint8_t* zl = (uint8_t*)robj_wrapper->inner_obj(); + size_t malloc_reserved = zmalloc_size(zl); + size_t min_sz = EstimateListpackMinBytes(members); + if (min_sz > malloc_reserved) { + zl = (uint8_t*)zrealloc(zl, min_sz); + robj_wrapper->set_inner_obj(zl); + } + } else { + detail::SortedMap* sm = (detail::SortedMap*)robj_wrapper->inner_obj(); + sm->Reserve(members.size()); + } + } + + for (size_t j = 0; j < members.size(); j++) { + const auto& m = members[j]; + int retval = + robj_wrapper->ZsetAdd(m.first, WrapSds(m.second), zparams.flags, &retflags, &new_score); + + if (zparams.flags & ZADD_IN_INCR) { + if (retval == 0) { + CHECK_EQ(1u, members.size()); + + aresult.is_nan = true; + break; + } + + if (retflags & ZADD_OUT_NOP) { + op_status = OpStatus::SKIPPED; + } + } + + if (retflags & ZADD_OUT_ADDED) + added++; + if (retflags & ZADD_OUT_UPDATED) + updated++; + } + + // if we migrated to skip_list - update listpack stats. + if (is_list_pack && !IsListPack(robj_wrapper)) { + DbTableStats* stats = db_slice.MutableStats(op_args.db_cntx.db_index); + --stats->listpack_blob_cnt; + } + + if (zparams.flags & ZADD_IN_INCR) { + aresult.new_score = new_score; + } else { + aresult.num_updated = zparams.ch ? added + updated : added; + } + + if (op_status != OpStatus::OK) + return op_status; + return aresult; +} + +OpResult ZSetFamily::OpKeyExisted(const OpArgs& op_args, string_view key) { + auto res_it = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); + return res_it.status(); +} + +OpResult ZSetFamily::OpScore(const OpArgs& op_args, string_view key, string_view member) { + auto res_it = op_args.GetDbSlice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); + if (!res_it) + return res_it.status(); + + const PrimeValue& pv = res_it.value()->second; + const detail::RobjWrapper* robj_wrapper = pv.GetRobjWrapper(); + auto res = GetZsetScore(robj_wrapper, WrapSds(member)); + if (!res) { + return OpStatus::MEMBER_NOTFOUND; + } + return *res; +} + +void ZSetFamily::BZPopMin(CmdArgList args, const CommandContext& cmd_cntx) { + BZPopMinMax(args, cmd_cntx.tx, cmd_cntx.rb, cmd_cntx.conn_cntx, false); +} + +void ZSetFamily::BZPopMax(CmdArgList args, const CommandContext& cmd_cntx) { + BZPopMinMax(args, cmd_cntx.tx, cmd_cntx.rb, cmd_cntx.conn_cntx, true); +} + +void ZSetFamily::ZAdd(CmdArgList args, const CommandContext& cmd_cntx) { + string_view key = ArgS(args, 0); + + ZSetFamily::ZParams zparams; + size_t i = 1; + for (; i < args.size() - 1; ++i) { + string cur_arg = absl::AsciiStrToUpper(ArgS(args, i)); + + if (cur_arg == "XX") { + zparams.flags |= ZADD_IN_XX; // update only + } else if (cur_arg == "NX") { + zparams.flags |= ZADD_IN_NX; // add new only. + } else if (cur_arg == "GT") { + zparams.flags |= ZADD_IN_GT; } else if (cur_arg == "LT") { zparams.flags |= ZADD_IN_LT; } else if (cur_arg == "CH") { @@ -2374,7 +2242,7 @@ void ZSetFamily::ZIncrBy(CmdArgList args, const CommandContext& cmd_cntx) { return rb->SendError(kScoreNaN); } - ZParams zparams; + ZSetFamily::ZParams zparams; zparams.flags = ZADD_IN_INCR; auto cb = [&](Transaction* t, EngineShard* shard) { @@ -2761,602 +2629,6 @@ void ZSetFamily::ZUnionStore(CmdArgList args, const CommandContext& cmd_cntx) { ZBooleanOperation(args, "zunionstore", true, true, cmd_cntx.tx, cmd_cntx.rb); } -void ZSetFamily::GeoAdd(CmdArgList args, const CommandContext& cmd_cntx) { - string_view key = ArgS(args, 0); - - ZParams zparams; - size_t i = 1; - for (; i < args.size(); ++i) { - string cur_arg = absl::AsciiStrToUpper(ArgS(args, i)); - - if (cur_arg == "XX") { - zparams.flags |= ZADD_IN_XX; // update only - } else if (cur_arg == "NX") { - zparams.flags |= ZADD_IN_NX; // add new only. - } else if (cur_arg == "CH") { - zparams.ch = true; - } else { - break; - } - } - - auto* builder = cmd_cntx.rb; - args.remove_prefix(i); - if (args.empty() || args.size() % 3 != 0) { - builder->SendError(kSyntaxErr); - return; - } - - if ((zparams.flags & ZADD_IN_NX) && (zparams.flags & ZADD_IN_XX)) { - builder->SendError(kNxXxErr); - return; - } - - absl::InlinedVector members; - for (i = 0; i < args.size(); i += 3) { - string_view longitude = ArgS(args, i); - string_view latitude = ArgS(args, i + 1); - string_view member = ArgS(args, i + 2); - - pair longlat; - - if (!ParseLongLat(longitude, latitude, &longlat)) { - string err = absl::StrCat("-ERR invalid longitude,latitude pair ", longitude, ",", latitude, - ",", member); - - return builder->SendError(err, kSyntaxErrType); - } - - /* Turn the coordinates into the score of the element. */ - GeoHashBits hash; - geohashEncodeWGS84(longlat.first, longlat.second, GEO_STEP_MAX, &hash); - GeoHashFix52Bits bits = geohashAlign52Bits(hash); - - members.emplace_back(bits, member); - } - DCHECK(cmd_cntx.tx); - - absl::Span memb_sp{members.data(), members.size()}; - ZAddGeneric(key, zparams, memb_sp, cmd_cntx.tx, builder); -} - -void ZSetFamily::GeoHash(CmdArgList args, const CommandContext& cmd_cntx) { - auto* rb = static_cast(cmd_cntx.rb); - - OpResult result = ZGetMembers(args, cmd_cntx.tx, rb); - - if (result.status() == OpStatus::WRONG_TYPE) { - return rb->SendError(kWrongTypeErr); - } - - rb->StartArray(result->size()); // Array return type. - const MScoreResponse& arr = result.value(); - - array buf; - for (const auto& p : arr) { - if (ToAsciiGeoHash(p, &buf)) { - rb->SendBulkString(string_view{buf.data(), buf.size() - 1}); - } else { - rb->SendNull(); - } - } -} - -void ZSetFamily::GeoPos(CmdArgList args, const CommandContext& cmd_cntx) { - auto* rb = static_cast(cmd_cntx.rb); - - OpResult result = ZGetMembers(args, cmd_cntx.tx, rb); - - if (result.status() != OpStatus::OK) { - return rb->SendError(result.status()); - } - - rb->StartArray(result->size()); // Array return type. - const MScoreResponse& arr = result.value(); - - double xy[2]; - for (const auto& p : arr) { - if (ScoreToLongLat(p, xy)) { - rb->StartArray(2); - rb->SendDouble(xy[0]); - rb->SendDouble(xy[1]); - } else { - rb->SendNull(); - } - } -} - -void ZSetFamily::GeoDist(CmdArgList args, const CommandContext& cmd_cntx) { - double distance_multiplier = 1; - auto* rb = static_cast(cmd_cntx.rb); - - if (args.size() == 4) { - string_view unit = ArgS(args, 3); - distance_multiplier = ExtractUnit(unit); - args.remove_suffix(1); - if (distance_multiplier < 0) { - return rb->SendError("unsupported unit provided. please use M, KM, FT, MI"); - } - } else if (args.size() != 3) { - return rb->SendError(kSyntaxErr); - } - - OpResult result = ZGetMembers(args, cmd_cntx.tx, rb); - - if (result.status() != OpStatus::OK) { - return rb->SendError(result.status()); - } - - const MScoreResponse& arr = result.value(); - - if (arr.size() != 2) { - return rb->SendError(kSyntaxErr); - } - - double xyxy[4]; // 2 pairs of score holding 2 locations - for (size_t i = 0; i < arr.size(); i++) { - if (!ScoreToLongLat(arr[i], xyxy + (i * 2))) { - return rb->SendNull(); - } - } - - return rb->SendDouble(geohashGetDistance(xyxy[0], xyxy[1], xyxy[2], xyxy[3]) / - distance_multiplier); -} - -namespace { -std::vector GetGeoRangeSpec(const GeoHashRadius& n) { - array neighbors; - unsigned int last_processed = 0; - - neighbors[0] = n.hash; - neighbors[1] = n.neighbors.north; - neighbors[2] = n.neighbors.south; - neighbors[3] = n.neighbors.east; - neighbors[4] = n.neighbors.west; - neighbors[5] = n.neighbors.north_east; - neighbors[6] = n.neighbors.north_west; - neighbors[7] = n.neighbors.south_east; - neighbors[8] = n.neighbors.south_west; - - // Get range_specs for neighbors (*and* our own hashbox) - std::vector range_specs; - for (unsigned int i = 0; i < neighbors.size(); i++) { - if (HASHISZERO(neighbors[i])) { - continue; - } - - // When a huge Radius (in the 5000 km range or more) is used, - // adjacent neighbors can be the same, leading to duplicated - // elements. Skip every range which is the same as the one - // processed previously. - if (last_processed && neighbors[i].bits == neighbors[last_processed].bits && - neighbors[i].step == neighbors[last_processed].step) { - continue; - } - - GeoHashFix52Bits min, max; - scoresOfGeoHashBox(neighbors[i], &min, &max); - - ZSetFamily::ScoreInterval si; - si.first = ZSetFamily::Bound{static_cast(min), false}; - si.second = ZSetFamily::Bound{static_cast(max), true}; - - ZSetFamily::RangeParams range_params; - range_params.interval_type = ZSetFamily::RangeParams::IntervalType::SCORE; - range_params.with_scores = true; - range_specs.emplace_back(si, range_params); - - last_processed = i; - } - return range_specs; -} - -void SortIfNeeded(GeoArray* ga, Sorting sorting, uint64_t count) { - if (sorting == Sorting::kUnsorted) - return; - - auto comparator = [&](const GeoPoint& a, const GeoPoint& b) { - if (sorting == Sorting::kAsc) { - return a.dist < b.dist; - } else { - DCHECK(sorting == Sorting::kDesc); - return a.dist > b.dist; - } - }; - - if (count > 0) { - std::partial_sort(ga->begin(), ga->begin() + count, ga->end(), comparator); - ga->resize(count); - } else { - std::sort(ga->begin(), ga->end(), comparator); - } -} - -void GeoSearchStoreGeneric(Transaction* tx, SinkReplyBuilder* builder, const GeoShape& shape_ref, - string_view key, string_view member, const GeoSearchOpts& geo_ops) { - GeoShape* shape = &(const_cast(shape_ref)); - auto* rb = static_cast(builder); - - ShardId from_shard = Shard(key, shard_set->size()); - - if (!member.empty()) { - // get shape.xy from member - OpResult member_score; - auto cb = [&](Transaction* t, EngineShard* shard) { - if (shard->shard_id() == from_shard) { - member_score = OpScore(t->GetOpArgs(shard), key, member); - } - return OpStatus::OK; - }; - tx->Execute(std::move(cb), false); - auto member_sts = member_score.status(); - if (member_sts != OpStatus::OK) { - tx->Conclude(); - switch (member_sts) { - case OpStatus::WRONG_TYPE: - return builder->SendError(kWrongTypeErr); - case OpStatus::KEY_NOTFOUND: - return rb->StartArray(0); - case OpStatus::MEMBER_NOTFOUND: - return builder->SendError(kMemberNotFound); - default: - return builder->SendError(member_sts); - } - } - ScoreToLongLat(*member_score, shape->xy); - } else { - // verify key is valid - OpResult result; - auto cb = [&](Transaction* t, EngineShard* shard) { - if (shard->shard_id() == from_shard) { - result = OpKeyExisted(t->GetOpArgs(shard), key); - } - return OpStatus::OK; - }; - tx->Execute(std::move(cb), false); - auto result_sts = result.status(); - if (result_sts != OpStatus::OK) { - tx->Conclude(); - switch (result_sts) { - case OpStatus::WRONG_TYPE: - return builder->SendError(kWrongTypeErr); - case OpStatus::KEY_NOTFOUND: - return rb->StartArray(0); - default: - return builder->SendError(result_sts); - } - } - } - DCHECK(shape->xy[0] >= -180.0 && shape->xy[0] <= 180.0); - DCHECK(shape->xy[1] >= -90.0 && shape->xy[1] <= 90.0); - - // query - GeoHashRadius georadius = geohashCalculateAreasByShapeWGS84(shape); - GeoArray ga; - auto range_specs = GetGeoRangeSpec(georadius); - // get all the matching members and add them to the potential result list - vector>> result_arrays; - auto cb = [&](Transaction* t, EngineShard* shard) { - auto res_it = OpRanges(range_specs, t->GetOpArgs(shard), key); - if (res_it) { - result_arrays.emplace_back(res_it); - } - return OpStatus::OK; - }; - - tx->Execute(std::move(cb), geo_ops.store == GeoStoreType::kNoStore); - - // filter potential result list - double xy[2]; - double distance; - unsigned long limit = geo_ops.any ? geo_ops.count : 0; - for (auto& result_array : result_arrays) { - for (auto& arr : *result_array) { - for (auto& p : arr) { - if (geoWithinShape(shape, p.second, xy, &distance) == 0) { - ga.emplace_back(xy[0], xy[1], distance, p.second, p.first); - if (limit > 0 && ga.size() >= limit) - break; - } - } - } - } - - // sort and trim by count - SortIfNeeded(&ga, geo_ops.sorting, geo_ops.count); - - if (geo_ops.store == GeoStoreType::kNoStore) { - // case 1: read mode - // case 2: write mode, kNoStore - // generate reply array withdist, withcoords, withhash - int record_size = 1; - if (geo_ops.withdist) { - record_size++; - } - if (geo_ops.withhash) { - record_size++; - } - if (geo_ops.withcoord) { - record_size++; - } - rb->StartArray(ga.size()); - for (const auto& p : ga) { - // [member, dist, x, y, hash] - if (geo_ops.HasWithStatement()) { - rb->StartArray(record_size); - } - rb->SendBulkString(p.member); - if (geo_ops.withdist) { - rb->SendDouble(p.dist / geo_ops.conversion); - } - if (geo_ops.withhash) { - rb->SendDouble(p.score); - } - if (geo_ops.withcoord) { - rb->StartArray(2); - rb->SendDouble(p.longitude); - rb->SendDouble(p.latitude); - } - } - } else { - // case 3: write mode, !kNoStore - DCHECK(geo_ops.store == GeoStoreType::kStoreDist || geo_ops.store == GeoStoreType::kStoreHash); - ShardId dest_shard = Shard(geo_ops.store_key, shard_set->size()); - DVLOG(1) << "store shard:" << dest_shard << ", key " << geo_ops.store_key; - AddResult add_result; - vector smvec; - for (const auto& p : ga) { - if (geo_ops.store == GeoStoreType::kStoreDist) { - smvec.emplace_back(p.dist / geo_ops.conversion, p.member); - } else { - DCHECK(geo_ops.store == GeoStoreType::kStoreHash); - smvec.emplace_back(p.score, p.member); - } - } - - auto store_cb = [&](Transaction* t, EngineShard* shard) { - if (shard->shard_id() == dest_shard) { - ZParams zparams; - zparams.override = true; - add_result = - OpAdd(t->GetOpArgs(shard), zparams, geo_ops.store_key, ScoredMemberSpan{smvec}).value(); - } - return OpStatus::OK; - }; - tx->Execute(std::move(store_cb), true); - - rb->SendLong(smvec.size()); - } -} -} // namespace - -void ZSetFamily::GeoSearch(CmdArgList args, const CommandContext& cmd_cntx) { - // parse arguments - string_view key = ArgS(args, 0); - GeoShape shape = {}; - GeoSearchOpts geo_ops; - string_view member; - - // FROMMEMBER or FROMLONLAT is set - bool from_set = false; - // BYRADIUS or BYBOX is set - bool by_set = false; - auto* builder = cmd_cntx.rb; - - for (size_t i = 1; i < args.size(); ++i) { - string cur_arg = absl::AsciiStrToUpper(ArgS(args, i)); - - if (cur_arg == "FROMMEMBER") { - if (from_set) { - return builder->SendError(kFromMemberLonglatErr); - } else if (i + 1 < args.size()) { - member = ArgS(args, i + 1); - from_set = true; - i++; - } else { - return builder->SendError(kSyntaxErr); - } - } else if (cur_arg == "FROMLONLAT") { - if (from_set) { - return builder->SendError(kFromMemberLonglatErr); - } else if (i + 2 < args.size()) { - string_view longitude_str = ArgS(args, i + 1); - string_view latitude_str = ArgS(args, i + 2); - pair longlat; - if (!ParseLongLat(longitude_str, latitude_str, &longlat)) { - string err = absl::StrCat("-ERR invalid longitude,latitude pair ", longitude_str, ",", - latitude_str); - return builder->SendError(err, kSyntaxErrType); - } - shape.xy[0] = longlat.first; - shape.xy[1] = longlat.second; - from_set = true; - i += 2; - } else { - return builder->SendError(kSyntaxErr); - } - } else if (cur_arg == "BYRADIUS") { - if (by_set) { - return builder->SendError(kByRadiusBoxErr); - } else if (i + 2 < args.size()) { - if (!ParseDouble(ArgS(args, i + 1), &shape.t.radius)) { - return builder->SendError(kInvalidFloatErr); - } - string_view unit = ArgS(args, i + 2); - shape.conversion = ExtractUnit(unit); - geo_ops.conversion = shape.conversion; - if (shape.conversion == -1) { - return builder->SendError("unsupported unit provided. please use M, KM, FT, MI"); - } - shape.type = CIRCULAR_TYPE; - by_set = true; - i += 2; - } else { - return builder->SendError(kSyntaxErr); - } - } else if (cur_arg == "BYBOX") { - if (by_set) { - return builder->SendError(kByRadiusBoxErr); - } else if (i + 3 < args.size()) { - if (!ParseDouble(ArgS(args, i + 1), &shape.t.r.width)) { - return builder->SendError(kInvalidFloatErr); - } - if (!ParseDouble(ArgS(args, i + 2), &shape.t.r.height)) { - return builder->SendError(kInvalidFloatErr); - } - string_view unit = ArgS(args, i + 3); - shape.conversion = ExtractUnit(unit); - geo_ops.conversion = shape.conversion; - if (shape.conversion == -1) { - return builder->SendError("unsupported unit provided. please use M, KM, FT, MI"); - } - shape.type = RECTANGLE_TYPE; - by_set = true; - i += 3; - } else { - return builder->SendError(kSyntaxErr); - } - } else if (cur_arg == "ASC") { - if (geo_ops.sorting != Sorting::kUnsorted) { - return builder->SendError(kAscDescErr); - } else { - geo_ops.sorting = Sorting::kAsc; - } - } else if (cur_arg == "DESC") { - if (geo_ops.sorting != Sorting::kUnsorted) { - return builder->SendError(kAscDescErr); - } else { - geo_ops.sorting = Sorting::kDesc; - } - } else if (cur_arg == "COUNT") { - if (i + 1 < args.size() && absl::SimpleAtoi(ArgS(args, i + 1), &geo_ops.count)) { - i++; - } else { - return builder->SendError(kSyntaxErr); - } - if (i + 1 < args.size() && ArgS(args, i + 1) == "ANY") { - geo_ops.any = true; - i++; - } - } else if (cur_arg == "WITHCOORD") { - geo_ops.withcoord = true; - } else if (cur_arg == "WITHDIST") { - geo_ops.withdist = true; - } else if (cur_arg == "WITHHASH") { - geo_ops.withhash = true; - } else { - return builder->SendError(kSyntaxErr); - } - } - - // check mandatory options - if (!from_set) { - return builder->SendError(kSyntaxErr); - } - if (!by_set) { - return builder->SendError(kSyntaxErr); - } - // parsing completed - - GeoSearchStoreGeneric(cmd_cntx.tx, builder, shape, key, member, geo_ops); -} - -void ZSetFamily::GeoRadiusByMember(CmdArgList args, const CommandContext& cmd_cntx) { - GeoShape shape = {}; - GeoSearchOpts geo_ops; - // parse arguments - string_view key = ArgS(args, 0); - // member to latlong, set shape.xy - string_view member = ArgS(args, 1); - - auto* builder = cmd_cntx.rb; - if (!ParseDouble(ArgS(args, 2), &shape.t.radius)) { - return builder->SendError(kInvalidFloatErr); - } - string_view unit = ArgS(args, 3); - shape.conversion = ExtractUnit(unit); - geo_ops.conversion = shape.conversion; - if (shape.conversion == -1) { - return builder->SendError("unsupported unit provided. please use M, KM, FT, MI"); - } - shape.type = CIRCULAR_TYPE; - - for (size_t i = 4; i < args.size(); ++i) { - string cur_arg = absl::AsciiStrToUpper(ArgS(args, i)); - - if (cur_arg == "ASC") { - if (geo_ops.sorting != Sorting::kUnsorted) { - return builder->SendError(kAscDescErr); - } else { - geo_ops.sorting = Sorting::kAsc; - } - } else if (cur_arg == "DESC") { - if (geo_ops.sorting != Sorting::kUnsorted) { - return builder->SendError(kAscDescErr); - } else { - geo_ops.sorting = Sorting::kDesc; - } - } else if (cur_arg == "COUNT") { - if (i + 1 < args.size() && absl::SimpleAtoi(ArgS(args, i + 1), &geo_ops.count)) { - i++; - } else { - return builder->SendError(kSyntaxErr); - } - if (i + 1 < args.size() && ArgS(args, i + 1) == "ANY") { - geo_ops.any = true; - i++; - } - } else if (cur_arg == "WITHCOORD") { - if (geo_ops.store != GeoStoreType::kNoStore) { - return builder->SendError(kStoreCompatErr); - } - geo_ops.withcoord = true; - } else if (cur_arg == "WITHDIST") { - if (geo_ops.store != GeoStoreType::kNoStore) { - return builder->SendError(kStoreCompatErr); - } - geo_ops.withdist = true; - } else if (cur_arg == "WITHHASH") { - if (geo_ops.store != GeoStoreType::kNoStore) { - return builder->SendError(kStoreCompatErr); - } - geo_ops.withhash = true; - } else if (cur_arg == "STORE") { - if (geo_ops.store != GeoStoreType::kNoStore) { - return builder->SendError(kStoreTypeErr); - } else if (geo_ops.withcoord || geo_ops.withdist || geo_ops.withhash) { - return builder->SendError(kStoreCompatErr); - } - if (i + 1 < args.size()) { - geo_ops.store_key = ArgS(args, i + 1); - geo_ops.store = GeoStoreType::kStoreHash; - i++; - } else { - return builder->SendError(kSyntaxErr); - } - } else if (cur_arg == "STOREDIST") { - if (geo_ops.store != GeoStoreType::kNoStore) { - return builder->SendError(kStoreTypeErr); - } else if (geo_ops.withcoord || geo_ops.withdist || geo_ops.withhash) { - return builder->SendError(kStoreCompatErr); - } - if (i + 1 < args.size()) { - geo_ops.store_key = ArgS(args, i + 1); - geo_ops.store = GeoStoreType::kStoreDist; - i++; - } else { - return builder->SendError(kSyntaxErr); - } - } else { - return builder->SendError(kSyntaxErr); - } - } - // parsing completed - - GeoSearchStoreGeneric(cmd_cntx.tx, builder, shape, key, member, geo_ops); -} - #define HFUNC(x) SetHandler(&ZSetFamily::x) namespace acl { @@ -3393,12 +2665,6 @@ constexpr uint32_t kZRevRank = READ | SORTEDSET | FAST; constexpr uint32_t kZScan = READ | SORTEDSET | SLOW; constexpr uint32_t kZUnion = READ | SORTEDSET | SLOW; constexpr uint32_t kZUnionStore = WRITE | SORTEDSET | SLOW; -constexpr uint32_t kGeoAdd = WRITE | GEO | SLOW; -constexpr uint32_t kGeoHash = READ | GEO | SLOW; -constexpr uint32_t kGeoPos = READ | GEO | SLOW; -constexpr uint32_t kGeoDist = READ | GEO | SLOW; -constexpr uint32_t kGeoSearch = READ | GEO | SLOW; -constexpr uint32_t kGeoRadiusByMember = WRITE | GEO | SLOW; } // namespace acl void ZSetFamily::Register(CommandRegistry* registry) { @@ -3445,16 +2711,7 @@ void ZSetFamily::Register(CommandRegistry* registry) { << CI{"ZREVRANK", CO::READONLY | CO::FAST, -3, 1, 1, acl::kZRevRank}.HFUNC(ZRevRank) << CI{"ZSCAN", CO::READONLY, -3, 1, 1, acl::kZScan}.HFUNC(ZScan) << CI{"ZUNION", CO::READONLY | CO::VARIADIC_KEYS, -3, 2, 2, acl::kZUnion}.HFUNC(ZUnion) - << CI{"ZUNIONSTORE", kStoreMask, -4, 3, 3, acl::kZUnionStore}.HFUNC(ZUnionStore) - - // GEO functions - << CI{"GEOADD", CO::FAST | CO::WRITE | CO::DENYOOM, -5, 1, 1, acl::kGeoAdd}.HFUNC(GeoAdd) - << CI{"GEOHASH", CO::FAST | CO::READONLY, -2, 1, 1, acl::kGeoHash}.HFUNC(GeoHash) - << CI{"GEOPOS", CO::FAST | CO::READONLY, -2, 1, 1, acl::kGeoPos}.HFUNC(GeoPos) - << CI{"GEODIST", CO::READONLY, -4, 1, 1, acl::kGeoDist}.HFUNC(GeoDist) - << CI{"GEOSEARCH", CO::READONLY, -4, 1, 1, acl::kGeoSearch}.HFUNC(GeoSearch) - << CI{"GEORADIUSBYMEMBER", CO::WRITE | CO::STORE_LAST_KEY, -4, 1, 1, acl::kGeoRadiusByMember} - .HFUNC(GeoRadiusByMember); + << CI{"ZUNIONSTORE", kStoreMask, -4, 3, 3, acl::kZUnionStore}.HFUNC(ZUnionStore); } } // namespace dfly diff --git a/src/server/zset_family.h b/src/server/zset_family.h index 17d4eceb24ad..5d30ef0deadb 100644 --- a/src/server/zset_family.h +++ b/src/server/zset_family.h @@ -1,9 +1,10 @@ -// Copyright 2022, DragonflyDB authors. All rights reserved. +// Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once +#include #include #include "facade/op_status.h" @@ -17,6 +18,8 @@ namespace dfly { class CommandRegistry; struct CommandContext; +class Transaction; +class OpArgs; class ZSetFamily { public: @@ -57,10 +60,45 @@ class ZSetFamily { ZRangeSpec(const ScoreInterval& si, const RangeParams& rp) : interval(si), params(rp){}; }; - private: - template using OpResult = facade::OpResult; + struct ZParams { + unsigned flags = 0; // mask of ZADD_IN_ macros. + bool ch = false; // Corresponds to CH option. + bool override = false; + }; + + using ScoredMember = std::pair; + using ScoredArray = std::vector; + using ScoredMemberView = std::pair; + using ScoredMemberSpan = absl::Span; + using SinkReplyBuilder = facade::SinkReplyBuilder; + template using OpResult = facade::OpResult; + + static void ZAddGeneric(std::string_view key, const ZParams& zparams, ScoredMemberSpan memb_sp, + Transaction* tx, SinkReplyBuilder* builder); + + static OpResult ZGetMembers(CmdArgList args, Transaction* tx, + SinkReplyBuilder* builder); + static OpResult> OpRanges(const std::vector& range_specs, + const OpArgs& op_args, std::string_view key); + + struct AddResult { + double new_score = 0; + unsigned num_updated = 0; + + bool is_nan = false; + }; + + static OpResult OpAdd(const OpArgs& op_args, const ZParams& zparams, + std::string_view key, ScoredMemberSpan members); + + static OpResult OpKeyExisted(const OpArgs& op_args, std::string_view key); + + static OpResult OpScore(const OpArgs& op_args, std::string_view key, + std::string_view member); + + private: static void BZPopMin(CmdArgList args, const CommandContext& cmd_cntx); static void BZPopMax(CmdArgList args, const CommandContext& cmd_cntx); static void ZAdd(CmdArgList args, const CommandContext& cmd_cntx); @@ -94,12 +132,6 @@ class ZSetFamily { static void ZScan(CmdArgList args, const CommandContext& cmd_cntx); static void ZUnion(CmdArgList args, const CommandContext& cmd_cntx); static void ZUnionStore(CmdArgList args, const CommandContext& cmd_cntx); - static void GeoAdd(CmdArgList args, const CommandContext& cmd_cntx); - static void GeoHash(CmdArgList args, const CommandContext& cmd_cntx); - static void GeoPos(CmdArgList args, const CommandContext& cmd_cntx); - static void GeoDist(CmdArgList args, const CommandContext& cmd_cntx); - static void GeoSearch(CmdArgList args, const CommandContext& cmd_cntx); - static void GeoRadiusByMember(CmdArgList args, const CommandContext& cmd_cntx); }; } // namespace dfly diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index a08b07242787..cc22c8071800 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -8,6 +8,7 @@ #include "base/logging.h" #include "facade/facade_test.h" #include "server/command_registry.h" +#include "server/geo_family.h" #include "server/test_utils.h" using namespace testing;