From be5ab860cb122e3c26dd55c2aa49a739089f0790 Mon Sep 17 00:00:00 2001 From: jeremyosterhoudt Date: Mon, 28 Aug 2023 08:25:47 -0700 Subject: [PATCH] GH-36078: [C#] Flight SQL implementation for C# (#36079) Flight SQL implementation for C# that is compatible with the C++ and Java implementations. Closes issue #36078 * Closes: #36078 Lead-authored-by: Jeremy Osterhoudt Co-authored-by: Weston Pace Signed-off-by: Weston Pace --- csharp/Apache.Arrow.sln | 12 + .../Apache.Arrow.Flight.Sql.csproj | 18 + .../FlightSqlServer.cs | 389 ++++++++++++++++++ .../Apache.Arrow.Flight.Sql/FlightSqlUtils.cs | 158 +++++++ .../src/Apache.Arrow.Flight.Sql/SqlActions.cs | 22 + .../Client/FlightClient.cs | 20 +- .../FlightClientRecordBatchStreamReader.cs | 2 +- .../FlightClientRecordBatchStreamWriter.cs | 4 +- csharp/src/Apache.Arrow.Flight/FlightData.cs | 53 +++ .../FlightHandshakeRequest.cs | 58 +++ .../FlightHandshakeResponse.cs | 63 +++ .../FlightRecordBatchStreamWriter.cs | 4 +- .../Internal/FlightDataStream.cs | 4 +- .../Internal/SchemaWriter.cs | 12 +- .../Server/FlightServer.cs | 7 +- .../FlightServerRecordBatchStreamReader.cs | 7 +- .../FlightServerRecordBatchStreamWriter.cs | 10 +- .../Internal/FlightServerImplementation.cs | 30 +- .../Server/Internal/HandshakeAdapters.cs | 40 ++ .../Apache.Arrow.Flight.Sql.Tests.csproj | 19 + .../FlightSqlServerTests.cs | 375 +++++++++++++++++ .../FlightSqlTestExtensions.cs | 27 ++ .../FlightSqlUtilsTests.cs | 73 ++++ .../TestFlightSqlSever.cs | 89 ++++ .../TestFlightServer.cs | 16 + .../Apache.Arrow.Flight.Tests/FlightTests.cs | 26 ++ dev/release/rat_exclude_files.txt | 2 + 27 files changed, 1508 insertions(+), 32 deletions(-) create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/FlightSqlUtils.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs create mode 100644 csharp/src/Apache.Arrow.Flight/FlightData.cs create mode 100644 csharp/src/Apache.Arrow.Flight/FlightHandshakeRequest.cs create mode 100644 csharp/src/Apache.Arrow.Flight/FlightHandshakeResponse.cs create mode 100644 csharp/src/Apache.Arrow.Flight/Server/Internal/HandshakeAdapters.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlServerTests.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlUtilsTests.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs diff --git a/csharp/Apache.Arrow.sln b/csharp/Apache.Arrow.sln index baf4bc6129598..7e7f7c6331e88 100644 --- a/csharp/Apache.Arrow.sln +++ b/csharp/Apache.Arrow.sln @@ -23,6 +23,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Compression", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Compression.Tests", "test\Apache.Arrow.Compression.Tests\Apache.Arrow.Compression.Tests.csproj", "{5D7FF380-B7DF-4752-B415-7C08C70C9F06}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql.Tests", "test\Apache.Arrow.Flight.Sql.Tests\Apache.Arrow.Flight.Sql.Tests.csproj", "{DCC99EB1-4E60-4F0D-AEA9-C44A4C0C8B1D}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql", "src\Apache.Arrow.Flight.Sql\Apache.Arrow.Flight.Sql.csproj", "{2ADE087A-B424-4895-8CC5-10170D10BA62}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -69,6 +73,14 @@ Global {5D7FF380-B7DF-4752-B415-7C08C70C9F06}.Debug|Any CPU.Build.0 = Debug|Any CPU {5D7FF380-B7DF-4752-B415-7C08C70C9F06}.Release|Any CPU.ActiveCfg = Release|Any CPU {5D7FF380-B7DF-4752-B415-7C08C70C9F06}.Release|Any CPU.Build.0 = Release|Any CPU + {DCC99EB1-4E60-4F0D-AEA9-C44A4C0C8B1D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {DCC99EB1-4E60-4F0D-AEA9-C44A4C0C8B1D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {DCC99EB1-4E60-4F0D-AEA9-C44A4C0C8B1D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {DCC99EB1-4E60-4F0D-AEA9-C44A4C0C8B1D}.Release|Any CPU.Build.0 = Release|Any CPU + {2ADE087A-B424-4895-8CC5-10170D10BA62}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2ADE087A-B424-4895-8CC5-10170D10BA62}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2ADE087A-B424-4895-8CC5-10170D10BA62}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2ADE087A-B424-4895-8CC5-10170D10BA62}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj b/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj new file mode 100644 index 0000000000000..50570d628924b --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj @@ -0,0 +1,18 @@ + + + netstandard2.1 + enable + + + + + + + + + + + + + + diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs new file mode 100644 index 0000000000000..dbfc1f7c7ea49 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs @@ -0,0 +1,389 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Server; +using Apache.Arrow.Types; +using Arrow.Flight.Protocol.Sql; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using Grpc.Core; +using Microsoft.Extensions.Logging; + +namespace Apache.Arrow.Flight.Sql; + +public abstract class FlightSqlServer : FlightServer +{ + private ILogger? Logger { get; } + public static readonly Schema CatalogSchema = new(new List {new("catalog_name", StringType.Default, false)}, null); + public static readonly Schema TableTypesSchema = new(new List {new("table_type", StringType.Default, false)}, null); + public static readonly Schema DbSchemaFlightSchema = new(new List {new("catalog_name", StringType.Default, true), new("db_schema_name", StringType.Default, false)}, null); + + public static readonly Schema PrimaryKeysSchema = new(new List + { + new("catalog_name", StringType.Default, true), + new("db_schema_name", StringType.Default, true), + new("table_name", StringType.Default, false), + new("column_name", StringType.Default, false), + new("key_sequence", Int32Type.Default, false), + new("key_name", StringType.Default, true) + }, null); + + public static readonly Schema KeyImportExportSchema = new(new List + { + new("pk_catalog_name", StringType.Default, true), + new("pk_db_schema_name", StringType.Default, true), + new("pk_table_name", StringType.Default, false), + new("pk_column_name", StringType.Default, false), + new("fk_catalog_name", StringType.Default, true), + new("fk_db_schema_name", StringType.Default, true), + new("fk_table_name", StringType.Default, false), + new("fk_column_name", StringType.Default, false), + new("key_sequence", Int32Type.Default, false), + new("fk_key_name", StringType.Default, true), + new("pk_key_name", StringType.Default, true), + new("update_rule", UInt8Type.Default, false), + new("delete_rule", UInt8Type.Default, false) + }, null); + + public static readonly Schema TypeInfoSchema = new(new List + { + new("type_name", StringType.Default, false), + new("data_type", Int32Type.Default, false), + new("column_size", Int32Type.Default, true), + new("literal_prefix", StringType.Default, true), + new("literal_suffix", StringType.Default, true), + new("create_params", new ListType(new Field("item", StringType.Default, false)), true), + new("nullable", Int32Type.Default, false), + new("case_sensitive", BooleanType.Default, false), + new("searchable", Int32Type.Default, false), + new("unsigned_attribute", BooleanType.Default, true), + new("fixed_prec_scale", BooleanType.Default, false), + new("auto_increment", BooleanType.Default, true), + new("local_type_name", StringType.Default, true), + new("minimum_scale", Int32Type.Default, true), + new("maximum_scale", Int32Type.Default, true), + new("sql_data_type", Int32Type.Default, false), + new("datetime_subcode", Int32Type.Default, true), + new("num_prec_radix", Int32Type.Default, true), + new("interval_precision", Int32Type.Default, true) + }, null); + + public static readonly Schema SqlInfoSchema = new(new List + { + new("info_name", UInt32Type.Default, false) + //TODO: once we have union serialization in Arrow Flight for .Net we should to add these fields + // fieldList.Add(new Field("value", new UnionType(new List(), new List()), false)); + // fieldList.Add(new Field("value", new UnionType(new [] + // { + // new Field("string_value", StringType.Default, false), + // new Field("bool_value", BooleanType.Default, false), + // new Field("bigint_value", Int64Type.Default, false), + // new Field("bool_value", BooleanType.Default, false), + // new Field("bigint_value", Int64Type.Default, false), + // new Field("int32_bitmask", Int32Type.Default, false), + // new Field("string_list", new ListType(new Field("item", StringType.Default, false)), false), + // new Field("int32_to_int32_list_map", new DictionaryType(Int32Type.Default, new ListType(Int32Type.Default), false), false), + // }, new []{(byte)ArrowTypeId.String, (byte)ArrowTypeId.Boolean, (byte)ArrowTypeId.Int64,/* (byte)3, (byte)4, (byte)5*/}, UnionMode.Dense), false)); + }, null); + + private static readonly Schema s_tableSchema = new(new List + { + new("catalog_name", StringType.Default, true), + new("db_schema_name", StringType.Default, true), + new("table_name", StringType.Default, false), + new("table_type", StringType.Default, false) + }, null); + + public static Schema GetTableSchema(bool includeTableSchemaField) + { + if (!includeTableSchemaField) + { + return s_tableSchema; + } + + var fields = s_tableSchema.FieldsList.ToList(); + fields.Add(new Field("table_schema", BinaryType.Default, false)); + + return new Schema(fields, s_tableSchema.Metadata); + } + + public static IMessage? GetCommand(FlightTicket ticket) + { + try + { + return GetCommand(Any.Parser.ParseFrom(ticket.Ticket)); + } + catch (InvalidProtocolBufferException) { } //The ticket is not a flight sql command + + return null; + } + + public static async Task GetCommand(FlightServerRecordBatchStreamReader requestStream) + { + return GetCommand(await requestStream.FlightDescriptor.ConfigureAwait(false)); + } + + public static IMessage? GetCommand(FlightDescriptor? request) + { + if (request == null) return null; + if (request.Type == FlightDescriptorType.Command && request.ParsedAndUnpackedMessage() is { } command) + { + return command; + } + + return null; + } + + private static IMessage? GetCommand(Any command) + { + if (command.Is(CommandPreparedStatementQuery.Descriptor)) + { + return command.Unpack(); + } + + if (command.Is(CommandGetSqlInfo.Descriptor)) + { + return command.Unpack(); + } + + if (command.Is(CommandGetCatalogs.Descriptor)) + { + return command.Unpack(); + } + + if (command.Is(CommandGetTableTypes.Descriptor)) + { + return command.Unpack(); + } + + if (command.Is(CommandGetTables.Descriptor)) + { + return command.Unpack(); + } + + if (command.Is(CommandGetDbSchemas.Descriptor)) + { + return command.Unpack(); + } + + if (command.Is(CommandGetPrimaryKeys.Descriptor)) + { + return command.Unpack(); + } + + if (command.Is(CommandGetExportedKeys.Descriptor)) + { + return command.Unpack(); + } + + if (command.Is(CommandGetImportedKeys.Descriptor)) + { + return command.Unpack(); + } + + if (command.Is(CommandGetCrossReference.Descriptor)) + { + return command.Unpack(); + } + + if (command.Is(CommandGetXdbcTypeInfo.Descriptor)) + { + return command.Unpack(); + } + + return null; + } + + protected FlightSqlServer(ILoggerFactory? factory = null) + { + Logger = factory?.CreateLogger(); + } + + /// + /// Lists actions supported by Flight SQL. For Flight RPC actions support + /// implementers should extend this method to return additional supported actions. + /// + public override async Task ListActions(IAsyncStreamWriter responseStream, ServerCallContext context) + { + foreach (var actionType in FlightSqlUtils.FlightSqlActions) + { + await responseStream.WriteAsync(actionType).ConfigureAwait(false); + } + } + + /// + /// Attempts to execute a valid Flight SQL command. For Flight RPC calls + /// implementers should extend this method in order to handle RPC messages. + /// + public override Task GetFlightInfo(FlightDescriptor flightDescriptor, ServerCallContext context) + { + var sqlCommand = GetCommand(flightDescriptor); + Logger?.LogTrace("Executing Flight SQL FlightInfo command: {DescriptorName}", sqlCommand?.Descriptor.Name); + return sqlCommand switch + { + CommandStatementQuery command => GetStatementQueryFlightInfo(command, flightDescriptor, context), + CommandPreparedStatementQuery command => GetPreparedStatementQueryFlightInfo(command, flightDescriptor, context), + CommandGetCatalogs command => GetCatalogFlightInfo(command, flightDescriptor, context), + CommandGetDbSchemas command => GetDbSchemaFlightInfo(command, flightDescriptor, context), + CommandGetTables command => GetTablesFlightInfo(command, flightDescriptor, context), + CommandGetTableTypes command => GetTableTypesFlightInfo(command, flightDescriptor, context), + CommandGetSqlInfo command => GetSqlFlightInfo(command, flightDescriptor, context), + CommandGetPrimaryKeys command => GetPrimaryKeysFlightInfo(command, flightDescriptor, context), + CommandGetExportedKeys command => GetExportedKeysFlightInfo(command, flightDescriptor, context), + CommandGetImportedKeys command => GetImportedKeysFlightInfo(command, flightDescriptor, context), + CommandGetCrossReference command => GetCrossReferenceFlightInfo(command, flightDescriptor, context), + CommandGetXdbcTypeInfo command => GetXdbcTypeFlightInfo(command, flightDescriptor, context), + _ => throw new InvalidOperationException($"command type {sqlCommand?.Descriptor?.Name} not supported") + }; + } + + /// + /// Attempts to execute a valid Flight SQL command. For Flight RPC calls + /// implementers should extend this method in order to handle RPC messages. + /// + public override Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) + { + var sqlCommand = GetCommand(ticket); + Logger?.LogTrace("Executing Flight SQL DoGet command: {SqlCommandDescriptor}", sqlCommand?.Descriptor); + return sqlCommand switch + { + CommandPreparedStatementQuery command => DoGetPreparedStatementQuery(command, responseStream, context), + CommandGetSqlInfo command => DoGetSqlInfo(command, responseStream, context), + CommandGetCatalogs command => DoGetCatalog(command, responseStream, context), + CommandGetTableTypes command => DoGetTableType(command, responseStream, context), + CommandGetTables command => DoGetTables(command, responseStream, context), + CommandGetDbSchemas command => DoGetDbSchema(command, responseStream, context), + CommandGetPrimaryKeys command => DoGetPrimaryKeys(command, responseStream, context), + CommandGetExportedKeys command => DoGetExportedKeys(command, responseStream, context), + CommandGetImportedKeys command => DoGetImportedKeys(command, responseStream, context), + CommandGetCrossReference command => DoGetCrossReference(command, responseStream, context), + CommandGetXdbcTypeInfo command => DoGetXbdcTypeInfo(command, responseStream, context), + _ => throw new RpcException(new Status(StatusCode.InvalidArgument, $"DoGet command {sqlCommand?.Descriptor} is not supported.")) + }; + } + + /// + /// Attempts to execute a valid Flight SQL command. For Flight RPC calls + /// implementers should extend this method in order to handle RPC messages. + /// + public override Task DoAction(FlightAction action, IAsyncStreamWriter responseStream, ServerCallContext context) + { + Logger?.LogTrace("Executing Flight SQL DoAction: {ActionType}", action.Type); + switch (action.Type) + { + case SqlAction.CreateRequest: + var command = FlightSqlUtils.ParseAndUnpack(action.Body); + return CreatePreparedStatement(command, action, responseStream, context); + case SqlAction.CloseRequest: + var closeCommand = FlightSqlUtils.ParseAndUnpack(action.Body); + return ClosePreparedStatement(closeCommand, action, responseStream, context); + default: + throw new NotImplementedException($"Action type {action.Type} not supported"); + } + } + + /// + /// Attempts to execute a valid Flight SQL command. For Flight RPC calls + /// implementers should extend this method in order to handle RPC messages. + /// + public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) + { + if (await GetCommand(requestStream).ConfigureAwait(false) is { } command) + { + await DoPutInternal(command, requestStream, responseStream, context).ConfigureAwait(false); + } + else + { + throw new NotImplementedException(); + } + } + + private Task DoPutInternal(IMessage command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) + { + Logger?.LogTrace("Executing Flight SQL DoAction: {DescriptorName}", command.Descriptor.Name); + return command switch + { + CommandStatementUpdate statementUpdate => PutStatementUpdate(statementUpdate, requestStream, responseStream, context), + CommandPreparedStatementQuery preparedStatementQuery => PutPreparedStatementQuery(preparedStatementQuery, requestStream, responseStream, context), + CommandPreparedStatementUpdate preparedStatementUpdate => PutPreparedStatementUpdate(preparedStatementUpdate, requestStream, responseStream, context), + _ => throw new NotImplementedException($"Command {command.Descriptor.Name} not supported") + }; + } + + public static bool SupportsAction(FlightAction action) + { + switch (action.Type) + { + case SqlAction.CreateRequest: + case SqlAction.CloseRequest: + return true; + default: + return false; + } + } + + + #region FlightInfo + + protected abstract Task GetStatementQueryFlightInfo(CommandStatementQuery commandStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext context); + protected abstract Task GetPreparedStatementQueryFlightInfo(CommandPreparedStatementQuery preparedStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext context); + protected abstract Task GetCatalogFlightInfo(CommandGetCatalogs command, FlightDescriptor flightDescriptor, ServerCallContext context); + protected abstract Task GetDbSchemaFlightInfo(CommandGetDbSchemas command, FlightDescriptor flightDescriptor, ServerCallContext context); + protected abstract Task GetTablesFlightInfo(CommandGetTables command, FlightDescriptor flightDescriptor, ServerCallContext context); + protected abstract Task GetTableTypesFlightInfo(CommandGetTableTypes command, FlightDescriptor flightDescriptor, ServerCallContext context); + protected abstract Task GetSqlFlightInfo(CommandGetSqlInfo commandGetSqlInfo, FlightDescriptor flightDescriptor, ServerCallContext context); + protected abstract Task GetPrimaryKeysFlightInfo(CommandGetPrimaryKeys command, FlightDescriptor flightDescriptor, ServerCallContext context); + protected abstract Task GetExportedKeysFlightInfo(CommandGetExportedKeys command, FlightDescriptor flightDescriptor, ServerCallContext context); + protected abstract Task GetImportedKeysFlightInfo(CommandGetImportedKeys command, FlightDescriptor flightDescriptor, ServerCallContext context); + protected abstract Task GetCrossReferenceFlightInfo(CommandGetCrossReference command, FlightDescriptor flightDescriptor, ServerCallContext context); + protected abstract Task GetXdbcTypeFlightInfo(CommandGetXdbcTypeInfo command, FlightDescriptor flightDescriptor, ServerCallContext context); + + #endregion + + #region DoGet + + protected abstract Task DoGetPreparedStatementQuery(CommandPreparedStatementQuery preparedStatementQuery, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context); + protected abstract Task DoGetSqlInfo(CommandGetSqlInfo getSqlInfo, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context); + protected abstract Task DoGetCatalog(CommandGetCatalogs command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context); + protected abstract Task DoGetTableType(CommandGetTableTypes command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context); + protected abstract Task DoGetTables(CommandGetTables command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context); + protected abstract Task DoGetPrimaryKeys(CommandGetPrimaryKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context); + protected abstract Task DoGetDbSchema(CommandGetDbSchemas command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context); + protected abstract Task DoGetExportedKeys(CommandGetExportedKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context); + protected abstract Task DoGetImportedKeys(CommandGetImportedKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context); + protected abstract Task DoGetCrossReference(CommandGetCrossReference command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context); + protected abstract Task DoGetXbdcTypeInfo(CommandGetXdbcTypeInfo command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context); + + #endregion + + #region DoAction + + protected abstract Task CreatePreparedStatement(ActionCreatePreparedStatementRequest request, FlightAction action, IAsyncStreamWriter streamWriter, ServerCallContext context); + protected abstract Task ClosePreparedStatement(ActionClosePreparedStatementRequest request, FlightAction action, IAsyncStreamWriter streamWriter, ServerCallContext context); + + #endregion + + #region DoPut + + protected abstract Task PutPreparedStatementUpdate(CommandPreparedStatementUpdate command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context); + protected abstract Task PutStatementUpdate(CommandStatementUpdate command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context); + protected abstract Task PutPreparedStatementQuery(CommandPreparedStatementQuery command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context); + + #endregion +} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlUtils.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlUtils.cs new file mode 100644 index 0000000000000..295fe4d32a9f6 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlUtils.cs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Buffers; +using System.Collections.Generic; +using Arrow.Flight.Protocol.Sql; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; + +namespace Apache.Arrow.Flight.Sql +{ + /// + /// Helper methods for doing common Flight Sql tasks and conversions + /// + public class FlightSqlUtils + { + public static readonly FlightActionType FlightSqlCreatePreparedStatement = new("CreatePreparedStatement", + "Creates a reusable prepared statement resource on the server. \n" + + "Request Message: ActionCreatePreparedStatementRequest\n" + + "Response Message: ActionCreatePreparedStatementResult"); + + public static readonly FlightActionType FlightSqlClosePreparedStatement = new("ClosePreparedStatement", + "Closes a reusable prepared statement resource on the server. \n" + + "Request Message: ActionClosePreparedStatementRequest\n" + + "Response Message: N/A"); + + /// + /// List of possible actions + /// + public static readonly List FlightSqlActions = new() + { + FlightSqlCreatePreparedStatement, + FlightSqlClosePreparedStatement + }; + + /// + /// Helper to parse {@link com.google.protobuf.Any} objects to the specific protobuf object. + /// + /// the raw bytes source value. + /// the materialized protobuf object. + public static Any Parse(ByteString source) + { + return Any.Parser.ParseFrom(source); + } + + /// + /// Helper to unpack {@link com.google.protobuf.Any} objects to the specific protobuf object. + /// + /// the parsed Source value. + /// IMessage + /// the materialized protobuf object. + public static T Unpack(Any source) where T : IMessage, new() + { + return source.Unpack(); + } + + /// + /// Helper to parse and unpack {@link com.google.protobuf.Any} objects to the specific protobuf object. + /// + /// the raw bytes source value. + /// IMessage + /// the materialized protobuf object. + public static T ParseAndUnpack(ByteString source) where T : IMessage, new() + { + return Unpack(Parse(source)); + } + } + + /// + /// A set of helper functions for converting encoded commands to IMessage types + /// + public static class FlightSqlExtensions + { + private static Any ParsedCommand(this FlightDescriptor descriptor) + { + return FlightSqlUtils.Parse(descriptor.Command); + } + + private static IMessage UnpackMessage(this Any command) + { + if (command.Is(CommandStatementQuery.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandPreparedStatementQuery.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandGetCatalogs.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandGetDbSchemas.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandGetTables.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandGetTableTypes.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandGetSqlInfo.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandGetPrimaryKeys.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandGetExportedKeys.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandGetImportedKeys.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandGetCrossReference.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandGetXdbcTypeInfo.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(TicketStatementQuery.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(TicketStatementQuery.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandStatementUpdate.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandPreparedStatementUpdate.Descriptor)) + return FlightSqlUtils.Unpack(command); + if (command.Is(CommandPreparedStatementQuery.Descriptor)) + return FlightSqlUtils.Unpack(command); + + throw new ArgumentException("The defined request is invalid."); + } + + /// + /// Extracts a command from a FlightDescriptor + /// + /// + /// An IMessage that has been parsed and unpacked + public static IMessage? ParsedAndUnpackedMessage(this FlightDescriptor descriptor) + { + try + { + return descriptor.ParsedCommand().UnpackMessage(); + } + catch (ArgumentException) + { + return null; + } + } + + public static ByteString Serialize(this IBufferMessage message) + { + int size = message.CalculateSize(); + var writer = new ArrayBufferWriter(size); + message.WriteTo(writer); + var schemaBytes = writer.WrittenSpan; + return ByteString.CopyFrom(schemaBytes); + } + } +} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs b/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs new file mode 100644 index 0000000000000..f3f3bef1e1d00 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace Apache.Arrow.Flight.Sql; + +public static class SqlAction +{ + public const string CreateRequest = "CreatePreparedStatement"; + public const string CloseRequest = "ClosePreparedStatement"; +} diff --git a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs index 8140e06493dc2..5dc0d1b434b6d 100644 --- a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs +++ b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs @@ -13,11 +13,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -using System; -using System.Collections.Generic; using System.Threading.Tasks; using Apache.Arrow.Flight.Internal; using Apache.Arrow.Flight.Protocol; +using Apache.Arrow.Flight.Server; +using Apache.Arrow.Flight.Server.Internal; using Grpc.Core; using Grpc.Net.Client; @@ -93,6 +93,22 @@ public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDesc channels.Dispose); } + public AsyncDuplexStreamingCall Handshake(Metadata headers = null) + { + var channel = _client.Handshake(headers); + var readStream = new StreamReader(channel.ResponseStream, response => new FlightHandshakeResponse(response)); + var writeStream = new FlightHandshakeStreamWriterAdapter(channel.RequestStream); + var call = new AsyncDuplexStreamingCall( + writeStream, + readStream, + channel.ResponseHeadersAsync, + channel.GetStatus, + channel.GetTrailers, + channel.Dispose); + + return call; + } + public AsyncServerStreamingCall DoAction(FlightAction action, Metadata headers = null) { var stream = _client.DoAction(action.ToProtocol(), headers); diff --git a/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamReader.cs b/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamReader.cs index 011af0c831508..73094338be4cd 100644 --- a/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamReader.cs +++ b/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamReader.cs @@ -21,7 +21,7 @@ namespace Apache.Arrow.Flight.Client { public class FlightClientRecordBatchStreamReader : FlightRecordBatchStreamReader { - internal FlightClientRecordBatchStreamReader(IAsyncStreamReader flightDataStream) : base(flightDataStream) + internal FlightClientRecordBatchStreamReader(IAsyncStreamReader flightDataStream) : base(flightDataStream) { } } diff --git a/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamWriter.cs b/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamWriter.cs index d2e62c42e8621..e5fa30c9f6aed 100644 --- a/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamWriter.cs +++ b/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamWriter.cs @@ -25,9 +25,9 @@ namespace Apache.Arrow.Flight.Client { public class FlightClientRecordBatchStreamWriter : FlightRecordBatchStreamWriter, IClientStreamWriter { - private readonly IClientStreamWriter _clientStreamWriter; + private readonly IClientStreamWriter _clientStreamWriter; private bool _completed = false; - internal FlightClientRecordBatchStreamWriter(IClientStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor) : base(clientStreamWriter, flightDescriptor) + internal FlightClientRecordBatchStreamWriter(IClientStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor) : base(clientStreamWriter, flightDescriptor) { _clientStreamWriter = clientStreamWriter; } diff --git a/csharp/src/Apache.Arrow.Flight/FlightData.cs b/csharp/src/Apache.Arrow.Flight/FlightData.cs new file mode 100644 index 0000000000000..f38b1de2206a8 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/FlightData.cs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Google.Protobuf; + +namespace Apache.Arrow.Flight; + +public class FlightData +{ + public FlightDescriptor Descriptor { get; } + public ByteString AppMetadata { get; } + public ByteString DataBody { get; } + public ByteString DataHeader { get; } + + public FlightData(FlightDescriptor descriptor, ByteString dataBody = null, ByteString dataHeader = null, ByteString appMetadata = null) + { + Descriptor = descriptor; + DataBody = dataBody; + DataHeader = dataHeader; + AppMetadata = appMetadata; + } + + internal FlightData(Protocol.FlightData protocolFlightData) + { + Descriptor = protocolFlightData.FlightDescriptor == null ? null : new FlightDescriptor(protocolFlightData.FlightDescriptor); + DataBody = protocolFlightData.DataBody; + DataHeader = protocolFlightData.DataHeader; + AppMetadata = protocolFlightData.AppMetadata; + } + + internal Protocol.FlightData ToProtocol() + { + return new Protocol.FlightData + { + FlightDescriptor = Descriptor?.ToProtocol(), + AppMetadata = AppMetadata, + DataBody = DataBody, + DataHeader = DataHeader + }; + } +} diff --git a/csharp/src/Apache.Arrow.Flight/FlightHandshakeRequest.cs b/csharp/src/Apache.Arrow.Flight/FlightHandshakeRequest.cs new file mode 100644 index 0000000000000..62db6446072a4 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/FlightHandshakeRequest.cs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Google.Protobuf; + +namespace Apache.Arrow.Flight; + +public class FlightHandshakeRequest +{ + private readonly Protocol.HandshakeRequest _result; + public ByteString Payload => _result.Payload; + public ulong ProtocolVersion => _result.ProtocolVersion; + + internal FlightHandshakeRequest(Protocol.HandshakeRequest result) + { + _result = result; + } + + public FlightHandshakeRequest(ByteString payload, ulong protocolVersion = 1) + { + _result = new Protocol.HandshakeRequest + { + Payload = payload, + ProtocolVersion = protocolVersion + }; + } + + internal Protocol.HandshakeRequest ToProtocol() + { + return _result; + } + + public override bool Equals(object obj) + { + if(obj is FlightHandshakeRequest other) + { + return Equals(_result, other._result); + } + return false; + } + + public override int GetHashCode() + { + return _result.GetHashCode(); + } +} diff --git a/csharp/src/Apache.Arrow.Flight/FlightHandshakeResponse.cs b/csharp/src/Apache.Arrow.Flight/FlightHandshakeResponse.cs new file mode 100644 index 0000000000000..4ceb288f8eda1 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/FlightHandshakeResponse.cs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Google.Protobuf; + +namespace Apache.Arrow.Flight; + +public class FlightHandshakeResponse +{ + public static readonly FlightHandshakeResponse Empty = new FlightHandshakeResponse(); + private readonly Protocol.HandshakeResponse _handshakeResponse; + + public ulong ProtocolVersion + { + get => _handshakeResponse.ProtocolVersion; + set => _handshakeResponse.ProtocolVersion = value; + } + + public ByteString Payload + { + get => _handshakeResponse.Payload; + set => _handshakeResponse.Payload = value; + } + + public FlightHandshakeResponse() + { + _handshakeResponse = new Protocol.HandshakeResponse + { + ProtocolVersion = 1 + }; + } + + internal FlightHandshakeResponse(Protocol.HandshakeResponse handshakeResponse) + { + _handshakeResponse = handshakeResponse; + } + + public FlightHandshakeResponse(ByteString payload, ulong protocolVersion = 1) + { + _handshakeResponse = new Protocol.HandshakeResponse + { + ProtocolVersion = protocolVersion, + Payload = payload + }; + } + + internal Protocol.HandshakeResponse ToProtocol() + { + return _handshakeResponse; + } +} diff --git a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs index a72be5a823403..f76f08224541f 100644 --- a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs +++ b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs @@ -27,12 +27,12 @@ namespace Apache.Arrow.Flight public abstract class FlightRecordBatchStreamWriter : IAsyncStreamWriter, IDisposable { private FlightDataStream _flightDataStream; - private readonly IAsyncStreamWriter _clientStreamWriter; + private readonly IAsyncStreamWriter _clientStreamWriter; private readonly FlightDescriptor _flightDescriptor; private bool _disposed; - private protected FlightRecordBatchStreamWriter(IAsyncStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor) + private protected FlightRecordBatchStreamWriter(IAsyncStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor) { _clientStreamWriter = clientStreamWriter; _flightDescriptor = flightDescriptor; diff --git a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs index 3211212c99cb9..72c1551be2917 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs @@ -34,10 +34,10 @@ namespace Apache.Arrow.Flight.Internal internal class FlightDataStream : ArrowStreamWriter { private readonly FlightDescriptor _flightDescriptor; - private readonly IAsyncStreamWriter _clientStreamWriter; + private readonly IAsyncStreamWriter _clientStreamWriter; private Protocol.FlightData _currentFlightData; - public FlightDataStream(IAsyncStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor, Schema schema) + public FlightDataStream(IAsyncStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor, Schema schema) : base(new MemoryStream(), schema) { _clientStreamWriter = clientStreamWriter; diff --git a/csharp/src/Apache.Arrow.Flight/Internal/SchemaWriter.cs b/csharp/src/Apache.Arrow.Flight/Internal/SchemaWriter.cs index c7e7d8135a1dd..be27cb1e39161 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/SchemaWriter.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/SchemaWriter.cs @@ -20,6 +20,7 @@ using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Flatbuf; +using Apache.Arrow.Flight.Internal; using Apache.Arrow.Ipc; using Google.Protobuf; @@ -30,7 +31,7 @@ namespace Apache.Arrow.Flight.Internal /// internal class SchemaWriter : ArrowStreamWriter { - private SchemaWriter(Stream baseStream, Schema schema) : base(baseStream, schema) + internal SchemaWriter(Stream baseStream, Schema schema) : base(baseStream, schema) { } @@ -53,3 +54,12 @@ public void WriteSchema(Schema schema, CancellationToken cancellationToken) } } } + +public static class SchemaExtension +{ + // Translate an Apache.Arrow.Schema to FlatBuffer Schema to ByteString + public static ByteString ToByteString(this Apache.Arrow.Schema schema) + { + return SchemaWriter.SerializeSchema(schema); + } +} diff --git a/csharp/src/Apache.Arrow.Flight/Server/FlightServer.cs b/csharp/src/Apache.Arrow.Flight/Server/FlightServer.cs index 30b0409d422fb..0005caf175f67 100644 --- a/csharp/src/Apache.Arrow.Flight/Server/FlightServer.cs +++ b/csharp/src/Apache.Arrow.Flight/Server/FlightServer.cs @@ -14,8 +14,6 @@ // limitations under the License. using System; -using System.Collections.Generic; -using System.Text; using System.Threading.Tasks; using Grpc.Core; @@ -57,5 +55,10 @@ public virtual Task GetFlightInfo(FlightDescriptor request, ServerCa { throw new NotImplementedException(); } + + public virtual Task Handshake(IAsyncStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) + { + throw new NotImplementedException(); + } } } diff --git a/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamReader.cs b/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamReader.cs index 5476d3d0e5ff9..c52b761ad38d9 100644 --- a/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamReader.cs +++ b/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamReader.cs @@ -14,7 +14,6 @@ // limitations under the License. using System.Threading.Tasks; -using Apache.Arrow.Flight.Protocol; using Apache.Arrow.Flight.Internal; using Grpc.Core; @@ -22,7 +21,11 @@ namespace Apache.Arrow.Flight.Server { public class FlightServerRecordBatchStreamReader : FlightRecordBatchStreamReader { - internal FlightServerRecordBatchStreamReader(IAsyncStreamReader flightDataStream) : base(flightDataStream) + public FlightServerRecordBatchStreamReader(IAsyncStreamReader flightDataStream) : base(new StreamReader(flightDataStream, data => data.ToProtocol())) + { + } + + internal FlightServerRecordBatchStreamReader(IAsyncStreamReader flightDataStream) : base(flightDataStream) { } diff --git a/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamWriter.cs b/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamWriter.cs index 6c1987339bdf3..7d1c89ea3df2a 100644 --- a/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamWriter.cs +++ b/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamWriter.cs @@ -13,10 +13,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -using System; -using System.Collections.Generic; -using System.Text; -using Apache.Arrow.Flight.Protocol; using Apache.Arrow.Flight.Internal; using Grpc.Core; @@ -24,7 +20,11 @@ namespace Apache.Arrow.Flight.Server { public class FlightServerRecordBatchStreamWriter : FlightRecordBatchStreamWriter, IServerStreamWriter { - internal FlightServerRecordBatchStreamWriter(IServerStreamWriter clientStreamWriter) : base(clientStreamWriter, null) + public FlightServerRecordBatchStreamWriter(IServerStreamWriter clientStreamWriter) : base(new StreamWriter(clientStreamWriter, data => new FlightData(data)), null) + { + } + + internal FlightServerRecordBatchStreamWriter(IServerStreamWriter clientStreamWriter) : base(clientStreamWriter, null) { } } diff --git a/csharp/src/Apache.Arrow.Flight/Server/Internal/FlightServerImplementation.cs b/csharp/src/Apache.Arrow.Flight/Server/Internal/FlightServerImplementation.cs index dcf6e57681894..f34ffaf92fc81 100644 --- a/csharp/src/Apache.Arrow.Flight/Server/Internal/FlightServerImplementation.cs +++ b/csharp/src/Apache.Arrow.Flight/Server/Internal/FlightServerImplementation.cs @@ -14,12 +14,9 @@ // limitations under the License. using System; -using System.Collections.Generic; -using System.Text; using System.Threading.Tasks; using Apache.Arrow.Flight.Internal; using Apache.Arrow.Flight.Protocol; -using Apache.Arrow.Flight.Server; using Grpc.Core; namespace Apache.Arrow.Flight.Server.Internal @@ -35,21 +32,26 @@ public FlightServerImplementation(FlightServer flightServer) _flightServer = flightServer; } - public override async Task DoPut(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) + public override async Task DoPut(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) { var readStream = new FlightServerRecordBatchStreamReader(requestStream); - var writeStream = new StreamWriter(responseStream, putResult => putResult.ToProtocol()); + var writeStream = new StreamWriter(responseStream, putResult => putResult.ToProtocol()); + await _flightServer.DoPut(readStream, writeStream, context).ConfigureAwait(false); } - public override Task DoGet(Protocol.Ticket request, IServerStreamWriter responseStream, ServerCallContext context) + public override Task DoGet(Protocol.Ticket request, IServerStreamWriter responseStream, ServerCallContext context) { - return _flightServer.DoGet(new FlightTicket(request.Ticket_), new FlightServerRecordBatchStreamWriter(responseStream), context); + var flightTicket = new FlightTicket(request.Ticket_); + var flightServerRecordBatchStreamWriter = new FlightServerRecordBatchStreamWriter(responseStream); + + return _flightServer.DoGet(flightTicket, flightServerRecordBatchStreamWriter, context); } public override Task ListFlights(Protocol.Criteria request, IServerStreamWriter responseStream, ServerCallContext context) { var writeStream = new StreamWriter(responseStream, flightInfo => flightInfo.ToProtocol()); + return _flightServer.ListFlights(new FlightCriteria(request), writeStream, context); } @@ -57,6 +59,7 @@ public override Task DoAction(Protocol.Action request, IServerStreamWriter(responseStream, result => result.ToProtocol()); + return _flightServer.DoAction(action, writeStream, context); } @@ -74,12 +77,12 @@ public override async Task GetSchema(Protocol.FlightDescriptor req public override async Task GetFlightInfo(Protocol.FlightDescriptor request, ServerCallContext context) { var flightDescriptor = new FlightDescriptor(request); - var flightInfo = await _flightServer.GetFlightInfo(flightDescriptor, context).ConfigureAwait(false); + FlightInfo flightInfo = await _flightServer.GetFlightInfo(flightDescriptor, context).ConfigureAwait(false); return flightInfo.ToProtocol(); } - public override Task DoExchange(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) + public override Task DoExchange(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) { //Exchange is not yet implemented throw new NotImplementedException(); @@ -87,14 +90,15 @@ public override Task DoExchange(IAsyncStreamReader requestStream, IS public override Task Handshake(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) { - //Handshake is not yet implemented - throw new NotImplementedException(); + var readStream = new StreamReader(requestStream, request => new FlightHandshakeRequest(request)); + var writeStream = new StreamWriter(responseStream, result => result.ToProtocol()); + return _flightServer.Handshake(readStream, writeStream, context); } - public override Task ListActions(Empty request, IServerStreamWriter responseStream, ServerCallContext context) + public override async Task ListActions(Empty request, IServerStreamWriter responseStream, ServerCallContext context) { var writeStream = new StreamWriter(responseStream, (actionType) => actionType.ToProtocol()); - return _flightServer.ListActions(writeStream, context); + await _flightServer.ListActions(writeStream, context).ConfigureAwait(false); } } } diff --git a/csharp/src/Apache.Arrow.Flight/Server/Internal/HandshakeAdapters.cs b/csharp/src/Apache.Arrow.Flight/Server/Internal/HandshakeAdapters.cs new file mode 100644 index 0000000000000..40ac3bebf1b19 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Server/Internal/HandshakeAdapters.cs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Threading.Tasks; +using Apache.Arrow.Flight.Protocol; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Server.Internal; + +internal class FlightHandshakeStreamWriterAdapter : IClientStreamWriter +{ + private readonly IClientStreamWriter _writeStream; + + public FlightHandshakeStreamWriterAdapter(IClientStreamWriter writeStream) + { + _writeStream = writeStream; + } + + public Task WriteAsync(FlightHandshakeRequest message) => _writeStream.WriteAsync(message.ToProtocol()); + + public WriteOptions WriteOptions + { + get => _writeStream.WriteOptions; + set => _writeStream.WriteOptions = value; + } + + public Task CompleteAsync() => _writeStream.CompleteAsync(); +} diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj new file mode 100644 index 0000000000000..07e341eb27ab2 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj @@ -0,0 +1,19 @@ + + + + net7.0 + false + + + + + + + + + + + + + + diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlServerTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlServerTests.cs new file mode 100644 index 0000000000000..4ad5bde0874a8 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlServerTests.cs @@ -0,0 +1,375 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#nullable enable +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Server; +using Apache.Arrow.Types; +using Arrow.Flight.Protocol.Sql; +using Google.Protobuf; +using Grpc.Core; +using Xunit; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class FlightSqlServerTests +{ + [Theory] + [InlineData(FlightDescriptorType.Path, null)] + [InlineData(FlightDescriptorType.Command, null)] + [InlineData(FlightDescriptorType.Command, typeof(CommandGetCatalogs))] + public void EnsureGetCommandReturnsTheCorrectResponse(FlightDescriptorType type, Type? expectedResult) + { + //Given + FlightDescriptor descriptor; + if (type == FlightDescriptorType.Command) + { + descriptor = expectedResult != null ? + FlightDescriptor.CreateCommandDescriptor(((IMessage) Activator.CreateInstance(expectedResult!)!).PackAndSerialize().ToByteArray()) : + FlightDescriptor.CreateCommandDescriptor(ByteString.Empty.ToStringUtf8()); + } + else + { + descriptor = FlightDescriptor.CreatePathDescriptor(System.Array.Empty()); + } + + //When + var result = FlightSqlServer.GetCommand(descriptor); + + //Then + Assert.Equal(expectedResult, result?.GetType()); + } + + [Fact] + public async Task EnsureTheCorrectActionsAreGiven() + { + //Given + var producer = new TestFlightSqlSever(); + var streamWriter = new MockServerStreamWriter(); + + //When + await producer.ListActions(streamWriter, new MockServerCallContext()).ConfigureAwait(false); + var actions = streamWriter.Messages.ToArray(); + + Assert.Equal(FlightSqlUtils.FlightSqlActions, actions); + } + + [Theory] + [InlineData(false, + new[] {"catalog_name", "db_schema_name", "table_name", "table_type"}, + new[] {typeof(StringType), typeof(StringType), typeof(StringType), typeof(StringType)}, + new[] {true, true, false, false}) + ] + [InlineData(true, + new[] {"catalog_name", "db_schema_name", "table_name", "table_type", "table_schema"}, + new[] {typeof(StringType), typeof(StringType), typeof(StringType), typeof(StringType), typeof(BinaryType)}, + new[] {true, true, false, false, false}) + ] + public void EnsureTableSchemaIsCorrectWithoutTableSchema(bool includeTableSchemaField, string[] expectedNames, Type[] expectedTypes, bool[] expectedIsNullable) + { + // Arrange + + // Act + var schema = FlightSqlServer.GetTableSchema(includeTableSchemaField); + var fields = schema.FieldsList; + + //Assert + Assert.False(schema.HasMetadata); + Assert.Equal(expectedNames.Length, fields.Count); + for (int i = 0; i < fields.Count; i++) + { + Assert.Equal(expectedNames[i], fields[i].Name); + Assert.Equal(expectedTypes[i], fields[i].DataType.GetType()); + Assert.Equal(expectedIsNullable[i], fields[i].IsNullable); + } + } + + #region FlightInfoTests + [Theory] + [InlineData(typeof(CommandStatementQuery), "GetStatementQueryFlightInfo")] + [InlineData(typeof(CommandPreparedStatementQuery), "GetPreparedStatementQueryFlightInfo")] + [InlineData(typeof(CommandGetCatalogs), "GetCatalogFlightInfo")] + [InlineData(typeof(CommandGetDbSchemas), "GetDbSchemaFlightInfo")] + [InlineData(typeof(CommandGetTables), "GetTablesFlightInfo")] + [InlineData(typeof(CommandGetTableTypes), "GetTableTypesFlightInfo")] + [InlineData(typeof(CommandGetSqlInfo), "GetSqlFlightInfo")] + [InlineData(typeof(CommandGetPrimaryKeys), "GetPrimaryKeysFlightInfo")] + [InlineData(typeof(CommandGetExportedKeys), "GetExportedKeysFlightInfo")] + [InlineData(typeof(CommandGetImportedKeys), "GetImportedKeysFlightInfo")] + [InlineData(typeof(CommandGetCrossReference), "GetCrossReferenceFlightInfo")] + [InlineData(typeof(CommandGetXdbcTypeInfo), "GetXdbcTypeFlightInfo")] + public async void EnsureGetFlightInfoIsCorrectlyRoutedForCommand(Type commandType, string expectedResult) + { + //Given + var command = (IMessage) Activator.CreateInstance(commandType)!; + var producer = new TestFlightSqlSever(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize().ToArray()); + + //When + var flightInfo = await producer.GetFlightInfo(descriptor, new MockServerCallContext()); + + //Then + Assert.Equal(expectedResult, flightInfo.Descriptor.Paths.First()); + } + + + [Fact] + public async void EnsureAnInvalidOperationExceptionIsThrownWhenACommandIsNotSupportedAndHasNoDescriptor() + { + //Given + var producer = new TestFlightSqlSever(); + + //When + async Task Act() => await producer.GetFlightInfo(FlightDescriptor.CreatePathDescriptor(""), new MockServerCallContext()); + var exception = await Record.ExceptionAsync(Act); + + //Then + Assert.Equal("command type not supported", exception?.Message); + } + + [Fact] + public async void EnsureAnInvalidOperationExceptionIsThrownWhenACommandIsNotSupported() + { + //Given + var producer = new TestFlightSqlSever(); + var command = new CommandPreparedStatementUpdate(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize().ToArray()); + + //When + async Task Act() => await producer.GetFlightInfo(descriptor, new MockServerCallContext()); + var exception = await Record.ExceptionAsync(Act); + + //Then + Assert.Equal("command type CommandPreparedStatementUpdate not supported", exception?.Message); + } + #endregion + + #region DoGetTests + + [Theory] + [InlineData(typeof(CommandPreparedStatementQuery), "DoGetPreparedStatementQuery")] + [InlineData(typeof(CommandGetSqlInfo), "DoGetSqlInfo")] + [InlineData(typeof(CommandGetCatalogs), "DoGetCatalog")] + [InlineData(typeof(CommandGetTableTypes), "DoGetTableType")] + [InlineData(typeof(CommandGetTables), "DoGetTables")] + [InlineData(typeof(CommandGetDbSchemas), "DoGetDbSchema")] + [InlineData(typeof(CommandGetPrimaryKeys), "DoGetPrimaryKeys")] + [InlineData(typeof(CommandGetExportedKeys), "DoGetExportedKeys")] + [InlineData(typeof(CommandGetImportedKeys), "DoGetImportedKeys")] + [InlineData(typeof(CommandGetCrossReference), "DoGetCrossReference")] + [InlineData(typeof(CommandGetXdbcTypeInfo), "DoGetXbdcTypeInfo")] + public async void EnsureDoGetIsCorrectlyRoutedForADoGetCommand(Type commandType, string expectedResult) + { + //Given + var producer = new TestFlightSqlSever(); + var command = (IMessage) Activator.CreateInstance(commandType)!; + var ticket = new FlightTicket(command.PackAndSerialize()); + var streamWriter = new MockServerStreamWriter(); + + //When + await producer.DoGet(ticket, new FlightServerRecordBatchStreamWriter(streamWriter), new MockServerCallContext()); + var schema = await streamWriter.Messages.GetSchema(); + + //Then + Assert.Equal(expectedResult, schema.FieldsList[0].Name); + } + + [Fact] + public async void EnsureAnInvalidOperationExceptionIsThrownWhenADoGetCommandIsNotSupported() + { + //Given + var producer = new TestFlightSqlSever(); + var ticket = new FlightTicket(""); + var streamWriter = new MockServerStreamWriter(); + + //When + async Task Act() => await producer.DoGet(ticket, new FlightServerRecordBatchStreamWriter(streamWriter), new MockServerCallContext()); + var exception = await Record.ExceptionAsync(Act); + + //Then + Assert.Equal("Status(StatusCode=\"InvalidArgument\", Detail=\"DoGet command is not supported.\")", exception?.Message); + } + #endregion + + #region DoActionTests + [Theory] + [InlineData(SqlAction.CloseRequest, typeof(ActionClosePreparedStatementRequest), "ClosePreparedStatement")] + [InlineData(SqlAction.CreateRequest, typeof(ActionCreatePreparedStatementRequest), "CreatePreparedStatement")] + [InlineData("BadCommand", typeof(ActionCreatePreparedStatementRequest), "Action type BadCommand not supported", true)] + public async void EnsureDoActionIsCorrectlyRoutedForAnActionRequest(string actionType, Type actionBodyType, string expectedResponse, bool isException = false) + { + //Given + var producer = new TestFlightSqlSever(); + var actionBody = (IMessage) Activator.CreateInstance(actionBodyType)!; + var action = new FlightAction(actionType, actionBody.PackAndSerialize()); + var mockStreamWriter = new MockStreamWriter(); + + //When + async Task Act() => await producer.DoAction(action, mockStreamWriter, new MockServerCallContext()); + var exception = await Record.ExceptionAsync(Act); + string? actualMessage = isException ? exception?.Message : mockStreamWriter.Messages[0].Body.ToStringUtf8(); + + //Then + Assert.Equal(expectedResponse, actualMessage); + } + #endregion + + #region DoPutTests + [Theory] + [InlineData(typeof(CommandStatementUpdate), "PutStatementUpdate")] + [InlineData(typeof(CommandPreparedStatementQuery), "PutPreparedStatementQuery")] + [InlineData(typeof(CommandPreparedStatementUpdate), "PutPreparedStatementUpdate")] + [InlineData(typeof(CommandGetXdbcTypeInfo), "Command CommandGetXdbcTypeInfo not supported", true)] + public async void EnsureDoPutIsCorrectlyRoutedForTheCommand(Type commandType, string expectedResponse, bool isException = false) + { + //Given + var command = (IMessage) Activator.CreateInstance(commandType)!; + var producer = new TestFlightSqlSever(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize().ToArray()); + var recordBatch = new RecordBatch(new Schema(new List(), null), System.Array.Empty(), 0); + var reader = new MockStreamReader(await recordBatch.ToFlightData(descriptor).ConfigureAwait(false)); + var batchReader = new FlightServerRecordBatchStreamReader(reader); + var mockStreamWriter = new MockServerStreamWriter(); + + //When + async Task Act() => await producer.DoPut(batchReader, mockStreamWriter, new MockServerCallContext()).ConfigureAwait(false); + var exception = await Record.ExceptionAsync(Act); + string? actualMessage = isException ? exception?.Message : mockStreamWriter.Messages[0].ApplicationMetadata.ToStringUtf8(); + + //Then + Assert.Equal(expectedResponse, actualMessage); + } + #endregion + + private class MockServerCallContext : ServerCallContext + { + protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders) => throw new NotImplementedException(); + + protected override ContextPropagationToken CreatePropagationTokenCore(ContextPropagationOptions? options) => throw new NotImplementedException(); + protected override string MethodCore => ""; + protected override string HostCore => ""; + protected override string PeerCore => ""; + protected override DateTime DeadlineCore => new(); + protected override Metadata RequestHeadersCore => new(); + protected override CancellationToken CancellationTokenCore => default; + protected override Metadata ResponseTrailersCore => new(); + protected override Status StatusCore { get; set; } + protected override WriteOptions WriteOptionsCore { get; set; } = WriteOptions.Default; + protected override AuthContext AuthContextCore => new("", new Dictionary>()); + } +} + +internal class MockStreamWriter : IServerStreamWriter +{ + public Task WriteAsync(T message) + { + _messages.Add(message); + return Task.FromResult(message); + } + + public IReadOnlyList Messages => new ReadOnlyCollection(_messages); + public WriteOptions? WriteOptions { get; set; } + private readonly List _messages = new(); +} + +internal class MockServerStreamWriter : IServerStreamWriter +{ + public Task WriteAsync(T message) + { + _messages.Add(message); + return Task.FromResult(message); + } + + public IReadOnlyList Messages => new ReadOnlyCollection(_messages); + public WriteOptions? WriteOptions { get; set; } + private readonly List _messages = new(); +} + +internal static class MockStreamReaderWriterExtensions +{ + public static async Task> GetRecordBatches(this IReadOnlyList flightDataList) + { + var list = new List(); + var recordBatchReader = new FlightServerRecordBatchStreamReader(new MockStreamReader(flightDataList)); + while (await recordBatchReader.MoveNext().ConfigureAwait(false)) + { + list.Add(recordBatchReader.Current); + } + + return list; + } + + public static async Task GetSchema(this IEnumerable flightDataList) + { + var recordBatchReader = new FlightServerRecordBatchStreamReader(new MockStreamReader(flightDataList)); + return await recordBatchReader.Schema; + } + + public static async Task> ToFlightData(this RecordBatch recordBatch, FlightDescriptor? descriptor = null) + { + var responseStream = new MockFlightServerRecordBatchStreamWriter(); + await responseStream.WriteRecordBatchAsync(recordBatch).ConfigureAwait(false); + if (descriptor == null) + { + return responseStream.FlightData; + } + + return responseStream.FlightData.Select( + flightData => new FlightData(descriptor, flightData.DataBody, flightData.DataHeader, flightData.AppMetadata) + ); + } +} + +internal class MockStreamReader: IAsyncStreamReader +{ + private readonly IEnumerator _flightActions; + + public MockStreamReader(IEnumerable flightActions) + { + _flightActions = flightActions.GetEnumerator(); + } + + public Task MoveNext(CancellationToken cancellationToken) + { + return Task.FromResult(_flightActions.MoveNext()); + } + + public T Current => _flightActions.Current; +} + +internal class MockFlightServerRecordBatchStreamWriter : FlightServerRecordBatchStreamWriter +{ + private readonly MockStreamWriter _streamWriter; + public MockFlightServerRecordBatchStreamWriter() : this(new MockStreamWriter()) { } + + private MockFlightServerRecordBatchStreamWriter(MockStreamWriter clientStreamWriter) : base(clientStreamWriter) + { + _streamWriter = clientStreamWriter; + } + + public IEnumerable FlightData => _streamWriter.Messages; + + public async Task WriteRecordBatchAsync(RecordBatch recordBatch) + { + await WriteAsync(recordBatch).ConfigureAwait(false); + } +} + + diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs new file mode 100644 index 0000000000000..031495fffdcc7 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public static class FlightSqlTestExtensions +{ + public static ByteString PackAndSerialize(this IMessage command) + { + return Any.Pack(command).Serialize(); + } +} diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlUtilsTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlUtilsTests.cs new file mode 100644 index 0000000000000..3ea7a8d3f0463 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlUtilsTests.cs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Arrow.Flight.Protocol.Sql; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using Xunit; + +namespace Apache.Arrow.Flight.Sql.Tests +{ + public class FlightSqlUtilsTests + { + [Fact] + public void EnsureParseCanCorrectlyReviveTheCommand() + { + //Given + var expectedCommand = new CommandStatementQuery + { + Query = "select * from database" + }; + + //When + var command = FlightSqlUtils.Parse(Any.Pack(expectedCommand).ToByteString()); + + //Then + Assert.Equal(command.Unpack(), expectedCommand); + } + + [Fact] + public void EnsureUnpackCanCreateTheCorrectObject() + { + //Given + var expectedCommand = new CommandPreparedStatementQuery + { + PreparedStatementHandle = ByteString.Empty + }; + + //When + var command = FlightSqlUtils.Unpack(Any.Pack(expectedCommand)); + + //Then + Assert.Equal(command, expectedCommand); + } + + [Fact] + public void EnsureParseAndUnpackProducesTheCorrectObject() + { + //Given + var expectedCommand = new CommandPreparedStatementQuery + { + PreparedStatementHandle = ByteString.Empty + }; + + //When + var command = FlightSqlUtils.ParseAndUnpack(Any.Pack(expectedCommand).ToByteString()); + + //Then + Assert.Equal(command, expectedCommand); + } + } +} diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs new file mode 100644 index 0000000000000..3dca632b5b761 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; +using System.Reflection; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Server; +using Apache.Arrow.Types; +using Arrow.Flight.Protocol.Sql; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class TestFlightSqlSever : FlightSqlServer +{ + protected override Task GetStatementQueryFlightInfo(CommandStatementQuery commandStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); + + protected override Task GetPreparedStatementQueryFlightInfo(CommandPreparedStatementQuery preparedStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); + + protected override Task GetCatalogFlightInfo(CommandGetCatalogs command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); + + protected override Task GetDbSchemaFlightInfo(CommandGetDbSchemas command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); + + protected override Task GetTablesFlightInfo(CommandGetTables command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); + + protected override Task GetTableTypesFlightInfo(CommandGetTableTypes command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); + + protected override Task GetSqlFlightInfo(CommandGetSqlInfo commandGetSqlInfo, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); + + protected override Task GetPrimaryKeysFlightInfo(CommandGetPrimaryKeys command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); + + protected override Task GetExportedKeysFlightInfo(CommandGetExportedKeys command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); + + protected override Task GetImportedKeysFlightInfo(CommandGetImportedKeys command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); + + protected override Task GetCrossReferenceFlightInfo(CommandGetCrossReference command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); + + protected override Task GetXdbcTypeFlightInfo(CommandGetXdbcTypeInfo command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); + + protected override async Task DoGetPreparedStatementQuery(CommandPreparedStatementQuery preparedStatementQuery, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetPreparedStatementQuery")); + + protected override async Task DoGetSqlInfo(CommandGetSqlInfo getSqlInfo, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetSqlInfo")); + + protected override async Task DoGetCatalog(CommandGetCatalogs command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetCatalog")); + + protected override async Task DoGetTableType(CommandGetTableTypes command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetTableType")); + + protected override async Task DoGetTables(CommandGetTables command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetTables")); + + protected override async Task DoGetPrimaryKeys(CommandGetPrimaryKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetPrimaryKeys")); + + protected override async Task DoGetDbSchema(CommandGetDbSchemas command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetDbSchema")); + + protected override async Task DoGetExportedKeys(CommandGetExportedKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetExportedKeys")); + + protected override async Task DoGetImportedKeys(CommandGetImportedKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetImportedKeys")); + + protected override async Task DoGetCrossReference(CommandGetCrossReference command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetCrossReference")); + + protected override async Task DoGetXbdcTypeInfo(CommandGetXdbcTypeInfo command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetXbdcTypeInfo")); + + protected override async Task CreatePreparedStatement(ActionCreatePreparedStatementRequest request, FlightAction action, IAsyncStreamWriter streamWriter, ServerCallContext serverCallContext) => await streamWriter.WriteAsync(new FlightResult("CreatePreparedStatement")); + + protected override async Task ClosePreparedStatement(ActionClosePreparedStatementRequest request, FlightAction action, IAsyncStreamWriter streamWriter, ServerCallContext serverCallContext) => await streamWriter.WriteAsync(new FlightResult("ClosePreparedStatement")); + + protected override async Task PutPreparedStatementUpdate(CommandPreparedStatementUpdate command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(new FlightPutResult("PutPreparedStatementUpdate")); + + protected override async Task PutStatementUpdate(CommandStatementUpdate command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(new FlightPutResult("PutStatementUpdate")); + + protected override async Task PutPreparedStatementQuery(CommandPreparedStatementQuery command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(new FlightPutResult("PutPreparedStatementQuery")); + + private RecordBatch MockRecordBatch(string name) + { + var schema = new Schema(new List {new(name, StringType.Default, false)}, System.Array.Empty>()); + return new RecordBatch(schema, new []{ new StringArray.Builder().Append(name).Build() }, 1); + } +} diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs index ae6e2e4b03b6d..a613b04a32002 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs @@ -18,6 +18,7 @@ using System.Linq; using System.Threading.Tasks; using Apache.Arrow.Flight.Server; +using Google.Protobuf; using Grpc.Core; using Grpc.Core.Utils; @@ -86,6 +87,21 @@ public override Task GetFlightInfo(FlightDescriptor request, ServerC throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); } + public override async Task Handshake(IAsyncStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) + { + while (await requestStream.MoveNext().ConfigureAwait(false)) + { + if (requestStream.Current.Payload.ToStringUtf8() == "Hello") + { + await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Hello handshake"))).ConfigureAwait(false); + } + else + { + await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Done"))).ConfigureAwait(false); + } + } + } + public override Task GetSchema(FlightDescriptor request, ServerCallContext context) { if(_flightStore.Flights.TryGetValue(request, out var flightHolder)) diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs index 3556c6a17feef..267fe4e4b606d 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs @@ -283,6 +283,32 @@ public async Task TestListFlights() } } + [Fact] + public async Task TestHandshake() + { + var duplexStreamingCall = _flightClient.Handshake(); + + await duplexStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.Empty)).ConfigureAwait(false); + await duplexStreamingCall.RequestStream.CompleteAsync().ConfigureAwait(false); + var results = await duplexStreamingCall.ResponseStream.ToListAsync().ConfigureAwait(false); + + Assert.Single(results); + Assert.Equal("Done", results.First().Payload.ToStringUtf8()); + } + + [Fact] + public async Task TestHandshakeWithSpecificMessage() + { + var duplexStreamingCall = _flightClient.Handshake(); + + await duplexStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.CopyFromUtf8("Hello"))).ConfigureAwait(false); + await duplexStreamingCall.RequestStream.CompleteAsync().ConfigureAwait(false); + var results = await duplexStreamingCall.ResponseStream.ToListAsync().ConfigureAwait(false); + + Assert.Single(results); + Assert.Equal("Hello handshake", results.First().Payload.ToStringUtf8()); + } + [Fact] public async Task TestGetBatchesWithAsyncEnumerable() { diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 7bdb692d048e9..aebe321d613ab 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -113,8 +113,10 @@ csharp/src/Apache.Arrow/Properties/Resources.resx csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj csharp/src/Apache.Arrow.Flight.AspNetCore/Apache.Arrow.Flight.AspNetCore.csproj csharp/src/Apache.Arrow.Compression/Apache.Arrow.Compression.csproj +csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj csharp/test/Apache.Arrow.Benchmarks/Apache.Arrow.Benchmarks.csproj csharp/test/Apache.Arrow.Flight.Tests/Apache.Arrow.Flight.Tests.csproj +csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj csharp/test/Apache.Arrow.IntegrationTest/Apache.Arrow.IntegrationTest.csproj csharp/test/Apache.Arrow.Tests/Apache.Arrow.Tests.csproj