Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/commands/cmd_tdigest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,57 @@ class CommandTDigestMerge : public Commander {
TDigestMergeOptions options_;
};

class CommandTDigestTrimmedMean : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
if (args.size() != 4) {
return {Status::RedisParseErr, errWrongNumOfArguments};
}

key_name_ = args[1];

auto low_cut_quantile = ParseFloat(args[2]);
if (!low_cut_quantile) {
return {Status::RedisParseErr, errValueIsNotFloat};
}
low_cut_quantile_ = *low_cut_quantile;

auto high_cut_quantile = ParseFloat(args[3]);
if (!high_cut_quantile) {
return {Status::RedisParseErr, errValueIsNotFloat};
}
high_cut_quantile_ = *high_cut_quantile;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the validation of high_cut_quantile and low_cut_quantile.
The parameter validation should be done in the earliest step rather than in the command processing.


return Status::OK();
}

Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override {
TDigest tdigest(srv->storage, conn->GetNamespace());
TDigestTrimmedMeanResult result;

auto s = tdigest.TrimmedMean(ctx, key_name_, low_cut_quantile_, high_cut_quantile_, &result);
if (!s.ok()) {
if (s.IsNotFound()) {
return {Status::RedisExecErr, errKeyNotFound};
}
return {Status::RedisExecErr, s.ToString()};
}

if (!result.mean.has_value()) {
*output = redis::BulkString(kNan);
} else {
*output = redis::BulkString(util::Float2String(*result.mean));
}

return Status::OK();
}

private:
std::string key_name_;
double low_cut_quantile_;
double high_cut_quantile_;
};

std::vector<CommandKeyRange> GetMergeKeyRange(const std::vector<std::string> &args) {
auto numkeys = ParseInt<int>(args[2], 10).ValueOr(0);
return {{1, 1, 1}, {3, 2 + numkeys, 1}};
Expand All @@ -507,6 +558,7 @@ REDIS_REGISTER_COMMANDS(TDigest, MakeCmdAttr<CommandTDigestCreate>("tdigest.crea
MakeCmdAttr<CommandTDigestByRevRank>("tdigest.byrevrank", -3, "read-only", 1, 1, 1),
MakeCmdAttr<CommandTDigestByRank>("tdigest.byrank", -3, "read-only", 1, 1, 1),
MakeCmdAttr<CommandTDigestQuantile>("tdigest.quantile", -3, "read-only", 1, 1, 1),
MakeCmdAttr<CommandTDigestTrimmedMean>("tdigest.trimmed_mean", 4, "read-only", 1, 1, 1),
MakeCmdAttr<CommandTDigestReset>("tdigest.reset", 2, "write", 1, 1, 1),
MakeCmdAttr<CommandTDigestMerge>("tdigest.merge", -4, "write", GetMergeKeyRange));
} // namespace redis
35 changes: 35 additions & 0 deletions src/types/redis_tdigest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,41 @@ rocksdb::Status TDigest::applyNewCentroids(ObserverOrUniquePtr<rocksdb::WriteBat
return rocksdb::Status::OK();
}

rocksdb::Status TDigest::TrimmedMean(engine::Context& ctx, const Slice& digest_name, double low_cut_quantile,
double high_cut_quantile, TDigestTrimmedMeanResult* result) {
auto ns_key = AppendNamespacePrefix(digest_name);
TDigestMetadata metadata;

{
LockGuard guard(storage_->GetLockManager(), ns_key);
if (auto status = getMetaDataByNsKey(ctx, ns_key, &metadata); !status.ok()) {
return status;
}

if (metadata.total_observations == 0) {
return rocksdb::Status::OK();
}

if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) {
return status;
}
}

// Dump centroids and create DummyCentroids wrapper for TDigest algorithm
std::vector<Centroid> centroids;
if (auto status = dumpCentroids(ctx, ns_key, metadata, &centroids); !status.ok()) {
return status;
}
auto dump_centroids = DummyCentroids<false>(metadata, centroids);
auto trimmed_mean_result = TDigestTrimmedMean(dump_centroids, low_cut_quantile, high_cut_quantile);
if (!trimmed_mean_result) {
return rocksdb::Status::InvalidArgument(trimmed_mean_result.Msg());
}

result->mean = *trimmed_mean_result;
return rocksdb::Status::OK();
}

std::string TDigest::internalSegmentGuardPrefixKey(const TDigestMetadata& metadata, const std::string& ns_key,
SegmentType seg) const {
std::string prefix_key;
Expand Down
6 changes: 6 additions & 0 deletions src/types/redis_tdigest.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ struct TDigestQuantitleResult {
std::optional<std::vector<double>> quantiles;
};

struct TDigestTrimmedMeanResult {
std::optional<double> mean;
};

class TDigest : public SubKeyScanner {
public:
using Slice = rocksdb::Slice;
Expand Down Expand Up @@ -85,6 +89,8 @@ class TDigest : public SubKeyScanner {
std::vector<double>* result);
rocksdb::Status ByRank(engine::Context& ctx, const Slice& digest_name, const std::vector<int>& inputs,
std::vector<double>* result);
rocksdb::Status TrimmedMean(engine::Context& ctx, const Slice& digest_name, double low_cut_quantile,
double high_cut_quantile, TDigestTrimmedMeanResult* result);
rocksdb::Status GetMetaData(engine::Context& context, const Slice& digest_name, TDigestMetadata* metadata);

private:
Expand Down
62 changes: 62 additions & 0 deletions src/types/tdigest.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,65 @@ inline Status TDigestRank(TD&& td, const std::vector<double>& inputs, std::vecto
}
return Status::OK();
}

template <typename TD>
inline StatusOr<double> TDigestTrimmedMean(TD&& td, double low_cut_quantile, double high_cut_quantile) {
if (td.Size() == 0) {
return Status{Status::InvalidArgument, "empty tdigest"};
}

if (low_cut_quantile < 0.0 || low_cut_quantile > 1.0) {
return Status{Status::InvalidArgument, "low cut quantile must be between 0 and 1"};
}
if (high_cut_quantile < 0.0 || high_cut_quantile > 1.0) {
return Status{Status::InvalidArgument, "high cut quantile must be between 0 and 1"};
}
if (low_cut_quantile >= high_cut_quantile) {
return Status{Status::InvalidArgument, "low cut quantile must be less than high cut quantile"};
}
Comment on lines +319 to +327
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to the command parse step.
We could add a guard here, but the validation should be in parsing step.


double low_boundary = std::numeric_limits<double>::quiet_NaN();
double high_boundary = std::numeric_limits<double>::quiet_NaN();

if (low_cut_quantile == 0.0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a more stable way of comparing doubles.

low_boundary = td.Min();
} else {
auto low_result = TDigestQuantile(td, low_cut_quantile);
if (!low_result) {
return low_result;
}
low_boundary = *low_result;
}

if (high_cut_quantile == 1.0) {
high_boundary = td.Max();
} else {
auto high_result = TDigestQuantile(td, high_cut_quantile);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterate through the whole centroids to get centroids within the boundaries.
TDigestQuantile would return an estimated linear value with solved edge cases rather than real centroids you need.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus, you have iterated the centroids twice after get the quantile.
With directly iteration, just scanning for one time is enough.

if (!high_result) {
return high_result;
}
high_boundary = *high_result;
}

auto iter = td.Begin();
double total_weight_in_range = 0;
double weighted_sum = 0;

while (iter->Valid()) {
auto centroid = GET_OR_RET(iter->GetCentroid());

if ((low_cut_quantile == 0.0 && high_cut_quantile == 1.0) ||
(centroid.mean >= low_boundary && centroid.mean <= high_boundary)) {
total_weight_in_range += centroid.weight;
weighted_sum += centroid.mean * centroid.weight;
}

iter->Next();
}

if (total_weight_in_range == 0) {
Comment on lines +329 to +368
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TDigestTrimmedMean can incorrectly return NaN when the low/high cut boundaries fall between centroid means (e.g., after quantile interpolation). The current logic only includes whole centroids whose mean is within [low_boundary, high_boundary], so it may exclude all centroids even though the quantile range contains weight. Consider computing overlap by weight position (e.g., iterate centroids with cumulative weight and add partial centroid weight for the boundary centroids) rather than filtering by centroid.mean.

Suggested change
double low_boundary = std::numeric_limits<double>::quiet_NaN();
double high_boundary = std::numeric_limits<double>::quiet_NaN();
if (low_cut_quantile == 0.0) {
low_boundary = td.Min();
} else {
auto low_result = TDigestQuantile(td, low_cut_quantile);
if (!low_result) {
return low_result;
}
low_boundary = *low_result;
}
if (high_cut_quantile == 1.0) {
high_boundary = td.Max();
} else {
auto high_result = TDigestQuantile(td, high_cut_quantile);
if (!high_result) {
return high_result;
}
high_boundary = *high_result;
}
auto iter = td.Begin();
double total_weight_in_range = 0;
double weighted_sum = 0;
while (iter->Valid()) {
auto centroid = GET_OR_RET(iter->GetCentroid());
if ((low_cut_quantile == 0.0 && high_cut_quantile == 1.0) ||
(centroid.mean >= low_boundary && centroid.mean <= high_boundary)) {
total_weight_in_range += centroid.weight;
weighted_sum += centroid.mean * centroid.weight;
}
iter->Next();
}
if (total_weight_in_range == 0) {
// First, compute the total weight of the t-digest.
double total_weight = 0.0;
{
auto iter = td.Begin();
while (iter->Valid()) {
auto centroid = GET_OR_RET(iter->GetCentroid());
total_weight += centroid.weight;
iter->Next();
}
}
if (total_weight == 0.0) {
return std::numeric_limits<double>::quiet_NaN();
}
// If no trimming is requested, just return the global weighted mean.
if (low_cut_quantile == 0.0 && high_cut_quantile == 1.0) {
double weighted_sum = 0.0;
auto iter = td.Begin();
while (iter->Valid()) {
auto centroid = GET_OR_RET(iter->GetCentroid());
weighted_sum += centroid.mean * centroid.weight;
iter->Next();
}
return weighted_sum / total_weight;
}
// Compute rank boundaries in weight space.
const double low_rank = low_cut_quantile * total_weight;
const double high_rank = high_cut_quantile * total_weight;
double cumulative_weight = 0.0;
double total_weight_in_range = 0.0;
double weighted_sum = 0.0;
auto iter = td.Begin();
while (iter->Valid()) {
auto centroid = GET_OR_RET(iter->GetCentroid());
const double start_rank = cumulative_weight;
const double end_rank = cumulative_weight + centroid.weight;
// If this centroid is entirely before the trimmed region, skip it.
if (end_rank <= low_rank) {
cumulative_weight = end_rank;
iter->Next();
continue;
}
// If we've passed the trimmed region, we can stop.
if (start_rank >= high_rank) {
break;
}
// Compute overlap of this centroid's weight with [low_rank, high_rank).
double overlap_start = start_rank;
if (overlap_start < low_rank) {
overlap_start = low_rank;
}
double overlap_end = end_rank;
if (overlap_end > high_rank) {
overlap_end = high_rank;
}
const double overlap = overlap_end - overlap_start;
if (overlap > 0.0) {
total_weight_in_range += overlap;
weighted_sum += centroid.mean * overlap;
}
cumulative_weight = end_rank;
iter->Next();
}
if (total_weight_in_range == 0.0) {

Copilot uses AI. Check for mistakes.
return std::numeric_limits<double>::quiet_NaN();
}

return weighted_sum / total_weight_in_range;
}
29 changes: 29 additions & 0 deletions tests/cppunit/types/tdigest_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,32 @@ TEST_F(RedisTDigestTest, ByRank_And_ByRevRank) {
EXPECT_EQ(result[0], 1.0) << "Rank 0 should be minimum";
EXPECT_TRUE(std::isinf(result[3])) << "Rank >= total_weight should be infinity";
}

TEST_F(RedisTDigestTest, TrimmedMean) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add cases for invalid arguments and more unordered and complex inputs.

std::string test_digest_name = "test_digest_trimmed_mean" + std::to_string(util::GetTimeStampMS());
bool exists = false;
auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists);
ASSERT_FALSE(exists);
ASSERT_TRUE(status.ok());

std::vector<double> values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
status = tdigest_->Add(*ctx_, test_digest_name, values);
ASSERT_TRUE(status.ok()) << status.ToString();

redis::TDigestTrimmedMeanResult result;
status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.1, 0.9, &result);
ASSERT_TRUE(status.ok()) << status.ToString();
ASSERT_TRUE(result.mean.has_value());
EXPECT_NEAR(*result.mean, 5.5, 1.0) << "Trimmed mean should be approximately 5.5";

status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.0, 1.0, &result);
ASSERT_TRUE(status.ok()) << status.ToString();
ASSERT_TRUE(result.mean.has_value());
EXPECT_NEAR(*result.mean, 5.5, 0.1) << "Full range should equal complete mean";

status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.25, 0.75, &result);
ASSERT_TRUE(status.ok()) << status.ToString();
ASSERT_TRUE(result.mean.has_value());
EXPECT_GT(*result.mean, 3.0) << "Trimmed mean should be greater than 3.0";
EXPECT_LT(*result.mean, 8.0) << "Trimmed mean should be less than 8.0";
}
95 changes: 95 additions & 0 deletions tests/gocase/unit/type/tdigest/tdigest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,101 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) {
require.EqualValues(t, expected[i], rank, "REVRANK mismatch at index %d", i)
}
})

t.Run("TDIGEST.TRIMMED_MEAN with non-existent key", func(t *testing.T) {
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", "nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist)
})

t.Run("TDIGEST.TRIMMED_MEAN with empty tdigest", func(t *testing.T) {
emptyKey := "tdigest_empty"
require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey, "compression", "100").Err())

result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", "0.9")
require.NoError(t, result.Err())
require.Equal(t, "nan", result.Val())
})

t.Run("TDIGEST.TRIMMED_MEAN with basic data set", func(t *testing.T) {
key := "tdigest_basic"
require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err())
require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err())

result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
require.NoError(t, result.Err())
mean, err := strconv.ParseFloat(result.Val().(string), 64)
require.NoError(t, err)
require.InDelta(t, 5.5, mean, 1.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the delta 1.0 too large for this test case?

})

t.Run("TDIGEST.TRIMMED_MEAN with no trimming", func(t *testing.T) {
key := "tdigest_no_trim"
require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err())
require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err())

result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0", "1")
require.NoError(t, result.Err())
mean, err := strconv.ParseFloat(result.Val().(string), 64)
require.NoError(t, err)
require.InDelta(t, 5.5, mean, 0.1)
})

t.Run("TDIGEST.TRIMMED_MEAN with skewed data", func(t *testing.T) {
key := "tdigest_skewed"
require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err())
require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "1", "1", "1", "1", "10", "100").Err())

result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.2", "0.8")
require.NoError(t, result.Err())
mean, err := strconv.ParseFloat(result.Val().(string), 64)
require.NoError(t, err)
require.Less(t, mean, 50.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we use Less rather than a precise result?

})

t.Run("TDIGEST.TRIMMED_MEAN wrong number of arguments", func(t *testing.T) {
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg)
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", "key").Err(), errMsgWrongNumberArg)
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", "key", "0.1").Err(), errMsgWrongNumberArg)
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", "key", "0.1", "0.9", "extra").Err(), errMsgWrongNumberArg)
})

t.Run("TDIGEST.TRIMMED_MEAN invalid quantile ranges", func(t *testing.T) {
key := "tdigest_invalid"
require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err())
require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5").Err())

require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "-0.1", "0.9").Err(), "low cut quantile must be between 0 and 1")
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "1.1").Err(), "high cut quantile must be between 0 and 1")
require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.9", "0.1").Err(), "low cut quantile must be less than high cut quantile")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error message could be constant string to reduce duplication.

require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.5", "0.5").Err(), "low cut quantile must be less than high cut quantile")
})

t.Run("TDIGEST.TRIMMED_MEAN with single value", func(t *testing.T) {
key := "tdigest_single"
require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err())
require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "42").Err())

result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
require.NoError(t, result.Err())
mean, err := strconv.ParseFloat(result.Val().(string), 64)
require.NoError(t, err)
require.InDelta(t, 42.0, mean, 0.001)
Copy link
Member

@LindaSummer LindaSummer Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could use a stable precision for delta in all cases?

})

t.Run("TDIGEST.TRIMMED_MEAN with extreme trimming", func(t *testing.T) {
key := "tdigest_extreme"
require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err())
require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err())

result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.4", "0.6")
require.NoError(t, result.Err())
meanStr := result.Val().(string)
if meanStr == "nan" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result should not be nan.

return
}
mean, err := strconv.ParseFloat(meanStr, 64)
require.NoError(t, err)
require.Greater(t, mean, 0.0)
Comment on lines +808 to +813
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test allows "nan" and returns early, which can mask real correctness issues (a non-empty digest with low_cut < high_cut should always have some weight in the trimmed range). It would be better to assert the result is not NaN for this dataset and verify it’s within an expected numeric range/value.

Suggested change
if meanStr == "nan" {
return
}
mean, err := strconv.ParseFloat(meanStr, 64)
require.NoError(t, err)
require.Greater(t, mean, 0.0)
mean, err := strconv.ParseFloat(meanStr, 64)
require.NoError(t, err)
require.False(t, math.IsNaN(mean))
require.Greater(t, mean, 4.0)
require.Less(t, mean, 7.0)

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use precise value for test cases for stable and correction.

})
}

func TestTDigestByRankAndByRevRank(t *testing.T) {
Expand Down
Loading