diff --git a/src/core_functions/scalar/generic/typeof.cpp b/src/core_functions/scalar/generic/typeof.cpp index a1b01f8c0ca1..564988053368 100644 --- a/src/core_functions/scalar/generic/typeof.cpp +++ b/src/core_functions/scalar/generic/typeof.cpp @@ -1,15 +1,57 @@ #include "duckdb/core_functions/scalar/generic_functions.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" namespace duckdb { +struct ConstantReturnBindData : public FunctionData { + Value val; + + explicit ConstantReturnBindData(Value val_p) : val(std::move(val_p)) { + } + + unique_ptr Copy() const override { + return make_uniq(val); + } + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return val == other.val; + } + static void Serialize(Serializer &serializer, const optional_ptr bind_data, + const ScalarFunction &function) { + auto &info = bind_data->Cast(); + serializer.WriteProperty(100, "constant_value", info.val); + } + + static unique_ptr Deserialize(Deserializer &deserializer, ScalarFunction &bound_function) { + auto value = deserializer.ReadProperty(100, "constant_value"); + return make_uniq(std::move(value)); + } +}; + static void TypeOfFunction(DataChunk &args, ExpressionState &state, Vector &result) { - Value v(args.data[0].GetType().ToString()); - result.Reference(v); + if (args.ColumnCount() == 1) { + Value v(args.data[0].GetType().ToString()); + result.Reference(v); + } else { + auto &bind_data = state.expr.Cast().bind_info->Cast(); + result.Reference(bind_data.val); + } +} + +unique_ptr BindTypeOfFunction(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + Value return_value(arguments[0]->return_type.ToString()); + arguments.clear(); + return make_uniq(return_value); } ScalarFunction TypeOfFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, TypeOfFunction); + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, TypeOfFunction, BindTypeOfFunction); fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + fun.serialize = ConstantReturnBindData::Serialize; + fun.deserialize = ConstantReturnBindData::Deserialize; return fun; }