Skip to content

[Enhancement] Reduce the memory useage of TableStatistic #50316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class CachedStatisticStorage implements StatisticStorage {
private final Executor statsCacheRefresherExecutor = Executors.newFixedThreadPool(Config.statistic_cache_thread_pool_size,
new ThreadFactoryBuilder().setDaemon(true).setNameFormat("stats-cache-refresher-%d").build());

AsyncLoadingCache<TableStatsCacheKey, Optional<TableStatistic>> tableStatsCache = Caffeine.newBuilder()
AsyncLoadingCache<TableStatsCacheKey, Optional<Long>> tableStatsCache = Caffeine.newBuilder()
.expireAfterWrite(Config.statistic_update_interval_sec * 2, TimeUnit.SECONDS)
.refreshAfterWrite(Config.statistic_update_interval_sec, TimeUnit.SECONDS)
.maximumSize(Config.statistic_cache_columns)
Expand Down Expand Up @@ -88,30 +88,29 @@ public class CachedStatisticStorage implements StatisticStorage {
.buildAsync(new ConnectorHistogramColumnStatsCacheLoader());

@Override
public Map<Long, TableStatistic> getTableStatistics(Long tableId, Collection<Partition> partitions) {
public Map<Long, Optional<Long>> getTableStatistics(Long tableId, Collection<Partition> partitions) {
// get Statistics Table column info, just return default column statistics
if (StatisticUtils.statisticTableBlackListCheck(tableId)) {
return partitions.stream().collect(Collectors.toMap(Partition::getId, p -> TableStatistic.unknown()));
return partitions.stream().collect(Collectors.toMap(Partition::getId, p -> Optional.empty()));
}

List<TableStatsCacheKey> keys = partitions.stream().map(p -> new TableStatsCacheKey(tableId, p.getId()))
.collect(Collectors.toList());

try {
CompletableFuture<Map<TableStatsCacheKey, Optional<TableStatistic>>> result = tableStatsCache.getAll(keys);
CompletableFuture<Map<TableStatsCacheKey, Optional<Long>>> result = tableStatsCache.getAll(keys);
if (result.isDone()) {
Map<TableStatsCacheKey, Optional<TableStatistic>> data = result.get();
return keys.stream().collect(Collectors.toMap(
TableStatsCacheKey::getPartitionId,
k -> data.getOrDefault(k, Optional.empty()).orElse(TableStatistic.unknown())));
Map<TableStatsCacheKey, Optional<Long>> data = result.get();
return keys.stream().collect(Collectors.toMap(TableStatsCacheKey::getPartitionId,
k -> data.getOrDefault(k, Optional.empty())));
}
} catch (InterruptedException e) {
LOG.warn("Failed to execute tableStatsCache.getAll", e);
Thread.currentThread().interrupt();
} catch (Exception e) {
LOG.warn("Faied to execute tableStatsCache.getAll", e);
}
return partitions.stream().collect(Collectors.toMap(Partition::getId, p -> TableStatistic.unknown()));
return partitions.stream().collect(Collectors.toMap(Partition::getId, p -> Optional.empty()));
}

@Override
Expand All @@ -122,7 +121,7 @@ public void refreshTableStatistic(Table table) {
}

try {
CompletableFuture<Map<TableStatsCacheKey, Optional<TableStatistic>>> completableFuture
CompletableFuture<Map<TableStatsCacheKey, Optional<Long>>> completableFuture
= tableStatsCache.getAll(statsCacheKeyList);
if (completableFuture.isDone()) {
completableFuture.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

public interface StatisticStorage {
// partitionId: TableStatistic
default Map<Long, TableStatistic> getTableStatistics(Long tableId, Collection<Partition> partitions) {
return partitions.stream().collect(Collectors.toMap(Partition::getId, p -> TableStatistic.unknown()));
// partitionId: RowCount
default Map<Long, Optional<Long>> getTableStatistics(Long tableId, Collection<Partition> partitions) {
return partitions.stream().collect(Collectors.toMap(Partition::getId, p -> Optional.empty()));
}

default void refreshTableStatistic(Table table) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

public class StatisticsCalcUtils {
Expand Down Expand Up @@ -121,20 +122,20 @@ public static long getTableRowCount(Table table, Operator node, OptimizerContext
// For example, a large amount of data LOAD may cause the number of rows to change greatly.
// This leads to very inaccurate row counts.
long deltaRows = deltaRows(table, basicStatsMeta.getUpdateRows());
Map<Long, TableStatistic> tableStatisticMap = GlobalStateMgr.getCurrentState().getStatisticStorage()
Map<Long, Optional<Long>> tableStatisticMap = GlobalStateMgr.getCurrentState().getStatisticStorage()
.getTableStatistics(table.getId(), selectedPartitions);
for (Partition partition : selectedPartitions) {
long partitionRowCount;
TableStatistic tableStatistic =
tableStatisticMap.getOrDefault(partition.getId(), TableStatistic.unknown());
Optional<Long> tableStatistic =
tableStatisticMap.getOrDefault(partition.getId(), Optional.empty());
LocalDateTime updateDatetime = StatisticUtils.getPartitionLastUpdateTime(partition);
if (tableStatistic.equals(TableStatistic.unknown())) {
if (tableStatistic.isEmpty()) {
partitionRowCount = partition.getRowCount();
if (updateDatetime.isAfter(lastWorkTimestamp)) {
partitionRowCount += deltaRows;
}
} else {
partitionRowCount = tableStatistic.getRowCount();
partitionRowCount = tableStatistic.get();
if (updateDatetime.isAfter(basicStatsMeta.getUpdateTime())) {
partitionRowCount += deltaRows;
}
Expand Down Expand Up @@ -184,17 +185,13 @@ private static void updateQueryDumpInfo(OptimizerContext optimizerContext, Table

private static long deltaRows(Table table, long totalRowCount) {
long tblRowCount = 0L;
Map<Long, TableStatistic> tableStatisticMap = GlobalStateMgr.getCurrentState().getStatisticStorage()
Map<Long, Optional<Long>> tableStatisticMap = GlobalStateMgr.getCurrentState().getStatisticStorage()
.getTableStatistics(table.getId(), table.getPartitions());

for (Partition partition : table.getPartitions()) {
long partitionRowCount;
TableStatistic statistic = tableStatisticMap.getOrDefault(partition.getId(), TableStatistic.unknown());
if (statistic.equals(TableStatistic.unknown())) {
partitionRowCount = partition.getRowCount();
} else {
partitionRowCount = statistic.getRowCount();
}
Optional<Long> statistic = tableStatisticMap.getOrDefault(partition.getId(), Optional.empty());
partitionRowCount = statistic.orElseGet(partition::getRowCount);
tblRowCount += partitionRowCount;
}
if (tblRowCount < totalRowCount) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package com.starrocks.sql.optimizer.statistics;

import com.github.benmanes.caffeine.cache.AsyncCacheLoader;
import com.google.api.client.util.Lists;
import com.starrocks.common.Config;
import com.starrocks.qe.ConnectContext;
import com.starrocks.statistic.StatisticExecutor;
import com.starrocks.statistic.StatisticUtils;
Expand All @@ -29,22 +31,22 @@
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;

public class TableStatsCacheLoader implements AsyncCacheLoader<TableStatsCacheKey, Optional<TableStatistic>> {
public class TableStatsCacheLoader implements AsyncCacheLoader<TableStatsCacheKey, Optional<Long>> {
private final StatisticExecutor statisticExecutor = new StatisticExecutor();

@Override
public @NonNull CompletableFuture<Optional<TableStatistic>> asyncLoad(@NonNull TableStatsCacheKey cacheKey, @
public @NonNull CompletableFuture<Optional<Long>> asyncLoad(@NonNull TableStatsCacheKey cacheKey, @
NonNull Executor executor) {
return CompletableFuture.supplyAsync(() -> {
try {
ConnectContext connectContext = StatisticUtils.buildConnectContext();
connectContext.setThreadLocalInfo();
List<TStatisticData> statisticData = queryStatisticsData(connectContext, cacheKey.tableId, cacheKey.partitionId);
if (statisticData.size() == 0) {
return Optional.of(new TableStatistic(cacheKey.getTableId(), cacheKey.getPartitionId(), 0L));
List<TStatisticData> statisticData =
statisticExecutor.queryTableStats(connectContext, cacheKey.tableId, cacheKey.partitionId);
if (statisticData.isEmpty()) {
return Optional.empty();
} else {
return Optional.of(new TableStatistic(cacheKey.getTableId(), cacheKey.getPartitionId(),
statisticData.get(0).rowCount));
return Optional.of(statisticData.get(0).rowCount);
}
} catch (RuntimeException e) {
throw e;
Expand All @@ -57,31 +59,40 @@ public class TableStatsCacheLoader implements AsyncCacheLoader<TableStatsCacheKe
}

@Override
public @NonNull CompletableFuture<Map<@NonNull TableStatsCacheKey, @NonNull Optional<TableStatistic>>> asyncLoadAll(
public @NonNull CompletableFuture<Map<@NonNull TableStatsCacheKey, @NonNull Optional<Long>>> asyncLoadAll(
@NonNull Iterable<? extends @NonNull TableStatsCacheKey> cacheKey,
@NonNull Executor executor) {
return CompletableFuture.supplyAsync(() -> {
try {
TableStatsCacheKey tableStatsCacheKey = cacheKey.iterator().next();
long tableId = tableStatsCacheKey.getTableId();

ConnectContext connectContext = StatisticUtils.buildConnectContext();
connectContext.setThreadLocalInfo();
List<TStatisticData> statisticData = queryStatisticsData(connectContext, tableStatsCacheKey.getTableId());

Map<TableStatsCacheKey, Optional<TableStatistic>> result = new HashMap<>();
for (TStatisticData tStatisticData : statisticData) {
result.put(new TableStatsCacheKey(tableId, tStatisticData.partitionId),
Optional.of(new TableStatistic(tableId, tStatisticData.partitionId, tStatisticData.rowCount)));
Map<TableStatsCacheKey, Optional<Long>> result = new HashMap<>();
List<Long> pids = Lists.newArrayList();
long tableId = -1;
for (TableStatsCacheKey statsCacheKey : cacheKey) {
pids.add(statsCacheKey.getPartitionId());
tableId = statsCacheKey.getTableId();
if (pids.size() > Config.expr_children_limit / 2) {
List<TStatisticData> statisticData =
statisticExecutor.queryTableStats(connectContext, statsCacheKey.getTableId(), pids);

statisticData.forEach(tStatisticData -> result.put(
new TableStatsCacheKey(statsCacheKey.getTableId(), tStatisticData.partitionId),
Optional.of(tStatisticData.rowCount)));
pids.clear();
}
}
List<TStatisticData> statisticData = statisticExecutor.queryTableStats(connectContext, tableId, pids);
for (TStatisticData data : statisticData) {
result.put(new TableStatsCacheKey(tableId, data.partitionId), Optional.of(data.rowCount));
}
for (TableStatsCacheKey key : cacheKey) {
if (!result.containsKey(key)) {
result.put(key, Optional.empty());
}
}

return result;

} catch (RuntimeException e) {
throw e;
} catch (Exception e) {
Expand All @@ -91,12 +102,4 @@ public class TableStatsCacheLoader implements AsyncCacheLoader<TableStatsCacheKe
}
}, executor);
}

private List<TStatisticData> queryStatisticsData(ConnectContext context, long tableId) {
return statisticExecutor.queryTableStats(context, tableId);
}

private List<TStatisticData> queryStatisticsData(ConnectContext context, long tableId, long partitionId) {
return statisticExecutor.queryTableStats(context, tableId, partitionId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import com.starrocks.common.io.Writable;
import com.starrocks.persist.gson.GsonUtils;
import com.starrocks.server.GlobalStateMgr;
import com.starrocks.sql.optimizer.statistics.TableStatistic;
import org.apache.commons.collections4.MapUtils;

import java.io.DataInput;
Expand All @@ -32,6 +31,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class BasicStatsMeta implements Writable {
@SerializedName("dbId")
Expand Down Expand Up @@ -132,15 +132,13 @@ public double getHealthy() {
long updatePartitionRowCount = 0L;
long updatePartitionCount = 0L;

Map<Long, TableStatistic> tableStatistics = GlobalStateMgr.getCurrentState().getStatisticStorage()
Map<Long, Optional<Long>> tableStatistics = GlobalStateMgr.getCurrentState().getStatisticStorage()
.getTableStatistics(table.getId(), table.getPartitions());

for (Partition partition : table.getPartitions()) {
tableRowCount += partition.getRowCount();
TableStatistic statistic = tableStatistics.getOrDefault(partition.getId(), TableStatistic.unknown());
if (!statistic.equals(TableStatistic.unknown())) {
cachedTableRowCount += statistic.getRowCount();
}
Optional<Long> statistic = tableStatistics.getOrDefault(partition.getId(), Optional.empty());
cachedTableRowCount += statistic.orElse(0L);
LocalDateTime loadTime = StatisticUtils.getPartitionLastUpdateTime(partition);

if (partition.hasData() && !isUpdatedAfterLoad(loadTime)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ public static Pair<List<TStatisticData>, Status> queryDictSync(Long dbId, Long t
}
}

public List<TStatisticData> queryTableStats(ConnectContext context, Long tableId) {
String sql = StatisticSQLBuilder.buildQueryTableStatisticsSQL(tableId);
public List<TStatisticData> queryTableStats(ConnectContext context, Long tableId, List<Long> partitions) {
String sql = StatisticSQLBuilder.buildQueryTableStatisticsSQL(tableId, partitions);
return executeStatisticDQL(context, sql);
}

Expand All @@ -233,7 +233,7 @@ public List<TStatisticData> queryTableStats(ConnectContext context, Long tableId
private static List<TStatisticData> deserializerStatisticData(List<TResultBatch> sqlResult) throws TException {
List<TStatisticData> statistics = Lists.newArrayList();

if (sqlResult.size() < 1) {
if (sqlResult.isEmpty()) {
return statistics;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,13 @@ public class StatisticSQLBuilder {
DEFAULT_VELOCITY_ENGINE.setProperty(VelocityEngine.RUNTIME_LOG_REFERENCE_LOG_INVALID, false);
}

public static String buildQueryTableStatisticsSQL(Long tableId) {
public static String buildQueryTableStatisticsSQL(Long tableId, List<Long> partitionIds) {
VelocityContext context = new VelocityContext();
context.put("predicate", "table_id = " + tableId);
if (!partitionIds.isEmpty()) {
context.put("predicate", "table_id = " + tableId + " and partition_id in (" +
partitionIds.stream().map(String::valueOf).collect(Collectors.joining(", ")) + ")");
}
return build(context, QUERY_TABLE_STATISTIC_TEMPLATE);
}

Expand Down
Loading
Loading