Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix passing str variable to readFrame/readMatrix #598

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/compiler/lowering/SpecializeGenericFunctionsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ namespace {
std::multimap<std::string, func::FuncOp> specializedVersions;
std::set<func::FuncOp> visited;
std::set<func::FuncOp> called;
std::set<func::FuncOp> templateFunctions;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This extra set for templateFunctions (and its unconditional use in line 299) is not clear to me right away (I did not test-run the code, maybe it would become clear then). Could you explain why this is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially, I thought the check isFunctionTemplate would already do the trick, but apparantely it does something else. I basically remove the "template function", which is specialized from considerations in later passes (i.e., doing a type inference on such a function doesn't make much sense imho). Template functions are those, which are specialized, that's why I add them unconditionally. According to the test suite this should be fine, but I feel this will require more thorough testing (independent from the reported bug).

In general there are quite a few question marks in this pass. For example the constant parameter propagation leads to X functions with identical code (except for the constant parameter of course) when invoking the UDF X times.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've played around with the code a bit now and see what you mean. The generation of a separate function for each invocation (even if it is the same parameter) is a bit weird and could be improved.


const DaphneUserConfig& userConfig;
std::shared_ptr<spdlog::logger> logger;
Expand Down Expand Up @@ -304,6 +305,7 @@ namespace {
calledFunction->getLoc().print(stream);
logger->debug("calledFunction\n\tname: {}\n\tlocation: {}", calledFunction.getSymName().str(), s);
}
templateFunctions.insert(calledFunction);
return specializedFunc;
}

Expand Down Expand Up @@ -417,7 +419,7 @@ void SpecializeGenericFunctionsPass::runOnOperation() {
entryFunctions.push_back(entry.second);
}
for(const auto &function : entryFunctions) {
if(isFunctionTemplate(function) || visited.count(function))
if(isFunctionTemplate(function) || visited.count(function) || templateFunctions.count(function))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the invocations of count() were there before, so doing it the same is quite alright and tbh I'm unsure if using empty() is just a matter of taste or if it would yield a (minor?) performance benefit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

count has logarithmic complexity whereas empty has constant complexity, so there should be a slight advantage here. In general, I tried to change as little as possible although the entire pass probably needs a bigger rewrite anyways.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The count actually makes sense here, as we're checking for that specific function in the set and not if there are any functions at all.

continue;
if(!inferTypesInFunction(function)) {
return signalPassFailure();
Expand All @@ -431,7 +433,7 @@ void SpecializeGenericFunctionsPass::runOnOperation() {
continue;
// Remove a function that was present before creating specializations,
// if it is never called.
if(!called.count(f.second))
if(!called.count(f.second) || templateFunctions.count(f.second))
f.second.erase();
}
}
Expand Down
19 changes: 19 additions & 0 deletions src/compiler/utils/CompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,25 @@

#include <string>

// **************************************************************************************************
// Specializations of isConstantHelper for string types
// **************************************************************************************************

template<>
std::pair<bool, std::string> CompilerUtils::isConstantHelper<std::string, mlir::StringAttr>(mlir::Value v, std::function<std::string(const mlir::StringAttr&)> func) {
if(auto co = v.getDefiningOp<mlir::daphne::ConstantOp>()) {
if(auto attr = co.getValue().dyn_cast<mlir::StringAttr>()) {
return std::make_pair(true, func(attr));
}
}
if(auto co = v.getDefiningOp<mlir::arith::ConstantOp>()) {
if(auto attr = co.getValue().dyn_cast<mlir::StringAttr>()) {
return std::make_pair(true, func(attr));
}
}
return std::make_pair(false, std::string());
}

// **************************************************************************************************
// Specializations of isConstant for various types
// **************************************************************************************************
Expand Down
18 changes: 18 additions & 0 deletions src/ir/daphneir/DaphneInferFrameLabelsOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ void inferFrameLabels_ExtractOrFilterRowOp(ExtractOrFilterRowOp * op) {
// Frame label inference implementations
// ****************************************************************************

void daphne::ReadOp::inferFrameLabels() {
auto p = CompilerUtils::isConstant<std::string>(getFileName());
if (auto resType = getRes().getType().dyn_cast<daphne::FrameType>()) {
if (p.first) {
std::vector<std::string> * labels;
FileMetaData fmd = CompilerUtils::getFileMetaData(getFileName());
if (fmd.labels.empty()) {
labels = nullptr;
} else {
labels = new std::vector<std::string>(fmd.labels);
}

Value res = getResult();
res.setType(res.getType().dyn_cast<daphne::FrameType>().withLabels(labels));
}
}
}

void daphne::ColBindOp::inferFrameLabels() {
auto ftLhs = getLhs().getType().dyn_cast<daphne::FrameType>();
auto ftRhs = getRhs().getType().dyn_cast<daphne::FrameType>();
Expand Down
9 changes: 7 additions & 2 deletions src/ir/daphneir/DaphneInferShapeOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,13 @@ std::vector<std::pair<ssize_t, ssize_t>> daphne::MatMulOp::inferShape() {
}

std::vector<std::pair<ssize_t, ssize_t>> daphne::ReadOp::inferShape() {
FileMetaData fmd = CompilerUtils::getFileMetaData(getFileName());
return {{fmd.numRows, fmd.numCols}};
auto p = CompilerUtils::isConstant<std::string>(getFileName());
if (p.first) {
FileMetaData fmd = CompilerUtils::getFileMetaData(getFileName());
return {{fmd.numRows, fmd.numCols}};
} else {
return {{-1, -1}};
}
}

std::vector<std::pair<ssize_t, ssize_t>> daphne::OrderOp::inferShape() {
Expand Down
54 changes: 54 additions & 0 deletions src/ir/daphneir/DaphneInferTypesOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,60 @@ std::vector<Type> daphne::OrderOp::inferTypes() {
return {t};
}


mlir::Type mlirTypeForCode(ValueTypeCode type, Builder builder) {
switch(type) {
case ValueTypeCode::SI8: return builder.getIntegerType(8, true);
case ValueTypeCode::SI32: return builder.getIntegerType(32, true);
case ValueTypeCode::SI64: return builder.getIntegerType(64, true);
case ValueTypeCode::UI8: return builder.getIntegerType(8, false);
case ValueTypeCode::UI32: return builder.getIntegerType(32, false);
case ValueTypeCode::UI64: return builder.getIntegerType(64, false);
case ValueTypeCode::F32: return builder.getF32Type();
case ValueTypeCode::F64: return builder.getF64Type();
default: throw std::runtime_error("mlirTypeForCode: unknown value type code");
}
}

std::vector<Type> daphne::ReadOp::inferTypes() {

auto p = CompilerUtils::isConstant<std::string>(getFileName());
Builder builder(getContext());
if (auto resType = getRes().getType().dyn_cast<daphne::MatrixType>()) {
// If an individual value type was specified per column
// (fmd.isSingleValueType == false), then this silently uses the
// type of the first column.
Comment on lines +299 to +301
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment about isSingleValueType is not correct and can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just moved it over, but I can also remove it alltogether.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment referred to reading a matrix from a file that has a schema defined. So it is not utterly wrong. You could move it a few lines up to the code that handles reading matrices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

// TODO: add sparsity information here already (if present), currently not possible as many other ops
// just take input types as output types, which is incorrect for sparsity
if (p.first) {
FileMetaData fmd = CompilerUtils::getFileMetaData(getFileName());
mlir::Type valType = mlirTypeForCode(fmd.schema[0], builder);
return {mlir::daphne::MatrixType::get(getContext(), valType)};
} else {
return {mlir::daphne::MatrixType::get(getContext(), daphne::UnknownType::get(getContext()))};
}
}
else if (auto resType = getRes().getType().dyn_cast<daphne::FrameType>()) {
if (p.first) {
FileMetaData fmd = CompilerUtils::getFileMetaData(getFileName());
std::vector<mlir::Type> cts;
if (fmd.isSingleValueType) {
for (size_t i = 0; i < fmd.numCols; i++) {
cts.push_back(mlirTypeForCode(fmd.schema[0], builder));
}
} else {
for (ValueTypeCode vtc : fmd.schema) {
cts.push_back(mlirTypeForCode(vtc, builder));
}
}
return {mlir::daphne::FrameType::get(builder.getContext(), cts)};
} else {
return {mlir::daphne::FrameType::get(builder.getContext(), {daphne::UnknownType::get(getContext())})};
}
}
return {daphne::UnknownType::get(getContext())};
}

std::vector<Type> daphne::SliceColOp::inferTypes() {
Type u = daphne::UnknownType::get(getContext());
Type srcTy = getSource().getType();
Expand Down
2 changes: 2 additions & 0 deletions src/ir/daphneir/DaphneOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,9 @@ def Daphne_PrintOp : Daphne_Op<"print"> {

// TODO Take asynchronous read into account.
def Daphne_ReadOp : Daphne_Op<"read", [
DeclareOpInterfaceMethods<InferTypesOpInterface>,
DeclareOpInterfaceMethods<InferShapeOpInterface>,
DeclareOpInterfaceMethods<InferFrameLabelsOpInterface>,
DeclareOpInterfaceMethods<InferSparsityOpInterface>
]> {
// TODO We might add arguments for a UDF later.
Expand Down
14 changes: 0 additions & 14 deletions src/parser/ParserUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,20 +188,6 @@ class ParserUtils {
throw std::runtime_error("unsupported value type: " + name);
}

mlir::Type mlirTypeForCode(ValueTypeCode type) {
switch(type) {
case ValueTypeCode::SI8: return builder.getIntegerType(8, true);
case ValueTypeCode::SI32: return builder.getIntegerType(32, true);
case ValueTypeCode::SI64: return builder.getIntegerType(64, true);
case ValueTypeCode::UI8: return builder.getIntegerType(8, false);
case ValueTypeCode::UI32: return builder.getIntegerType(32, false);
case ValueTypeCode::UI64: return builder.getIntegerType(64, false);
case ValueTypeCode::F32: return builder.getF32Type();
case ValueTypeCode::F64: return builder.getF64Type();
default: throw std::runtime_error("ParserUtils::mlirTypeForCode: unknown value type code");
}
}

// ************************************************************************
// Misc
// ************************************************************************
Expand Down
48 changes: 10 additions & 38 deletions src/parser/daphnedsl/DaphneDSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1009,47 +1009,19 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f
loc, arg, newline, err
);
}
if(func == "readFrame" || func == "readMatrix") {
checkNumArgsExact(func, numArgs, 1);

mlir::Value filename = args[0];
FileMetaData fmd = CompilerUtils::getFileMetaData(filename);

mlir::Type resType;

if(func == "readFrame") {
std::vector<mlir::Type> cts;
if(fmd.isSingleValueType)
for(size_t i = 0; i < fmd.numCols; i++)
cts.push_back(utils.mlirTypeForCode(fmd.schema[0]));
else
for(ValueTypeCode vtc : fmd.schema)
cts.push_back(utils.mlirTypeForCode(vtc));

std::vector<std::string> * labels;
if(fmd.labels.empty())
labels = nullptr;
else
labels = new std::vector<std::string>(fmd.labels);
if (func == "readMatrix") {
checkNumArgsExact(func, numArgs, 1);
mlir::Type resType = mlir::daphne::MatrixType::get(builder.getContext(), utils.unknownType);
return static_cast<mlir::Value>(builder.create<ReadOp>(loc, resType, /*filename = */ args[0]));
}

resType = mlir::daphne::FrameType::get(
// TODO Inserting #rows/#cols here could cause problems, if
// the frame is involved in any SCF ops (if/while/for).
builder.getContext(), cts, fmd.numRows, fmd.numCols, labels
);
}
else // func == "read.matrix"
// If an individual value type was specified per column
// (fmd.isSingleValueType == false), then this silently uses the
// type of the first column.
// TODO: add sparsity information here already (if present), currently not possible as many other ops
// just take input types as output types, which is incorrect for sparsity
resType = utils.matrixOf(utils.mlirTypeForCode(fmd.schema[0]));

return static_cast<mlir::Value>(builder.create<ReadOp>(
loc, resType, filename
));
if (func == "readFrame") {
checkNumArgsExact(func, numArgs, 1);
mlir::Type resType = mlir::daphne::FrameType::get(builder.getContext(), {utils.unknownType});
return static_cast<mlir::Value>(builder.create<ReadOp>(loc, resType, /*filename = */ args[0]));
}

if(func == "writeFrame" || func == "writeMatrix" || func == "write") {
// Note that the type of arg already indicates if it is a frame or a
// matrix.
Expand Down
2 changes: 2 additions & 0 deletions test/api/cli/io/ReadCsv1.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-0.1,-0.2,0.1,0.2
3.14,5.41,6.22216,5
6 changes: 6 additions & 0 deletions test/api/cli/io/ReadCsv1.csv.meta
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"numRows": 2,
"numCols": 4,
"valueType": "f64",
"numNonZeros": 0
}
18 changes: 17 additions & 1 deletion test/api/cli/io/ReadTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,20 @@ TEST_CASE("readSparse", TAG_IO) {
"--args",
arg.c_str());
}
#endif
#endif

TEST_CASE("readFrameFromCSV", TAG_IO)
{
compareDaphneToRef(dirPath + "testReadFrame.txt", dirPath + "testReadFrame.daphne");
}

TEST_CASE("readMatrixFromCSV", TAG_IO)
{
compareDaphneToRef(dirPath + "testReadMatrix.txt", dirPath + "testReadMatrix.daphne");
}

// does not yet work!
// TEST_CASE("readReadMatrixFromCSV_DynamicPath", TAG_IO)
// {
// compareDaphneToRef(dirPath + "testReadMatrix.txt", dirPath + "testReadMatrix_DynamicPath.daphne");
// }
6 changes: 6 additions & 0 deletions test/api/cli/io/testReadFrame.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Test reading from a file when the file path is not trivially constant (i.e., a parameter to a UDF)
def readFrameFromCSV(path: str) {
print(readFrame(path));
}

readFrameFromCSV("test/api/cli/io/ReadCsv1.csv");
3 changes: 3 additions & 0 deletions test/api/cli/io/testReadFrame.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Frame(2x4, [col_0:double, col_1:double, col_2:double, col_3:double])
-0.1 -0.2 0.1 0.2
3.14 5.41 6.22216 5
6 changes: 6 additions & 0 deletions test/api/cli/io/testReadMatrix.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Test reading from a file when the file path is not trivially constant (i.e., a parameter to a UDF)
def readMatrixFromCSV(path: str) {
print(readMatrix(path));
}

readMatrixFromCSV("test/api/cli/io/ReadCsv1.csv");
3 changes: 3 additions & 0 deletions test/api/cli/io/testReadMatrix.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
DenseMatrix(2x4, double)
-0.1 -0.2 0.1 0.2
3.14 5.41 6.22216 5
5 changes: 5 additions & 0 deletions test/api/cli/io/testReadMatrix_DynamicPath.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Test dynamic computation of string path -> does not yet work!
i = 1;
filename = "test/api/cli/io/ReadCsv" + i + ".csv";
m = readMatrix(filename);
print(m);
Loading