Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion ydb/core/kqp/executer_actor/kqp_data_executer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1967,7 +1967,12 @@ class TKqpDataExecuter : public TKqpExecuterBase<TKqpDataExecuter, EExecType::Da
}
}

size_t sourceScanPartitionsCount = TasksGraph.BuildAllTasks({}, ResourcesSnapshot, Stats.get(), &ShardsWithEffects);
size_t sourceScanPartitionsCount = 0;

if (!graphRestored) {
sourceScanPartitionsCount = TasksGraph.BuildAllTasks({}, ResourcesSnapshot, Stats.get(), &ShardsWithEffects);
}

OnEmptyResult();

TIssue validateIssue;
Expand Down
2 changes: 1 addition & 1 deletion ydb/core/kqp/executer_actor/kqp_executer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ class TKqpExecuterBase : public TActor<TDerived> {

bool RestoreTasksGraph() {
if (Request.QueryPhysicalGraph) {
TasksGraph.RestoreTasksGraphInfo(*Request.QueryPhysicalGraph);
TasksGraph.RestoreTasksGraphInfo(ResourcesSnapshot, *Request.QueryPhysicalGraph);
}

return TasksGraph.GetMeta().IsRestored;
Expand Down
151 changes: 94 additions & 57 deletions ydb/core/kqp/executer_actor/kqp_tasks_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,49 @@ using namespace NYql;
using namespace NYql::NDq;
using namespace NYql::NNodes;

namespace {

struct TStageScheduleInfo {
double StageCost = 0.0;
ui32 TaskCount = 0;
};

std::map<ui32, TStageScheduleInfo> ScheduleByCost(const IKqpGateway::TPhysicalTxData& tx, const TVector<NKikimrKqp::TKqpNodeResources>& resourceSnapshot) {
std::map<ui32, TStageScheduleInfo> result;
if (!resourceSnapshot.empty()) // can't schedule w/o node count
{
// collect costs and schedule stages with external sources only
double totalCost = 0.0;
for (ui32 stageIdx = 0; stageIdx < tx.Body->StagesSize(); ++stageIdx) {
auto& stage = tx.Body->GetStages(stageIdx);
if (stage.SourcesSize() > 0 && stage.GetSources(0).GetTypeCase() == NKqpProto::TKqpSource::kExternalSource) {
if (stage.GetStageCost() > 0.0 && stage.GetTaskCount() == 0) {
totalCost += stage.GetStageCost();
result.emplace(stageIdx, TStageScheduleInfo{.StageCost = stage.GetStageCost()});
}
}
}
// assign task counts
if (!result.empty()) {
// allow use 2/3 of threads in single stage
ui32 maxStageTaskCount = (TStagePredictor::GetUsableThreads() * 2 + 2) / 3;
// total limit per mode is x2
ui32 maxTotalTaskCount = maxStageTaskCount * 2;
for (auto& [_, stageInfo] : result) {
// schedule tasks evenly between nodes
stageInfo.TaskCount =
std::max<ui32>(
std::min(static_cast<ui32>(maxTotalTaskCount * stageInfo.StageCost / totalCost), maxStageTaskCount)
, 1
) * resourceSnapshot.size();
}
}
}
return result;
}

} // anonymous namespace

struct TShardRangesWithShardId {
TMaybe<ui64> ShardId;
const TShardKeyRanges* Ranges;
Expand Down Expand Up @@ -546,6 +589,12 @@ void TKqpTasksGraph::BuildDqSourceStreamLookupChannels(const TStageInfo& stageIn
streamLookupSource.SetProviderName(compiledSource.GetType());
*streamLookupSource.MutableLookupSource() = compiledSource.GetSettings();

TString structuredToken;
const auto& sourceName = compiledSource.GetSourceName();
if (sourceName) {
structuredToken = NYql::CreateStructuredTokenParser(compiledSource.GetAuthInfo()).ToBuilder().ReplaceReferences(GetMeta().SecureParams).ToJson();
}

TTransform dqSourceStreamLookupTransform = {
.Type = "StreamLookupInputTransform",
.InputType = dqSourceStreamLookup.GetInputStageRowType(),
Expand All @@ -554,7 +603,12 @@ void TKqpTasksGraph::BuildDqSourceStreamLookupChannels(const TStageInfo& stageIn
YQL_ENSURE(dqSourceStreamLookupTransform.Settings.PackFrom(*settings));

for (const auto taskId : stageInfo.Tasks) {
GetTask(taskId).Inputs[inputIndex].Transform = dqSourceStreamLookupTransform;
auto& task = GetTask(taskId);
task.Inputs[inputIndex].Transform = dqSourceStreamLookupTransform;

if (structuredToken) {
task.Meta.SecureParams.emplace(sourceName, structuredToken);
}
}

BuildUnionAllChannels(*this, stageInfo, inputIndex, inputStageInfo, outputIndex, /* enableSpilling */ false, logFunc);
Expand Down Expand Up @@ -1505,9 +1559,10 @@ void TKqpTasksGraph::PersistTasksGraphInfo(NKikimrKqp::TQueryPhysicalGraph& resu
}
}

// Restored graph only requires to update authentication secrets
// and to reassign existing tasks between actual nodes.
void TKqpTasksGraph::RestoreTasksGraphInfo(const NKikimrKqp::TQueryPhysicalGraph& graphInfo) {
void TKqpTasksGraph::RestoreTasksGraphInfo(const TVector<NKikimrKqp::TKqpNodeResources>& resourcesSnapshot, const NKikimrKqp::TQueryPhysicalGraph& graphInfo) {
GetMeta().IsRestored = true;
GetMeta().AllowWithSpilling = false;

const auto restoreDqTransform = [](const auto& protoInfo) -> TMaybe<TTransform> {
if (!protoInfo.HasTransform()) {
return Nothing();
Expand Down Expand Up @@ -1543,12 +1598,14 @@ void TKqpTasksGraph::RestoreTasksGraphInfo(const NKikimrKqp::TQueryPhysicalGraph
const auto& task = graphInfo.GetTasks(taskIdx);
const auto txId = task.GetTxId();
const auto& taskInfo = task.GetDqTask();
const NYql::NDq::TStageId stageId(txId, taskInfo.GetStageId());

auto& stageInfo = GetStageInfo({txId, taskInfo.GetStageId()});
auto& stageInfo = GetStageInfo(stageId);
auto& newTask = AddTask(stageInfo, TTaskType::RESTORED);
YQL_ENSURE(taskInfo.GetId() == newTask.Id);
newTask.Meta.TaskParams.insert(taskInfo.GetTaskParams().begin(), taskInfo.GetTaskParams().end());
newTask.Meta.ReadRanges.assign(taskInfo.GetReadRanges().begin(), taskInfo.GetReadRanges().end());
newTask.Meta.Type = TTaskMeta::TTaskType::Compute;

for (size_t inputIdx = 0; inputIdx < taskInfo.InputsSize(); ++inputIdx) {
const auto& inputInfo = taskInfo.GetInputs(inputIdx);
Expand Down Expand Up @@ -1663,6 +1720,23 @@ void TKqpTasksGraph::RestoreTasksGraphInfo(const NKikimrKqp::TQueryPhysicalGraph
restoreDqChannel(txId, channelInfo).SrcOutputIndex = outputIdx;
}
}

const auto& stage = stageInfo.Meta.GetStage(stageId);
FillSecureParamsFromStage(newTask.Meta.SecureParams, stage);
BuildSinks(stage, stageInfo, newTask);

for (const auto& input : stage.GetInputs()) {
if (input.GetTypeCase() != NKqpProto::TKqpPhyConnection::kDqSourceStreamLookup) {
continue;
}

if (const auto& compiledSource = input.GetDqSourceStreamLookup().GetLookupSource(); const auto& sourceName = compiledSource.GetSourceName()) {
newTask.Meta.SecureParams.emplace(
sourceName,
NYql::CreateStructuredTokenParser(compiledSource.GetAuthInfo()).ToBuilder().ReplaceReferences(GetMeta().SecureParams).ToJson()
);
}
}
}

for (const auto& [id, channel] : channels) {
Expand All @@ -1671,7 +1745,20 @@ void TKqpTasksGraph::RestoreTasksGraphInfo(const NKikimrKqp::TQueryPhysicalGraph
YQL_ENSURE(id == newChannel.Id);
}

GetMeta().IsRestored = true;
for (ui64 txIdx = 0; txIdx < Transactions.size(); ++txIdx) {
const auto& tx = Transactions.at(txIdx);
const auto scheduledTaskCount = ScheduleByCost(tx, resourcesSnapshot);

for (ui64 stageIdx = 0; stageIdx < tx.Body->StagesSize(); ++stageIdx) {
const auto& stage = tx.Body->GetStages(stageIdx);
auto& stageInfo = GetStageInfo({txIdx, stageIdx});

if (const auto& sources = stage.GetSources(); !sources.empty() && sources[0].GetTypeCase() == NKqpProto::TKqpSource::kExternalSource) {
const auto it = scheduledTaskCount.find(stageIdx);
BuildReadTasksFromSource(stageInfo, resourcesSnapshot, it != scheduledTaskCount.end() ? it->second.TaskCount : 0);
}
}
}
}

void TKqpTasksGraph::BuildSysViewScanTasks(TStageInfo& stageInfo) {
Expand Down Expand Up @@ -1745,15 +1832,6 @@ std::pair<ui32, TKqpTasksGraph::TTaskType::ECreateReason> TKqpTasksGraph::GetMax
bool TKqpTasksGraph::BuildComputeTasks(TStageInfo& stageInfo, const ui32 nodesCount) {
auto& stage = stageInfo.Meta.GetStage(stageInfo.Id);

// TODO: move outside
if (GetMeta().IsRestored) {
for (const auto taskId : stageInfo.Tasks) {
auto& task = GetTask(taskId);
task.Meta.Type = TTaskMeta::TTaskType::Compute;
}
return false;
}

TTaskType::ECreateReason tasksReason = TTaskType::MINIMUM_COMPUTE;
bool unknownAffectedShardCount = false;
ui32 partitionsCount = 1;
Expand Down Expand Up @@ -2715,45 +2793,6 @@ TMaybe<size_t> TKqpTasksGraph::BuildScanTasksFromSource(TStageInfo& stageInfo, b
}
}

struct TStageScheduleInfo {
double StageCost = 0.0;
ui32 TaskCount = 0;
};

static std::map<ui32, TStageScheduleInfo> ScheduleByCost(const IKqpGateway::TPhysicalTxData& tx, const TVector<NKikimrKqp::TKqpNodeResources>& resourceSnapshot) {
std::map<ui32, TStageScheduleInfo> result;
if (!resourceSnapshot.empty()) // can't schedule w/o node count
{
// collect costs and schedule stages with external sources only
double totalCost = 0.0;
for (ui32 stageIdx = 0; stageIdx < tx.Body->StagesSize(); ++stageIdx) {
auto& stage = tx.Body->GetStages(stageIdx);
if (stage.SourcesSize() > 0 && stage.GetSources(0).GetTypeCase() == NKqpProto::TKqpSource::kExternalSource) {
if (stage.GetStageCost() > 0.0 && stage.GetTaskCount() == 0) {
totalCost += stage.GetStageCost();
result.emplace(stageIdx, TStageScheduleInfo{.StageCost = stage.GetStageCost()});
}
}
}
// assign task counts
if (!result.empty()) {
// allow use 2/3 of threads in single stage
ui32 maxStageTaskCount = (TStagePredictor::GetUsableThreads() * 2 + 2) / 3;
// total limit per mode is x2
ui32 maxTotalTaskCount = maxStageTaskCount * 2;
for (auto& [_, stageInfo] : result) {
// schedule tasks evenly between nodes
stageInfo.TaskCount =
std::max<ui32>(
std::min(static_cast<ui32>(maxTotalTaskCount * stageInfo.StageCost / totalCost), maxStageTaskCount)
, 1
) * resourceSnapshot.size();
}
}
}
return result;
}

void TKqpTasksGraph::FillSecureParamsFromStage(THashMap<TString, TString>& secureParams, const NKqpProto::TKqpPhyStage& stage) const {
for (const auto& [secretName, authInfo] : stage.GetSecureParams()) {
const auto& structuredToken = NYql::CreateStructuredTokenParser(authInfo).ToBuilder().ReplaceReferences(GetMeta().SecureParams).ToJson();
Expand Down Expand Up @@ -2943,9 +2982,7 @@ size_t TKqpTasksGraph::BuildAllTasks(std::optional<TLlvmSettings> llvmSettings,

// Not task-related
GetMeta().AllowWithSpilling |= stage.GetAllowWithSpilling();
if (!GetMeta().IsRestored) {
BuildKqpStageChannels(stageInfo, GetMeta().TxId, GetMeta().AllowWithSpilling, tx.Body->EnableShuffleElimination());
}
BuildKqpStageChannels(stageInfo, GetMeta().TxId, GetMeta().AllowWithSpilling, tx.Body->EnableShuffleElimination());
}

// Not task-related
Expand Down
2 changes: 1 addition & 1 deletion ydb/core/kqp/executer_actor/kqp_tasks_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ class TKqpTasksGraph : public NYql::NDq::TDqTasksGraph<TGraphMeta, TStageInfoMet

NYql::NDqProto::TDqTask* ArenaSerializeTaskToProto(const TTask& task, bool serializeAsyncIoSettings);
void PersistTasksGraphInfo(NKikimrKqp::TQueryPhysicalGraph& result) const;
void RestoreTasksGraphInfo(const NKikimrKqp::TQueryPhysicalGraph& graphInfo);
void RestoreTasksGraphInfo(const TVector<NKikimrKqp::TKqpNodeResources>& resourcesSnapshot, const NKikimrKqp::TQueryPhysicalGraph& graphInfo);

// TODO: public used by TKqpPlanner
void FillChannelDesc(NYql::NDqProto::TChannel& channelDesc, const NYql::NDq::TChannel& channel,
Expand Down
13 changes: 9 additions & 4 deletions ydb/core/kqp/query_compiler/kqp_query_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,21 +471,21 @@ void FillOlapProgram(const T& node, const NKikimr::NMiniKQL::TType* miniKqlResul
CompileOlapProgram(node.Process(), tableMeta, readProto, resultColNames, ctx);
}

THashMap<TString, TString> FindSecureParams(const TExprNode::TPtr& node, const TTypeAnnotationContext& typesCtx, TSet<TString>& SecretNames) {
THashMap<TString, TString> FindSecureParams(const TExprNode::TPtr& node, const TTypeAnnotationContext& typesCtx, TSet<TString>& secretNames) {
THashMap<TString, TString> secureParams;
NYql::NCommon::FillSecureParams(node, typesCtx, secureParams);

for (auto& [secretName, structuredToken] : secureParams) {
const auto& tokenParser = CreateStructuredTokenParser(structuredToken);
tokenParser.ListReferences(SecretNames);
tokenParser.ListReferences(secretNames);
structuredToken = tokenParser.ToBuilder().RemoveSecrets().ToJson();
}

return secureParams;
}

std::optional<std::pair<TString, TString>> FindOneSecureParam(const TExprNode::TPtr& node, const TTypeAnnotationContext& typesCtx, const TString& nodeName, TSet<TString>& SecretNames) {
const auto& secureParams = FindSecureParams(node, typesCtx, SecretNames);
std::optional<std::pair<TString, TString>> FindOneSecureParam(const TExprNode::TPtr& node, const TTypeAnnotationContext& typesCtx, const TString& nodeName, TSet<TString>& secretNames) {
const auto& secureParams = FindSecureParams(node, typesCtx, secretNames);
if (secureParams.empty()) {
return std::nullopt;
}
Expand Down Expand Up @@ -1849,6 +1849,11 @@ class TKqpQueryCompiler : public IKqpQueryCompiler {
YQL_ENSURE(!lookupSourceSettings.type_url().empty(), "Data source provider \"" << dataSourceCategory << "\" did't fill dq source settings for its dq source node");
YQL_ENSURE(lookupSourceType, "Data source provider \"" << dataSourceCategory << "\" did't fill dq source settings type for its dq source node");

if (const auto& secureParams = FindOneSecureParam(lookupSourceWrap.Ptr(), TypesCtx, "streamLookupSource", SecretNames)) {
lookupSource.SetSourceName(secureParams->first);
lookupSource.SetAuthInfo(secureParams->second);
}

const auto& streamLookupOutput = streamLookup.Output();
const auto connectionInputRowType = GetSeqItemType(streamLookupOutput.Ref().GetTypeAnn());
YQL_ENSURE(connectionInputRowType->GetKind() == ETypeAnnotationKind::Struct);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ class TStreamingTestFixture : public NUnitTest::TBaseFixture {

auto listSplitsBuilder = mockClient->ExpectListSplits();
auto fillListSplitExpectation = listSplitsBuilder
.ValidateArgs(settings.ValidateListSplitsArgs)
.ValidateArgs(settings.ValidateListSplitsArgs ? TConnectorClientMock::EArgsValidation::Strict : TConnectorClientMock::EArgsValidation::DataSourceInstance)
.Select()
.DataSourceInstance(GetMockConnectorSourceInstance())
.Table(settings.TableName)
Expand Down Expand Up @@ -689,7 +689,7 @@ class TStreamingTestFixture : public NUnitTest::TBaseFixture {
{
auto columnsBuilder = readSplitsBuilder
.Filtering(TReadSplitsRequest::FILTERING_OPTIONAL)
.ValidateArgs(settings.ValidateReadSplitsArgs)
.ValidateArgs(settings.ValidateReadSplitsArgs ? TConnectorClientMock::EArgsValidation::Strict : TConnectorClientMock::EArgsValidation::DataSourceInstance)
.Split()
.Description("some binary description")
.Select()
Expand Down
Loading
Loading