3333#include " llvm/ADT/TypeSwitch.h"
3434#include " llvm/Support/ErrorHandling.h"
3535#include " llvm/Support/MathExtras.h"
36+ #include < cassert>
3637#include < optional>
3738
3839using cir::MissingFeatures;
@@ -41,12 +42,13 @@ using cir::MissingFeatures;
4142// CIR Custom Parser/Printer Signatures
4243// ===----------------------------------------------------------------------===//
4344
44- static mlir::ParseResult
45- parseFuncTypeArgs (mlir::AsmParser &p, llvm::SmallVector<mlir::Type> ¶ms,
46- bool &isVarArg);
47- static void printFuncTypeArgs (mlir::AsmPrinter &p,
48- mlir::ArrayRef<mlir::Type> params, bool isVarArg);
45+ static mlir::ParseResult parseFuncType (mlir::AsmParser &p,
46+ mlir::Type &optionalReturnTypes,
47+ llvm::SmallVector<mlir::Type> ¶ms,
48+ bool &isVarArg);
4949
50+ static void printFuncType (mlir::AsmPrinter &p, mlir::Type optionalReturnTypes,
51+ mlir::ArrayRef<mlir::Type> params, bool isVarArg);
5052static mlir::ParseResult parsePointerAddrSpace (mlir::AsmParser &p,
5153 mlir::Attribute &addrSpaceAttr);
5254static void printPointerAddrSpace (mlir::AsmPrinter &p,
@@ -813,9 +815,38 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
813815 return get (llvm::to_vector (inputs), results[0 ], isVarArg ());
814816}
815817
816- mlir::ParseResult parseFuncTypeArgs (mlir::AsmParser &p,
817- llvm::SmallVector<mlir::Type> ¶ms,
818- bool &isVarArg) {
818+ // A special parser is needed for function returning void to handle the missing
819+ // type.
820+ static mlir::ParseResult parseFuncTypeReturn (mlir::AsmParser &p,
821+ mlir::Type &optionalReturnType) {
822+ if (succeeded (p.parseOptionalLParen ())) {
823+ // If we have already a '(', the function has no return type
824+ optionalReturnType = {};
825+ return mlir::success ();
826+ }
827+ mlir::Type type;
828+ if (p.parseType (type))
829+ return mlir::failure ();
830+ if (isa<cir::VoidType>(type))
831+ // An explicit !cir.void means also no return type.
832+ optionalReturnType = {};
833+ else
834+ // Otherwise use the actual type.
835+ optionalReturnType = type;
836+ return p.parseLParen ();
837+ }
838+
839+ // A special pretty-printer for function returning or not a result.
840+ static void printFuncTypeReturn (mlir::AsmPrinter &p,
841+ mlir::Type optionalReturnType) {
842+ if (optionalReturnType)
843+ p << optionalReturnType << ' ' ;
844+ p << ' (' ;
845+ }
846+
847+ static mlir::ParseResult
848+ parseFuncTypeArgs (mlir::AsmParser &p, llvm::SmallVector<mlir::Type> ¶ms,
849+ bool &isVarArg) {
819850 isVarArg = false ;
820851 // `(` `)`
821852 if (succeeded (p.parseOptionalRParen ()))
@@ -845,8 +876,9 @@ mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
845876 return p.parseRParen ();
846877}
847878
848- void printFuncTypeArgs (mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
849- bool isVarArg) {
879+ static void printFuncTypeArgs (mlir::AsmPrinter &p,
880+ mlir::ArrayRef<mlir::Type> params,
881+ bool isVarArg) {
850882 llvm::interleaveComma (params, p,
851883 [&p](mlir::Type type) { p.printType (type); });
852884 if (isVarArg) {
@@ -857,11 +889,49 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
857889 p << ' )' ;
858890}
859891
892+ // Use a custom parser to handle the optional return and argument types without
893+ // an optional anchor.
894+ static mlir::ParseResult parseFuncType (mlir::AsmParser &p,
895+ mlir::Type &optionalReturnTypes,
896+ llvm::SmallVector<mlir::Type> ¶ms,
897+ bool &isVarArg) {
898+ if (failed (parseFuncTypeReturn (p, optionalReturnTypes)))
899+ return failure ();
900+ return parseFuncTypeArgs (p, params, isVarArg);
901+ }
902+
903+ static void printFuncType (mlir::AsmPrinter &p, mlir::Type optionalReturnTypes,
904+ mlir::ArrayRef<mlir::Type> params, bool isVarArg) {
905+ printFuncTypeReturn (p, optionalReturnTypes);
906+ printFuncTypeArgs (p, params, isVarArg);
907+ }
908+
909+ // Return the actual return type or an explicit !cir.void if the function does
910+ // not return anything
911+ mlir::Type FuncType::getReturnType () const {
912+ if (isVoid ())
913+ return cir::VoidType::get (getContext ());
914+ return static_cast <detail::FuncTypeStorage *>(getImpl ())->optionalReturnType ;
915+ }
916+
917+ // / Returns the result type of the function as an ArrayRef, enabling better
918+ // / integration with generic MLIR utilities.
860919llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes () const {
861- return static_cast <detail::FuncTypeStorage *>(getImpl ())->returnType ;
920+ if (isVoid ())
921+ return {};
922+ return static_cast <detail::FuncTypeStorage *>(getImpl ())->optionalReturnType ;
862923}
863924
864- bool FuncType::isVoid () const { return mlir::isa<VoidType>(getReturnType ()); }
925+ // Whether the function returns void
926+ bool FuncType::isVoid () const {
927+ auto rt =
928+ static_cast <detail::FuncTypeStorage *>(getImpl ())->optionalReturnType ;
929+ assert (!rt ||
930+ !mlir::isa<cir::VoidType>(rt) &&
931+ " The return type for a function returning void should be empty "
932+ " instead of a real !cir.void" );
933+ return !rt;
934+ }
865935
866936// ===----------------------------------------------------------------------===//
867937// MethodType Definitions
0 commit comments