Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enh/ts 5255/cols func2 #29354

Open
wants to merge 17 commits into
base: 3.0
Choose a base branch
from
Prev Previous commit
Next Next commit
cols function
  • Loading branch information
facetosea committed Dec 25, 2024
commit e26a3567bc0c941f71e656ddf92bd6a874558f67
26 changes: 21 additions & 5 deletions source/libs/parser/src/parTranslater.c
Original file line number Diff line number Diff line change
Expand Up @@ -7438,6 +7438,18 @@ static bool invalidColsAlias(SFunctionNode* pFunc) {
return false;
}

static int32_t getSelectFuncIndex(SNodeList* FuncNodeList, SNode* pSelectFunc) {
SNode* pNode = NULL;
int32_t selectFuncIndex = 0;
FOREACH(pNode, FuncNodeList) {
++selectFuncIndex;
if (nodesEqualNode(pNode, pSelectFunc)) {
return selectFuncIndex;
}
}
return 0;
}

static int32_t rewriteColsFunction(STranslateContext* pCxt, SNodeList** nodeList) {
int32_t code = hasInvalidColsFunction(pCxt, *nodeList);
if (TSDB_CODE_SUCCESS != code) {
Expand Down Expand Up @@ -7469,26 +7481,30 @@ static int32_t rewriteColsFunction(STranslateContext* pCxt, SNodeList** nodeList
}
SNode* pNewNode = NULL;
int32_t nums = 0;
int32_t selectFuncNum = 0;
int32_t selectFuncCount = 0;
FOREACH(pTmpNode, *nodeList) {
if (QUERY_NODE_FUNCTION == nodeType(pTmpNode)) {
SFunctionNode* pFunc = (SFunctionNode*)pTmpNode;
if (strcasecmp(pFunc->functionName, "cols") == 0) {
++selectFuncNum;
SNode* pSelectFunc = nodesListGetNode(pFunc->pParameterList, 0);
if(nodeType(pSelectFunc) != QUERY_NODE_FUNCTION) {
if (nodeType(pSelectFunc) != QUERY_NODE_FUNCTION) {
code = TSDB_CODE_PAR_INVALID_COLS_FUNCTION;
parserError("Invalid cols function, the first parameter must be a select function");
goto _end;
}
nodesListMakeStrictAppend(&tmpFuncNodeList, pSelectFunc);
int32_t selectFuncIndex = getSelectFuncIndex(tmpFuncNodeList, pSelectFunc);
if (selectFuncIndex == 0) {
++selectFuncCount;
selectFuncIndex = selectFuncCount;
nodesListMakeStrictAppend(&tmpFuncNodeList, pSelectFunc);
}
// start from index 1, because the first parameter is select function which needn't to output.
for (int i = 1; i < pFunc->pParameterList->length; ++i) {
SNode* pExpr = nodesListGetNode(pFunc->pParameterList, i);

code = nodesCloneNode(pExpr, &pNewNode);
if (nodesIsExprNode(pNewNode)) {
SBindTupleFuncCxt pCxt = {selectFuncNum};
SBindTupleFuncCxt pCxt = {selectFuncIndex};
nodesRewriteExpr(&pNewNode, pushDownBindSelectFunc, &pCxt);
} else {
code = TSDB_CODE_PAR_INVALID_COLS_FUNCTION;
Expand Down
32 changes: 25 additions & 7 deletions tests/system-test/2-query/cols_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def funcNestTest(self):
tdSql.checkRows(1)
#tdSql.checkCols(4)
tdSql.checkData(0, 0, 1734574930000)
tdSql.checkData(0, 1, 2.2)
tdSql.checkData(0, 1, 'bbbbbbbbb')
tdSql.checkData(0, 2, 1734574929000)
tdSql.checkData(0, 3, 1.1)
tdSql.checkData(0, 3, 'a')
tdSql.query(f'select cols(last(c0), ts, c1, c2, c3), cols(first(c0), ts, c1, c2, c3) from db.d1')
tdSql.checkRows(1)
#tdSql.checkCols(6)
Expand All @@ -103,19 +103,37 @@ def funcNestTest(self):
tdSql.query(f'select cols(last(ts), c1), cols(first(ts), c1) from db.d1')
tdSql.checkRows(1)
#tdSql.checkCols(6)
tdSql.checkData(0, 0, 2)
tdSql.checkData(0, 1, 1)
tdSql.checkData(0, 0, 2.2)
tdSql.checkData(0, 1, 1.1)

tdSql.query(f'select cols(first(ts), c1), cols(first(ts), c1) from db.d1')
tdSql.query(f'select cols(first(ts), c0, c1), cols(first(ts), c0, c1) from db.d1')
tdSql.checkRows(1)
#tdSql.checkCols(6)
tdSql.checkData(0, 0, 1)
tdSql.checkData(0, 1, 1)
tdSql.checkData(0, 1, 1.1)
tdSql.checkData(0, 2, 1)
tdSql.checkData(0, 3, 1.1)

tdSql.query(f'select cols(first(c0), ts, length(c2)), cols(last(c0), ts, length(c2)) from db.d1')
tdSql.query(f'select cols(first(ts), c0, c1), cols(first(ts+1), c0, c1) from db.d1')
tdSql.checkRows(1)
#tdSql.checkCols(6)
tdSql.checkData(0, 0, 1)
tdSql.checkData(0, 1, 1.1)
tdSql.checkData(0, 2, 1)
tdSql.checkData(0, 3, 1.1)

tdSql.query(f'select cols(first(ts), c0, c1), cols(first(ts), c0+1, c1+2) from db.d1')
tdSql.checkRows(1)
#tdSql.checkCols(6)
tdSql.checkData(0, 0, 1)
tdSql.checkData(0, 1, 1.1)
tdSql.checkData(0, 2, 2)
tdSql.checkData(0, 3, 3.1)

tdSql.query(f'select cols(first(c0), ts, length(c2)), cols(last(c0), ts, length(c2)) from db.d1')
tdSql.checkRows(1)
#tdSql.checkCols(6)
tdSql.checkData(0, 0, 1734574929000)
tdSql.checkData(0, 1, 1)
tdSql.query(f'select cols(first(c0), ts, length(c2)), cols(last(c0), ts, length(c2)) from db.d1')
tdSql.checkRows(1)
Expand Down