diff --git a/csharp/Apache.Arrow.sln b/csharp/Apache.Arrow.sln
index baf4bc6129598..7e7f7c6331e88 100644
--- a/csharp/Apache.Arrow.sln
+++ b/csharp/Apache.Arrow.sln
@@ -23,6 +23,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Compression",
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Compression.Tests", "test\Apache.Arrow.Compression.Tests\Apache.Arrow.Compression.Tests.csproj", "{5D7FF380-B7DF-4752-B415-7C08C70C9F06}"
EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql.Tests", "test\Apache.Arrow.Flight.Sql.Tests\Apache.Arrow.Flight.Sql.Tests.csproj", "{DCC99EB1-4E60-4F0D-AEA9-C44A4C0C8B1D}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql", "src\Apache.Arrow.Flight.Sql\Apache.Arrow.Flight.Sql.csproj", "{2ADE087A-B424-4895-8CC5-10170D10BA62}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -69,6 +73,14 @@ Global
{5D7FF380-B7DF-4752-B415-7C08C70C9F06}.Debug|Any CPU.Build.0 = Debug|Any CPU
{5D7FF380-B7DF-4752-B415-7C08C70C9F06}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5D7FF380-B7DF-4752-B415-7C08C70C9F06}.Release|Any CPU.Build.0 = Release|Any CPU
+ {DCC99EB1-4E60-4F0D-AEA9-C44A4C0C8B1D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {DCC99EB1-4E60-4F0D-AEA9-C44A4C0C8B1D}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {DCC99EB1-4E60-4F0D-AEA9-C44A4C0C8B1D}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {DCC99EB1-4E60-4F0D-AEA9-C44A4C0C8B1D}.Release|Any CPU.Build.0 = Release|Any CPU
+ {2ADE087A-B424-4895-8CC5-10170D10BA62}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {2ADE087A-B424-4895-8CC5-10170D10BA62}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {2ADE087A-B424-4895-8CC5-10170D10BA62}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {2ADE087A-B424-4895-8CC5-10170D10BA62}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj b/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj
new file mode 100644
index 0000000000000..50570d628924b
--- /dev/null
+++ b/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj
@@ -0,0 +1,18 @@
+
+
+ netstandard2.1
+ enable
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs
new file mode 100644
index 0000000000000..dbfc1f7c7ea49
--- /dev/null
+++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs
@@ -0,0 +1,389 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using Apache.Arrow.Flight.Server;
+using Apache.Arrow.Types;
+using Arrow.Flight.Protocol.Sql;
+using Google.Protobuf;
+using Google.Protobuf.WellKnownTypes;
+using Grpc.Core;
+using Microsoft.Extensions.Logging;
+
+namespace Apache.Arrow.Flight.Sql;
+
+public abstract class FlightSqlServer : FlightServer
+{
+ private ILogger? Logger { get; }
+ public static readonly Schema CatalogSchema = new(new List {new("catalog_name", StringType.Default, false)}, null);
+ public static readonly Schema TableTypesSchema = new(new List {new("table_type", StringType.Default, false)}, null);
+ public static readonly Schema DbSchemaFlightSchema = new(new List {new("catalog_name", StringType.Default, true), new("db_schema_name", StringType.Default, false)}, null);
+
+ public static readonly Schema PrimaryKeysSchema = new(new List
+ {
+ new("catalog_name", StringType.Default, true),
+ new("db_schema_name", StringType.Default, true),
+ new("table_name", StringType.Default, false),
+ new("column_name", StringType.Default, false),
+ new("key_sequence", Int32Type.Default, false),
+ new("key_name", StringType.Default, true)
+ }, null);
+
+ public static readonly Schema KeyImportExportSchema = new(new List
+ {
+ new("pk_catalog_name", StringType.Default, true),
+ new("pk_db_schema_name", StringType.Default, true),
+ new("pk_table_name", StringType.Default, false),
+ new("pk_column_name", StringType.Default, false),
+ new("fk_catalog_name", StringType.Default, true),
+ new("fk_db_schema_name", StringType.Default, true),
+ new("fk_table_name", StringType.Default, false),
+ new("fk_column_name", StringType.Default, false),
+ new("key_sequence", Int32Type.Default, false),
+ new("fk_key_name", StringType.Default, true),
+ new("pk_key_name", StringType.Default, true),
+ new("update_rule", UInt8Type.Default, false),
+ new("delete_rule", UInt8Type.Default, false)
+ }, null);
+
+ public static readonly Schema TypeInfoSchema = new(new List
+ {
+ new("type_name", StringType.Default, false),
+ new("data_type", Int32Type.Default, false),
+ new("column_size", Int32Type.Default, true),
+ new("literal_prefix", StringType.Default, true),
+ new("literal_suffix", StringType.Default, true),
+ new("create_params", new ListType(new Field("item", StringType.Default, false)), true),
+ new("nullable", Int32Type.Default, false),
+ new("case_sensitive", BooleanType.Default, false),
+ new("searchable", Int32Type.Default, false),
+ new("unsigned_attribute", BooleanType.Default, true),
+ new("fixed_prec_scale", BooleanType.Default, false),
+ new("auto_increment", BooleanType.Default, true),
+ new("local_type_name", StringType.Default, true),
+ new("minimum_scale", Int32Type.Default, true),
+ new("maximum_scale", Int32Type.Default, true),
+ new("sql_data_type", Int32Type.Default, false),
+ new("datetime_subcode", Int32Type.Default, true),
+ new("num_prec_radix", Int32Type.Default, true),
+ new("interval_precision", Int32Type.Default, true)
+ }, null);
+
+ public static readonly Schema SqlInfoSchema = new(new List
+ {
+ new("info_name", UInt32Type.Default, false)
+ //TODO: once we have union serialization in Arrow Flight for .Net we should to add these fields
+ // fieldList.Add(new Field("value", new UnionType(new List(), new List()), false));
+ // fieldList.Add(new Field("value", new UnionType(new []
+ // {
+ // new Field("string_value", StringType.Default, false),
+ // new Field("bool_value", BooleanType.Default, false),
+ // new Field("bigint_value", Int64Type.Default, false),
+ // new Field("bool_value", BooleanType.Default, false),
+ // new Field("bigint_value", Int64Type.Default, false),
+ // new Field("int32_bitmask", Int32Type.Default, false),
+ // new Field("string_list", new ListType(new Field("item", StringType.Default, false)), false),
+ // new Field("int32_to_int32_list_map", new DictionaryType(Int32Type.Default, new ListType(Int32Type.Default), false), false),
+ // }, new []{(byte)ArrowTypeId.String, (byte)ArrowTypeId.Boolean, (byte)ArrowTypeId.Int64,/* (byte)3, (byte)4, (byte)5*/}, UnionMode.Dense), false));
+ }, null);
+
+ private static readonly Schema s_tableSchema = new(new List
+ {
+ new("catalog_name", StringType.Default, true),
+ new("db_schema_name", StringType.Default, true),
+ new("table_name", StringType.Default, false),
+ new("table_type", StringType.Default, false)
+ }, null);
+
+ public static Schema GetTableSchema(bool includeTableSchemaField)
+ {
+ if (!includeTableSchemaField)
+ {
+ return s_tableSchema;
+ }
+
+ var fields = s_tableSchema.FieldsList.ToList();
+ fields.Add(new Field("table_schema", BinaryType.Default, false));
+
+ return new Schema(fields, s_tableSchema.Metadata);
+ }
+
+ public static IMessage? GetCommand(FlightTicket ticket)
+ {
+ try
+ {
+ return GetCommand(Any.Parser.ParseFrom(ticket.Ticket));
+ }
+ catch (InvalidProtocolBufferException) { } //The ticket is not a flight sql command
+
+ return null;
+ }
+
+ public static async Task GetCommand(FlightServerRecordBatchStreamReader requestStream)
+ {
+ return GetCommand(await requestStream.FlightDescriptor.ConfigureAwait(false));
+ }
+
+ public static IMessage? GetCommand(FlightDescriptor? request)
+ {
+ if (request == null) return null;
+ if (request.Type == FlightDescriptorType.Command && request.ParsedAndUnpackedMessage() is { } command)
+ {
+ return command;
+ }
+
+ return null;
+ }
+
+ private static IMessage? GetCommand(Any command)
+ {
+ if (command.Is(CommandPreparedStatementQuery.Descriptor))
+ {
+ return command.Unpack();
+ }
+
+ if (command.Is(CommandGetSqlInfo.Descriptor))
+ {
+ return command.Unpack();
+ }
+
+ if (command.Is(CommandGetCatalogs.Descriptor))
+ {
+ return command.Unpack();
+ }
+
+ if (command.Is(CommandGetTableTypes.Descriptor))
+ {
+ return command.Unpack();
+ }
+
+ if (command.Is(CommandGetTables.Descriptor))
+ {
+ return command.Unpack();
+ }
+
+ if (command.Is(CommandGetDbSchemas.Descriptor))
+ {
+ return command.Unpack();
+ }
+
+ if (command.Is(CommandGetPrimaryKeys.Descriptor))
+ {
+ return command.Unpack();
+ }
+
+ if (command.Is(CommandGetExportedKeys.Descriptor))
+ {
+ return command.Unpack();
+ }
+
+ if (command.Is(CommandGetImportedKeys.Descriptor))
+ {
+ return command.Unpack();
+ }
+
+ if (command.Is(CommandGetCrossReference.Descriptor))
+ {
+ return command.Unpack();
+ }
+
+ if (command.Is(CommandGetXdbcTypeInfo.Descriptor))
+ {
+ return command.Unpack();
+ }
+
+ return null;
+ }
+
+ protected FlightSqlServer(ILoggerFactory? factory = null)
+ {
+ Logger = factory?.CreateLogger();
+ }
+
+ ///
+ /// Lists actions supported by Flight SQL. For Flight RPC actions support
+ /// implementers should extend this method to return additional supported actions.
+ ///
+ public override async Task ListActions(IAsyncStreamWriter responseStream, ServerCallContext context)
+ {
+ foreach (var actionType in FlightSqlUtils.FlightSqlActions)
+ {
+ await responseStream.WriteAsync(actionType).ConfigureAwait(false);
+ }
+ }
+
+ ///
+ /// Attempts to execute a valid Flight SQL command. For Flight RPC calls
+ /// implementers should extend this method in order to handle RPC messages.
+ ///
+ public override Task GetFlightInfo(FlightDescriptor flightDescriptor, ServerCallContext context)
+ {
+ var sqlCommand = GetCommand(flightDescriptor);
+ Logger?.LogTrace("Executing Flight SQL FlightInfo command: {DescriptorName}", sqlCommand?.Descriptor.Name);
+ return sqlCommand switch
+ {
+ CommandStatementQuery command => GetStatementQueryFlightInfo(command, flightDescriptor, context),
+ CommandPreparedStatementQuery command => GetPreparedStatementQueryFlightInfo(command, flightDescriptor, context),
+ CommandGetCatalogs command => GetCatalogFlightInfo(command, flightDescriptor, context),
+ CommandGetDbSchemas command => GetDbSchemaFlightInfo(command, flightDescriptor, context),
+ CommandGetTables command => GetTablesFlightInfo(command, flightDescriptor, context),
+ CommandGetTableTypes command => GetTableTypesFlightInfo(command, flightDescriptor, context),
+ CommandGetSqlInfo command => GetSqlFlightInfo(command, flightDescriptor, context),
+ CommandGetPrimaryKeys command => GetPrimaryKeysFlightInfo(command, flightDescriptor, context),
+ CommandGetExportedKeys command => GetExportedKeysFlightInfo(command, flightDescriptor, context),
+ CommandGetImportedKeys command => GetImportedKeysFlightInfo(command, flightDescriptor, context),
+ CommandGetCrossReference command => GetCrossReferenceFlightInfo(command, flightDescriptor, context),
+ CommandGetXdbcTypeInfo command => GetXdbcTypeFlightInfo(command, flightDescriptor, context),
+ _ => throw new InvalidOperationException($"command type {sqlCommand?.Descriptor?.Name} not supported")
+ };
+ }
+
+ ///
+ /// Attempts to execute a valid Flight SQL command. For Flight RPC calls
+ /// implementers should extend this method in order to handle RPC messages.
+ ///
+ public override Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context)
+ {
+ var sqlCommand = GetCommand(ticket);
+ Logger?.LogTrace("Executing Flight SQL DoGet command: {SqlCommandDescriptor}", sqlCommand?.Descriptor);
+ return sqlCommand switch
+ {
+ CommandPreparedStatementQuery command => DoGetPreparedStatementQuery(command, responseStream, context),
+ CommandGetSqlInfo command => DoGetSqlInfo(command, responseStream, context),
+ CommandGetCatalogs command => DoGetCatalog(command, responseStream, context),
+ CommandGetTableTypes command => DoGetTableType(command, responseStream, context),
+ CommandGetTables command => DoGetTables(command, responseStream, context),
+ CommandGetDbSchemas command => DoGetDbSchema(command, responseStream, context),
+ CommandGetPrimaryKeys command => DoGetPrimaryKeys(command, responseStream, context),
+ CommandGetExportedKeys command => DoGetExportedKeys(command, responseStream, context),
+ CommandGetImportedKeys command => DoGetImportedKeys(command, responseStream, context),
+ CommandGetCrossReference command => DoGetCrossReference(command, responseStream, context),
+ CommandGetXdbcTypeInfo command => DoGetXbdcTypeInfo(command, responseStream, context),
+ _ => throw new RpcException(new Status(StatusCode.InvalidArgument, $"DoGet command {sqlCommand?.Descriptor} is not supported."))
+ };
+ }
+
+ ///
+ /// Attempts to execute a valid Flight SQL command. For Flight RPC calls
+ /// implementers should extend this method in order to handle RPC messages.
+ ///
+ public override Task DoAction(FlightAction action, IAsyncStreamWriter responseStream, ServerCallContext context)
+ {
+ Logger?.LogTrace("Executing Flight SQL DoAction: {ActionType}", action.Type);
+ switch (action.Type)
+ {
+ case SqlAction.CreateRequest:
+ var command = FlightSqlUtils.ParseAndUnpack(action.Body);
+ return CreatePreparedStatement(command, action, responseStream, context);
+ case SqlAction.CloseRequest:
+ var closeCommand = FlightSqlUtils.ParseAndUnpack(action.Body);
+ return ClosePreparedStatement(closeCommand, action, responseStream, context);
+ default:
+ throw new NotImplementedException($"Action type {action.Type} not supported");
+ }
+ }
+
+ ///
+ /// Attempts to execute a valid Flight SQL command. For Flight RPC calls
+ /// implementers should extend this method in order to handle RPC messages.
+ ///
+ public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context)
+ {
+ if (await GetCommand(requestStream).ConfigureAwait(false) is { } command)
+ {
+ await DoPutInternal(command, requestStream, responseStream, context).ConfigureAwait(false);
+ }
+ else
+ {
+ throw new NotImplementedException();
+ }
+ }
+
+ private Task DoPutInternal(IMessage command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context)
+ {
+ Logger?.LogTrace("Executing Flight SQL DoAction: {DescriptorName}", command.Descriptor.Name);
+ return command switch
+ {
+ CommandStatementUpdate statementUpdate => PutStatementUpdate(statementUpdate, requestStream, responseStream, context),
+ CommandPreparedStatementQuery preparedStatementQuery => PutPreparedStatementQuery(preparedStatementQuery, requestStream, responseStream, context),
+ CommandPreparedStatementUpdate preparedStatementUpdate => PutPreparedStatementUpdate(preparedStatementUpdate, requestStream, responseStream, context),
+ _ => throw new NotImplementedException($"Command {command.Descriptor.Name} not supported")
+ };
+ }
+
+ public static bool SupportsAction(FlightAction action)
+ {
+ switch (action.Type)
+ {
+ case SqlAction.CreateRequest:
+ case SqlAction.CloseRequest:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+
+ #region FlightInfo
+
+ protected abstract Task GetStatementQueryFlightInfo(CommandStatementQuery commandStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext context);
+ protected abstract Task GetPreparedStatementQueryFlightInfo(CommandPreparedStatementQuery preparedStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext context);
+ protected abstract Task GetCatalogFlightInfo(CommandGetCatalogs command, FlightDescriptor flightDescriptor, ServerCallContext context);
+ protected abstract Task GetDbSchemaFlightInfo(CommandGetDbSchemas command, FlightDescriptor flightDescriptor, ServerCallContext context);
+ protected abstract Task GetTablesFlightInfo(CommandGetTables command, FlightDescriptor flightDescriptor, ServerCallContext context);
+ protected abstract Task GetTableTypesFlightInfo(CommandGetTableTypes command, FlightDescriptor flightDescriptor, ServerCallContext context);
+ protected abstract Task GetSqlFlightInfo(CommandGetSqlInfo commandGetSqlInfo, FlightDescriptor flightDescriptor, ServerCallContext context);
+ protected abstract Task GetPrimaryKeysFlightInfo(CommandGetPrimaryKeys command, FlightDescriptor flightDescriptor, ServerCallContext context);
+ protected abstract Task GetExportedKeysFlightInfo(CommandGetExportedKeys command, FlightDescriptor flightDescriptor, ServerCallContext context);
+ protected abstract Task GetImportedKeysFlightInfo(CommandGetImportedKeys command, FlightDescriptor flightDescriptor, ServerCallContext context);
+ protected abstract Task GetCrossReferenceFlightInfo(CommandGetCrossReference command, FlightDescriptor flightDescriptor, ServerCallContext context);
+ protected abstract Task GetXdbcTypeFlightInfo(CommandGetXdbcTypeInfo command, FlightDescriptor flightDescriptor, ServerCallContext context);
+
+ #endregion
+
+ #region DoGet
+
+ protected abstract Task DoGetPreparedStatementQuery(CommandPreparedStatementQuery preparedStatementQuery, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
+ protected abstract Task DoGetSqlInfo(CommandGetSqlInfo getSqlInfo, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
+ protected abstract Task DoGetCatalog(CommandGetCatalogs command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
+ protected abstract Task DoGetTableType(CommandGetTableTypes command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
+ protected abstract Task DoGetTables(CommandGetTables command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
+ protected abstract Task DoGetPrimaryKeys(CommandGetPrimaryKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
+ protected abstract Task DoGetDbSchema(CommandGetDbSchemas command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
+ protected abstract Task DoGetExportedKeys(CommandGetExportedKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
+ protected abstract Task DoGetImportedKeys(CommandGetImportedKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
+ protected abstract Task DoGetCrossReference(CommandGetCrossReference command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
+ protected abstract Task DoGetXbdcTypeInfo(CommandGetXdbcTypeInfo command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context);
+
+ #endregion
+
+ #region DoAction
+
+ protected abstract Task CreatePreparedStatement(ActionCreatePreparedStatementRequest request, FlightAction action, IAsyncStreamWriter streamWriter, ServerCallContext context);
+ protected abstract Task ClosePreparedStatement(ActionClosePreparedStatementRequest request, FlightAction action, IAsyncStreamWriter streamWriter, ServerCallContext context);
+
+ #endregion
+
+ #region DoPut
+
+ protected abstract Task PutPreparedStatementUpdate(CommandPreparedStatementUpdate command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context);
+ protected abstract Task PutStatementUpdate(CommandStatementUpdate command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context);
+ protected abstract Task PutPreparedStatementQuery(CommandPreparedStatementQuery command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context);
+
+ #endregion
+}
diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlUtils.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlUtils.cs
new file mode 100644
index 0000000000000..295fe4d32a9f6
--- /dev/null
+++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlUtils.cs
@@ -0,0 +1,158 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using Arrow.Flight.Protocol.Sql;
+using Google.Protobuf;
+using Google.Protobuf.WellKnownTypes;
+
+namespace Apache.Arrow.Flight.Sql
+{
+ ///
+ /// Helper methods for doing common Flight Sql tasks and conversions
+ ///
+ public class FlightSqlUtils
+ {
+ public static readonly FlightActionType FlightSqlCreatePreparedStatement = new("CreatePreparedStatement",
+ "Creates a reusable prepared statement resource on the server. \n" +
+ "Request Message: ActionCreatePreparedStatementRequest\n" +
+ "Response Message: ActionCreatePreparedStatementResult");
+
+ public static readonly FlightActionType FlightSqlClosePreparedStatement = new("ClosePreparedStatement",
+ "Closes a reusable prepared statement resource on the server. \n" +
+ "Request Message: ActionClosePreparedStatementRequest\n" +
+ "Response Message: N/A");
+
+ ///
+ /// List of possible actions
+ ///
+ public static readonly List FlightSqlActions = new()
+ {
+ FlightSqlCreatePreparedStatement,
+ FlightSqlClosePreparedStatement
+ };
+
+ ///
+ /// Helper to parse {@link com.google.protobuf.Any} objects to the specific protobuf object.
+ ///
+ /// the raw bytes source value.
+ /// the materialized protobuf object.
+ public static Any Parse(ByteString source)
+ {
+ return Any.Parser.ParseFrom(source);
+ }
+
+ ///
+ /// Helper to unpack {@link com.google.protobuf.Any} objects to the specific protobuf object.
+ ///
+ /// the parsed Source value.
+ /// IMessage
+ /// the materialized protobuf object.
+ public static T Unpack(Any source) where T : IMessage, new()
+ {
+ return source.Unpack();
+ }
+
+ ///
+ /// Helper to parse and unpack {@link com.google.protobuf.Any} objects to the specific protobuf object.
+ ///
+ /// the raw bytes source value.
+ /// IMessage
+ /// the materialized protobuf object.
+ public static T ParseAndUnpack(ByteString source) where T : IMessage, new()
+ {
+ return Unpack(Parse(source));
+ }
+ }
+
+ ///
+ /// A set of helper functions for converting encoded commands to IMessage types
+ ///
+ public static class FlightSqlExtensions
+ {
+ private static Any ParsedCommand(this FlightDescriptor descriptor)
+ {
+ return FlightSqlUtils.Parse(descriptor.Command);
+ }
+
+ private static IMessage UnpackMessage(this Any command)
+ {
+ if (command.Is(CommandStatementQuery.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandPreparedStatementQuery.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandGetCatalogs.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandGetDbSchemas.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandGetTables.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandGetTableTypes.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandGetSqlInfo.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandGetPrimaryKeys.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandGetExportedKeys.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandGetImportedKeys.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandGetCrossReference.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandGetXdbcTypeInfo.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(TicketStatementQuery.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(TicketStatementQuery.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandStatementUpdate.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandPreparedStatementUpdate.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+ if (command.Is(CommandPreparedStatementQuery.Descriptor))
+ return FlightSqlUtils.Unpack(command);
+
+ throw new ArgumentException("The defined request is invalid.");
+ }
+
+ ///
+ /// Extracts a command from a FlightDescriptor
+ ///
+ ///
+ /// An IMessage that has been parsed and unpacked
+ public static IMessage? ParsedAndUnpackedMessage(this FlightDescriptor descriptor)
+ {
+ try
+ {
+ return descriptor.ParsedCommand().UnpackMessage();
+ }
+ catch (ArgumentException)
+ {
+ return null;
+ }
+ }
+
+ public static ByteString Serialize(this IBufferMessage message)
+ {
+ int size = message.CalculateSize();
+ var writer = new ArrayBufferWriter(size);
+ message.WriteTo(writer);
+ var schemaBytes = writer.WrittenSpan;
+ return ByteString.CopyFrom(schemaBytes);
+ }
+ }
+}
diff --git a/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs b/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs
new file mode 100644
index 0000000000000..f3f3bef1e1d00
--- /dev/null
+++ b/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs
@@ -0,0 +1,22 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+namespace Apache.Arrow.Flight.Sql;
+
+public static class SqlAction
+{
+ public const string CreateRequest = "CreatePreparedStatement";
+ public const string CloseRequest = "ClosePreparedStatement";
+}
diff --git a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
index 8140e06493dc2..5dc0d1b434b6d 100644
--- a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
+++ b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
@@ -13,11 +13,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-using System;
-using System.Collections.Generic;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Internal;
using Apache.Arrow.Flight.Protocol;
+using Apache.Arrow.Flight.Server;
+using Apache.Arrow.Flight.Server.Internal;
using Grpc.Core;
using Grpc.Net.Client;
@@ -93,6 +93,22 @@ public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDesc
channels.Dispose);
}
+ public AsyncDuplexStreamingCall Handshake(Metadata headers = null)
+ {
+ var channel = _client.Handshake(headers);
+ var readStream = new StreamReader(channel.ResponseStream, response => new FlightHandshakeResponse(response));
+ var writeStream = new FlightHandshakeStreamWriterAdapter(channel.RequestStream);
+ var call = new AsyncDuplexStreamingCall(
+ writeStream,
+ readStream,
+ channel.ResponseHeadersAsync,
+ channel.GetStatus,
+ channel.GetTrailers,
+ channel.Dispose);
+
+ return call;
+ }
+
public AsyncServerStreamingCall DoAction(FlightAction action, Metadata headers = null)
{
var stream = _client.DoAction(action.ToProtocol(), headers);
diff --git a/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamReader.cs b/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamReader.cs
index 011af0c831508..73094338be4cd 100644
--- a/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamReader.cs
+++ b/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamReader.cs
@@ -21,7 +21,7 @@ namespace Apache.Arrow.Flight.Client
{
public class FlightClientRecordBatchStreamReader : FlightRecordBatchStreamReader
{
- internal FlightClientRecordBatchStreamReader(IAsyncStreamReader flightDataStream) : base(flightDataStream)
+ internal FlightClientRecordBatchStreamReader(IAsyncStreamReader flightDataStream) : base(flightDataStream)
{
}
}
diff --git a/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamWriter.cs b/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamWriter.cs
index d2e62c42e8621..e5fa30c9f6aed 100644
--- a/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamWriter.cs
+++ b/csharp/src/Apache.Arrow.Flight/Client/FlightClientRecordBatchStreamWriter.cs
@@ -25,9 +25,9 @@ namespace Apache.Arrow.Flight.Client
{
public class FlightClientRecordBatchStreamWriter : FlightRecordBatchStreamWriter, IClientStreamWriter
{
- private readonly IClientStreamWriter _clientStreamWriter;
+ private readonly IClientStreamWriter _clientStreamWriter;
private bool _completed = false;
- internal FlightClientRecordBatchStreamWriter(IClientStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor) : base(clientStreamWriter, flightDescriptor)
+ internal FlightClientRecordBatchStreamWriter(IClientStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor) : base(clientStreamWriter, flightDescriptor)
{
_clientStreamWriter = clientStreamWriter;
}
diff --git a/csharp/src/Apache.Arrow.Flight/FlightData.cs b/csharp/src/Apache.Arrow.Flight/FlightData.cs
new file mode 100644
index 0000000000000..f38b1de2206a8
--- /dev/null
+++ b/csharp/src/Apache.Arrow.Flight/FlightData.cs
@@ -0,0 +1,53 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using Google.Protobuf;
+
+namespace Apache.Arrow.Flight;
+
+public class FlightData
+{
+ public FlightDescriptor Descriptor { get; }
+ public ByteString AppMetadata { get; }
+ public ByteString DataBody { get; }
+ public ByteString DataHeader { get; }
+
+ public FlightData(FlightDescriptor descriptor, ByteString dataBody = null, ByteString dataHeader = null, ByteString appMetadata = null)
+ {
+ Descriptor = descriptor;
+ DataBody = dataBody;
+ DataHeader = dataHeader;
+ AppMetadata = appMetadata;
+ }
+
+ internal FlightData(Protocol.FlightData protocolFlightData)
+ {
+ Descriptor = protocolFlightData.FlightDescriptor == null ? null : new FlightDescriptor(protocolFlightData.FlightDescriptor);
+ DataBody = protocolFlightData.DataBody;
+ DataHeader = protocolFlightData.DataHeader;
+ AppMetadata = protocolFlightData.AppMetadata;
+ }
+
+ internal Protocol.FlightData ToProtocol()
+ {
+ return new Protocol.FlightData
+ {
+ FlightDescriptor = Descriptor?.ToProtocol(),
+ AppMetadata = AppMetadata,
+ DataBody = DataBody,
+ DataHeader = DataHeader
+ };
+ }
+}
diff --git a/csharp/src/Apache.Arrow.Flight/FlightHandshakeRequest.cs b/csharp/src/Apache.Arrow.Flight/FlightHandshakeRequest.cs
new file mode 100644
index 0000000000000..62db6446072a4
--- /dev/null
+++ b/csharp/src/Apache.Arrow.Flight/FlightHandshakeRequest.cs
@@ -0,0 +1,58 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using Google.Protobuf;
+
+namespace Apache.Arrow.Flight;
+
+public class FlightHandshakeRequest
+{
+ private readonly Protocol.HandshakeRequest _result;
+ public ByteString Payload => _result.Payload;
+ public ulong ProtocolVersion => _result.ProtocolVersion;
+
+ internal FlightHandshakeRequest(Protocol.HandshakeRequest result)
+ {
+ _result = result;
+ }
+
+ public FlightHandshakeRequest(ByteString payload, ulong protocolVersion = 1)
+ {
+ _result = new Protocol.HandshakeRequest
+ {
+ Payload = payload,
+ ProtocolVersion = protocolVersion
+ };
+ }
+
+ internal Protocol.HandshakeRequest ToProtocol()
+ {
+ return _result;
+ }
+
+ public override bool Equals(object obj)
+ {
+ if(obj is FlightHandshakeRequest other)
+ {
+ return Equals(_result, other._result);
+ }
+ return false;
+ }
+
+ public override int GetHashCode()
+ {
+ return _result.GetHashCode();
+ }
+}
diff --git a/csharp/src/Apache.Arrow.Flight/FlightHandshakeResponse.cs b/csharp/src/Apache.Arrow.Flight/FlightHandshakeResponse.cs
new file mode 100644
index 0000000000000..4ceb288f8eda1
--- /dev/null
+++ b/csharp/src/Apache.Arrow.Flight/FlightHandshakeResponse.cs
@@ -0,0 +1,63 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using Google.Protobuf;
+
+namespace Apache.Arrow.Flight;
+
+public class FlightHandshakeResponse
+{
+ public static readonly FlightHandshakeResponse Empty = new FlightHandshakeResponse();
+ private readonly Protocol.HandshakeResponse _handshakeResponse;
+
+ public ulong ProtocolVersion
+ {
+ get => _handshakeResponse.ProtocolVersion;
+ set => _handshakeResponse.ProtocolVersion = value;
+ }
+
+ public ByteString Payload
+ {
+ get => _handshakeResponse.Payload;
+ set => _handshakeResponse.Payload = value;
+ }
+
+ public FlightHandshakeResponse()
+ {
+ _handshakeResponse = new Protocol.HandshakeResponse
+ {
+ ProtocolVersion = 1
+ };
+ }
+
+ internal FlightHandshakeResponse(Protocol.HandshakeResponse handshakeResponse)
+ {
+ _handshakeResponse = handshakeResponse;
+ }
+
+ public FlightHandshakeResponse(ByteString payload, ulong protocolVersion = 1)
+ {
+ _handshakeResponse = new Protocol.HandshakeResponse
+ {
+ ProtocolVersion = protocolVersion,
+ Payload = payload
+ };
+ }
+
+ internal Protocol.HandshakeResponse ToProtocol()
+ {
+ return _handshakeResponse;
+ }
+}
diff --git a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs
index a72be5a823403..f76f08224541f 100644
--- a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs
+++ b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs
@@ -27,12 +27,12 @@ namespace Apache.Arrow.Flight
public abstract class FlightRecordBatchStreamWriter : IAsyncStreamWriter, IDisposable
{
private FlightDataStream _flightDataStream;
- private readonly IAsyncStreamWriter _clientStreamWriter;
+ private readonly IAsyncStreamWriter _clientStreamWriter;
private readonly FlightDescriptor _flightDescriptor;
private bool _disposed;
- private protected FlightRecordBatchStreamWriter(IAsyncStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor)
+ private protected FlightRecordBatchStreamWriter(IAsyncStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor)
{
_clientStreamWriter = clientStreamWriter;
_flightDescriptor = flightDescriptor;
diff --git a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs
index 3211212c99cb9..72c1551be2917 100644
--- a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs
+++ b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs
@@ -34,10 +34,10 @@ namespace Apache.Arrow.Flight.Internal
internal class FlightDataStream : ArrowStreamWriter
{
private readonly FlightDescriptor _flightDescriptor;
- private readonly IAsyncStreamWriter _clientStreamWriter;
+ private readonly IAsyncStreamWriter _clientStreamWriter;
private Protocol.FlightData _currentFlightData;
- public FlightDataStream(IAsyncStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor, Schema schema)
+ public FlightDataStream(IAsyncStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor, Schema schema)
: base(new MemoryStream(), schema)
{
_clientStreamWriter = clientStreamWriter;
diff --git a/csharp/src/Apache.Arrow.Flight/Internal/SchemaWriter.cs b/csharp/src/Apache.Arrow.Flight/Internal/SchemaWriter.cs
index c7e7d8135a1dd..be27cb1e39161 100644
--- a/csharp/src/Apache.Arrow.Flight/Internal/SchemaWriter.cs
+++ b/csharp/src/Apache.Arrow.Flight/Internal/SchemaWriter.cs
@@ -20,6 +20,7 @@
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Flatbuf;
+using Apache.Arrow.Flight.Internal;
using Apache.Arrow.Ipc;
using Google.Protobuf;
@@ -30,7 +31,7 @@ namespace Apache.Arrow.Flight.Internal
///
internal class SchemaWriter : ArrowStreamWriter
{
- private SchemaWriter(Stream baseStream, Schema schema) : base(baseStream, schema)
+ internal SchemaWriter(Stream baseStream, Schema schema) : base(baseStream, schema)
{
}
@@ -53,3 +54,12 @@ public void WriteSchema(Schema schema, CancellationToken cancellationToken)
}
}
}
+
+public static class SchemaExtension
+{
+ // Translate an Apache.Arrow.Schema to FlatBuffer Schema to ByteString
+ public static ByteString ToByteString(this Apache.Arrow.Schema schema)
+ {
+ return SchemaWriter.SerializeSchema(schema);
+ }
+}
diff --git a/csharp/src/Apache.Arrow.Flight/Server/FlightServer.cs b/csharp/src/Apache.Arrow.Flight/Server/FlightServer.cs
index 30b0409d422fb..0005caf175f67 100644
--- a/csharp/src/Apache.Arrow.Flight/Server/FlightServer.cs
+++ b/csharp/src/Apache.Arrow.Flight/Server/FlightServer.cs
@@ -14,8 +14,6 @@
// limitations under the License.
using System;
-using System.Collections.Generic;
-using System.Text;
using System.Threading.Tasks;
using Grpc.Core;
@@ -57,5 +55,10 @@ public virtual Task GetFlightInfo(FlightDescriptor request, ServerCa
{
throw new NotImplementedException();
}
+
+ public virtual Task Handshake(IAsyncStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context)
+ {
+ throw new NotImplementedException();
+ }
}
}
diff --git a/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamReader.cs b/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamReader.cs
index 5476d3d0e5ff9..c52b761ad38d9 100644
--- a/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamReader.cs
+++ b/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamReader.cs
@@ -14,7 +14,6 @@
// limitations under the License.
using System.Threading.Tasks;
-using Apache.Arrow.Flight.Protocol;
using Apache.Arrow.Flight.Internal;
using Grpc.Core;
@@ -22,7 +21,11 @@ namespace Apache.Arrow.Flight.Server
{
public class FlightServerRecordBatchStreamReader : FlightRecordBatchStreamReader
{
- internal FlightServerRecordBatchStreamReader(IAsyncStreamReader flightDataStream) : base(flightDataStream)
+ public FlightServerRecordBatchStreamReader(IAsyncStreamReader flightDataStream) : base(new StreamReader(flightDataStream, data => data.ToProtocol()))
+ {
+ }
+
+ internal FlightServerRecordBatchStreamReader(IAsyncStreamReader flightDataStream) : base(flightDataStream)
{
}
diff --git a/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamWriter.cs b/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamWriter.cs
index 6c1987339bdf3..7d1c89ea3df2a 100644
--- a/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamWriter.cs
+++ b/csharp/src/Apache.Arrow.Flight/Server/FlightServerRecordBatchStreamWriter.cs
@@ -13,10 +13,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-using System;
-using System.Collections.Generic;
-using System.Text;
-using Apache.Arrow.Flight.Protocol;
using Apache.Arrow.Flight.Internal;
using Grpc.Core;
@@ -24,7 +20,11 @@ namespace Apache.Arrow.Flight.Server
{
public class FlightServerRecordBatchStreamWriter : FlightRecordBatchStreamWriter, IServerStreamWriter
{
- internal FlightServerRecordBatchStreamWriter(IServerStreamWriter clientStreamWriter) : base(clientStreamWriter, null)
+ public FlightServerRecordBatchStreamWriter(IServerStreamWriter clientStreamWriter) : base(new StreamWriter(clientStreamWriter, data => new FlightData(data)), null)
+ {
+ }
+
+ internal FlightServerRecordBatchStreamWriter(IServerStreamWriter clientStreamWriter) : base(clientStreamWriter, null)
{
}
}
diff --git a/csharp/src/Apache.Arrow.Flight/Server/Internal/FlightServerImplementation.cs b/csharp/src/Apache.Arrow.Flight/Server/Internal/FlightServerImplementation.cs
index dcf6e57681894..f34ffaf92fc81 100644
--- a/csharp/src/Apache.Arrow.Flight/Server/Internal/FlightServerImplementation.cs
+++ b/csharp/src/Apache.Arrow.Flight/Server/Internal/FlightServerImplementation.cs
@@ -14,12 +14,9 @@
// limitations under the License.
using System;
-using System.Collections.Generic;
-using System.Text;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Internal;
using Apache.Arrow.Flight.Protocol;
-using Apache.Arrow.Flight.Server;
using Grpc.Core;
namespace Apache.Arrow.Flight.Server.Internal
@@ -35,21 +32,26 @@ public FlightServerImplementation(FlightServer flightServer)
_flightServer = flightServer;
}
- public override async Task DoPut(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context)
+ public override async Task DoPut(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context)
{
var readStream = new FlightServerRecordBatchStreamReader(requestStream);
- var writeStream = new StreamWriter(responseStream, putResult => putResult.ToProtocol());
+ var writeStream = new StreamWriter(responseStream, putResult => putResult.ToProtocol());
+
await _flightServer.DoPut(readStream, writeStream, context).ConfigureAwait(false);
}
- public override Task DoGet(Protocol.Ticket request, IServerStreamWriter responseStream, ServerCallContext context)
+ public override Task DoGet(Protocol.Ticket request, IServerStreamWriter responseStream, ServerCallContext context)
{
- return _flightServer.DoGet(new FlightTicket(request.Ticket_), new FlightServerRecordBatchStreamWriter(responseStream), context);
+ var flightTicket = new FlightTicket(request.Ticket_);
+ var flightServerRecordBatchStreamWriter = new FlightServerRecordBatchStreamWriter(responseStream);
+
+ return _flightServer.DoGet(flightTicket, flightServerRecordBatchStreamWriter, context);
}
public override Task ListFlights(Protocol.Criteria request, IServerStreamWriter responseStream, ServerCallContext context)
{
var writeStream = new StreamWriter(responseStream, flightInfo => flightInfo.ToProtocol());
+
return _flightServer.ListFlights(new FlightCriteria(request), writeStream, context);
}
@@ -57,6 +59,7 @@ public override Task DoAction(Protocol.Action request, IServerStreamWriter(responseStream, result => result.ToProtocol());
+
return _flightServer.DoAction(action, writeStream, context);
}
@@ -74,12 +77,12 @@ public override async Task GetSchema(Protocol.FlightDescriptor req
public override async Task GetFlightInfo(Protocol.FlightDescriptor request, ServerCallContext context)
{
var flightDescriptor = new FlightDescriptor(request);
- var flightInfo = await _flightServer.GetFlightInfo(flightDescriptor, context).ConfigureAwait(false);
+ FlightInfo flightInfo = await _flightServer.GetFlightInfo(flightDescriptor, context).ConfigureAwait(false);
return flightInfo.ToProtocol();
}
- public override Task DoExchange(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context)
+ public override Task DoExchange(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context)
{
//Exchange is not yet implemented
throw new NotImplementedException();
@@ -87,14 +90,15 @@ public override Task DoExchange(IAsyncStreamReader requestStream, IS
public override Task Handshake(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context)
{
- //Handshake is not yet implemented
- throw new NotImplementedException();
+ var readStream = new StreamReader(requestStream, request => new FlightHandshakeRequest(request));
+ var writeStream = new StreamWriter(responseStream, result => result.ToProtocol());
+ return _flightServer.Handshake(readStream, writeStream, context);
}
- public override Task ListActions(Empty request, IServerStreamWriter responseStream, ServerCallContext context)
+ public override async Task ListActions(Empty request, IServerStreamWriter responseStream, ServerCallContext context)
{
var writeStream = new StreamWriter(responseStream, (actionType) => actionType.ToProtocol());
- return _flightServer.ListActions(writeStream, context);
+ await _flightServer.ListActions(writeStream, context).ConfigureAwait(false);
}
}
}
diff --git a/csharp/src/Apache.Arrow.Flight/Server/Internal/HandshakeAdapters.cs b/csharp/src/Apache.Arrow.Flight/Server/Internal/HandshakeAdapters.cs
new file mode 100644
index 0000000000000..40ac3bebf1b19
--- /dev/null
+++ b/csharp/src/Apache.Arrow.Flight/Server/Internal/HandshakeAdapters.cs
@@ -0,0 +1,40 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System.Threading.Tasks;
+using Apache.Arrow.Flight.Protocol;
+using Grpc.Core;
+
+namespace Apache.Arrow.Flight.Server.Internal;
+
+internal class FlightHandshakeStreamWriterAdapter : IClientStreamWriter
+{
+ private readonly IClientStreamWriter _writeStream;
+
+ public FlightHandshakeStreamWriterAdapter(IClientStreamWriter writeStream)
+ {
+ _writeStream = writeStream;
+ }
+
+ public Task WriteAsync(FlightHandshakeRequest message) => _writeStream.WriteAsync(message.ToProtocol());
+
+ public WriteOptions WriteOptions
+ {
+ get => _writeStream.WriteOptions;
+ set => _writeStream.WriteOptions = value;
+ }
+
+ public Task CompleteAsync() => _writeStream.CompleteAsync();
+}
diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj
new file mode 100644
index 0000000000000..07e341eb27ab2
--- /dev/null
+++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj
@@ -0,0 +1,19 @@
+
+
+
+ net7.0
+ false
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlServerTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlServerTests.cs
new file mode 100644
index 0000000000000..4ad5bde0874a8
--- /dev/null
+++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlServerTests.cs
@@ -0,0 +1,375 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#nullable enable
+using System;
+using System.Collections.Generic;
+using System.Collections.ObjectModel;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Arrow.Flight.Server;
+using Apache.Arrow.Types;
+using Arrow.Flight.Protocol.Sql;
+using Google.Protobuf;
+using Grpc.Core;
+using Xunit;
+
+namespace Apache.Arrow.Flight.Sql.Tests;
+
+public class FlightSqlServerTests
+{
+ [Theory]
+ [InlineData(FlightDescriptorType.Path, null)]
+ [InlineData(FlightDescriptorType.Command, null)]
+ [InlineData(FlightDescriptorType.Command, typeof(CommandGetCatalogs))]
+ public void EnsureGetCommandReturnsTheCorrectResponse(FlightDescriptorType type, Type? expectedResult)
+ {
+ //Given
+ FlightDescriptor descriptor;
+ if (type == FlightDescriptorType.Command)
+ {
+ descriptor = expectedResult != null ?
+ FlightDescriptor.CreateCommandDescriptor(((IMessage) Activator.CreateInstance(expectedResult!)!).PackAndSerialize().ToByteArray()) :
+ FlightDescriptor.CreateCommandDescriptor(ByteString.Empty.ToStringUtf8());
+ }
+ else
+ {
+ descriptor = FlightDescriptor.CreatePathDescriptor(System.Array.Empty());
+ }
+
+ //When
+ var result = FlightSqlServer.GetCommand(descriptor);
+
+ //Then
+ Assert.Equal(expectedResult, result?.GetType());
+ }
+
+ [Fact]
+ public async Task EnsureTheCorrectActionsAreGiven()
+ {
+ //Given
+ var producer = new TestFlightSqlSever();
+ var streamWriter = new MockServerStreamWriter();
+
+ //When
+ await producer.ListActions(streamWriter, new MockServerCallContext()).ConfigureAwait(false);
+ var actions = streamWriter.Messages.ToArray();
+
+ Assert.Equal(FlightSqlUtils.FlightSqlActions, actions);
+ }
+
+ [Theory]
+ [InlineData(false,
+ new[] {"catalog_name", "db_schema_name", "table_name", "table_type"},
+ new[] {typeof(StringType), typeof(StringType), typeof(StringType), typeof(StringType)},
+ new[] {true, true, false, false})
+ ]
+ [InlineData(true,
+ new[] {"catalog_name", "db_schema_name", "table_name", "table_type", "table_schema"},
+ new[] {typeof(StringType), typeof(StringType), typeof(StringType), typeof(StringType), typeof(BinaryType)},
+ new[] {true, true, false, false, false})
+ ]
+ public void EnsureTableSchemaIsCorrectWithoutTableSchema(bool includeTableSchemaField, string[] expectedNames, Type[] expectedTypes, bool[] expectedIsNullable)
+ {
+ // Arrange
+
+ // Act
+ var schema = FlightSqlServer.GetTableSchema(includeTableSchemaField);
+ var fields = schema.FieldsList;
+
+ //Assert
+ Assert.False(schema.HasMetadata);
+ Assert.Equal(expectedNames.Length, fields.Count);
+ for (int i = 0; i < fields.Count; i++)
+ {
+ Assert.Equal(expectedNames[i], fields[i].Name);
+ Assert.Equal(expectedTypes[i], fields[i].DataType.GetType());
+ Assert.Equal(expectedIsNullable[i], fields[i].IsNullable);
+ }
+ }
+
+ #region FlightInfoTests
+ [Theory]
+ [InlineData(typeof(CommandStatementQuery), "GetStatementQueryFlightInfo")]
+ [InlineData(typeof(CommandPreparedStatementQuery), "GetPreparedStatementQueryFlightInfo")]
+ [InlineData(typeof(CommandGetCatalogs), "GetCatalogFlightInfo")]
+ [InlineData(typeof(CommandGetDbSchemas), "GetDbSchemaFlightInfo")]
+ [InlineData(typeof(CommandGetTables), "GetTablesFlightInfo")]
+ [InlineData(typeof(CommandGetTableTypes), "GetTableTypesFlightInfo")]
+ [InlineData(typeof(CommandGetSqlInfo), "GetSqlFlightInfo")]
+ [InlineData(typeof(CommandGetPrimaryKeys), "GetPrimaryKeysFlightInfo")]
+ [InlineData(typeof(CommandGetExportedKeys), "GetExportedKeysFlightInfo")]
+ [InlineData(typeof(CommandGetImportedKeys), "GetImportedKeysFlightInfo")]
+ [InlineData(typeof(CommandGetCrossReference), "GetCrossReferenceFlightInfo")]
+ [InlineData(typeof(CommandGetXdbcTypeInfo), "GetXdbcTypeFlightInfo")]
+ public async void EnsureGetFlightInfoIsCorrectlyRoutedForCommand(Type commandType, string expectedResult)
+ {
+ //Given
+ var command = (IMessage) Activator.CreateInstance(commandType)!;
+ var producer = new TestFlightSqlSever();
+ var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize().ToArray());
+
+ //When
+ var flightInfo = await producer.GetFlightInfo(descriptor, new MockServerCallContext());
+
+ //Then
+ Assert.Equal(expectedResult, flightInfo.Descriptor.Paths.First());
+ }
+
+
+ [Fact]
+ public async void EnsureAnInvalidOperationExceptionIsThrownWhenACommandIsNotSupportedAndHasNoDescriptor()
+ {
+ //Given
+ var producer = new TestFlightSqlSever();
+
+ //When
+ async Task Act() => await producer.GetFlightInfo(FlightDescriptor.CreatePathDescriptor(""), new MockServerCallContext());
+ var exception = await Record.ExceptionAsync(Act);
+
+ //Then
+ Assert.Equal("command type not supported", exception?.Message);
+ }
+
+ [Fact]
+ public async void EnsureAnInvalidOperationExceptionIsThrownWhenACommandIsNotSupported()
+ {
+ //Given
+ var producer = new TestFlightSqlSever();
+ var command = new CommandPreparedStatementUpdate();
+ var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize().ToArray());
+
+ //When
+ async Task Act() => await producer.GetFlightInfo(descriptor, new MockServerCallContext());
+ var exception = await Record.ExceptionAsync(Act);
+
+ //Then
+ Assert.Equal("command type CommandPreparedStatementUpdate not supported", exception?.Message);
+ }
+ #endregion
+
+ #region DoGetTests
+
+ [Theory]
+ [InlineData(typeof(CommandPreparedStatementQuery), "DoGetPreparedStatementQuery")]
+ [InlineData(typeof(CommandGetSqlInfo), "DoGetSqlInfo")]
+ [InlineData(typeof(CommandGetCatalogs), "DoGetCatalog")]
+ [InlineData(typeof(CommandGetTableTypes), "DoGetTableType")]
+ [InlineData(typeof(CommandGetTables), "DoGetTables")]
+ [InlineData(typeof(CommandGetDbSchemas), "DoGetDbSchema")]
+ [InlineData(typeof(CommandGetPrimaryKeys), "DoGetPrimaryKeys")]
+ [InlineData(typeof(CommandGetExportedKeys), "DoGetExportedKeys")]
+ [InlineData(typeof(CommandGetImportedKeys), "DoGetImportedKeys")]
+ [InlineData(typeof(CommandGetCrossReference), "DoGetCrossReference")]
+ [InlineData(typeof(CommandGetXdbcTypeInfo), "DoGetXbdcTypeInfo")]
+ public async void EnsureDoGetIsCorrectlyRoutedForADoGetCommand(Type commandType, string expectedResult)
+ {
+ //Given
+ var producer = new TestFlightSqlSever();
+ var command = (IMessage) Activator.CreateInstance(commandType)!;
+ var ticket = new FlightTicket(command.PackAndSerialize());
+ var streamWriter = new MockServerStreamWriter();
+
+ //When
+ await producer.DoGet(ticket, new FlightServerRecordBatchStreamWriter(streamWriter), new MockServerCallContext());
+ var schema = await streamWriter.Messages.GetSchema();
+
+ //Then
+ Assert.Equal(expectedResult, schema.FieldsList[0].Name);
+ }
+
+ [Fact]
+ public async void EnsureAnInvalidOperationExceptionIsThrownWhenADoGetCommandIsNotSupported()
+ {
+ //Given
+ var producer = new TestFlightSqlSever();
+ var ticket = new FlightTicket("");
+ var streamWriter = new MockServerStreamWriter();
+
+ //When
+ async Task Act() => await producer.DoGet(ticket, new FlightServerRecordBatchStreamWriter(streamWriter), new MockServerCallContext());
+ var exception = await Record.ExceptionAsync(Act);
+
+ //Then
+ Assert.Equal("Status(StatusCode=\"InvalidArgument\", Detail=\"DoGet command is not supported.\")", exception?.Message);
+ }
+ #endregion
+
+ #region DoActionTests
+ [Theory]
+ [InlineData(SqlAction.CloseRequest, typeof(ActionClosePreparedStatementRequest), "ClosePreparedStatement")]
+ [InlineData(SqlAction.CreateRequest, typeof(ActionCreatePreparedStatementRequest), "CreatePreparedStatement")]
+ [InlineData("BadCommand", typeof(ActionCreatePreparedStatementRequest), "Action type BadCommand not supported", true)]
+ public async void EnsureDoActionIsCorrectlyRoutedForAnActionRequest(string actionType, Type actionBodyType, string expectedResponse, bool isException = false)
+ {
+ //Given
+ var producer = new TestFlightSqlSever();
+ var actionBody = (IMessage) Activator.CreateInstance(actionBodyType)!;
+ var action = new FlightAction(actionType, actionBody.PackAndSerialize());
+ var mockStreamWriter = new MockStreamWriter();
+
+ //When
+ async Task Act() => await producer.DoAction(action, mockStreamWriter, new MockServerCallContext());
+ var exception = await Record.ExceptionAsync(Act);
+ string? actualMessage = isException ? exception?.Message : mockStreamWriter.Messages[0].Body.ToStringUtf8();
+
+ //Then
+ Assert.Equal(expectedResponse, actualMessage);
+ }
+ #endregion
+
+ #region DoPutTests
+ [Theory]
+ [InlineData(typeof(CommandStatementUpdate), "PutStatementUpdate")]
+ [InlineData(typeof(CommandPreparedStatementQuery), "PutPreparedStatementQuery")]
+ [InlineData(typeof(CommandPreparedStatementUpdate), "PutPreparedStatementUpdate")]
+ [InlineData(typeof(CommandGetXdbcTypeInfo), "Command CommandGetXdbcTypeInfo not supported", true)]
+ public async void EnsureDoPutIsCorrectlyRoutedForTheCommand(Type commandType, string expectedResponse, bool isException = false)
+ {
+ //Given
+ var command = (IMessage) Activator.CreateInstance(commandType)!;
+ var producer = new TestFlightSqlSever();
+ var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize().ToArray());
+ var recordBatch = new RecordBatch(new Schema(new List(), null), System.Array.Empty(), 0);
+ var reader = new MockStreamReader(await recordBatch.ToFlightData(descriptor).ConfigureAwait(false));
+ var batchReader = new FlightServerRecordBatchStreamReader(reader);
+ var mockStreamWriter = new MockServerStreamWriter();
+
+ //When
+ async Task Act() => await producer.DoPut(batchReader, mockStreamWriter, new MockServerCallContext()).ConfigureAwait(false);
+ var exception = await Record.ExceptionAsync(Act);
+ string? actualMessage = isException ? exception?.Message : mockStreamWriter.Messages[0].ApplicationMetadata.ToStringUtf8();
+
+ //Then
+ Assert.Equal(expectedResponse, actualMessage);
+ }
+ #endregion
+
+ private class MockServerCallContext : ServerCallContext
+ {
+ protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders) => throw new NotImplementedException();
+
+ protected override ContextPropagationToken CreatePropagationTokenCore(ContextPropagationOptions? options) => throw new NotImplementedException();
+ protected override string MethodCore => "";
+ protected override string HostCore => "";
+ protected override string PeerCore => "";
+ protected override DateTime DeadlineCore => new();
+ protected override Metadata RequestHeadersCore => new();
+ protected override CancellationToken CancellationTokenCore => default;
+ protected override Metadata ResponseTrailersCore => new();
+ protected override Status StatusCore { get; set; }
+ protected override WriteOptions WriteOptionsCore { get; set; } = WriteOptions.Default;
+ protected override AuthContext AuthContextCore => new("", new Dictionary>());
+ }
+}
+
+internal class MockStreamWriter : IServerStreamWriter
+{
+ public Task WriteAsync(T message)
+ {
+ _messages.Add(message);
+ return Task.FromResult(message);
+ }
+
+ public IReadOnlyList Messages => new ReadOnlyCollection(_messages);
+ public WriteOptions? WriteOptions { get; set; }
+ private readonly List _messages = new();
+}
+
+internal class MockServerStreamWriter : IServerStreamWriter
+{
+ public Task WriteAsync(T message)
+ {
+ _messages.Add(message);
+ return Task.FromResult(message);
+ }
+
+ public IReadOnlyList Messages => new ReadOnlyCollection(_messages);
+ public WriteOptions? WriteOptions { get; set; }
+ private readonly List _messages = new();
+}
+
+internal static class MockStreamReaderWriterExtensions
+{
+ public static async Task> GetRecordBatches(this IReadOnlyList flightDataList)
+ {
+ var list = new List();
+ var recordBatchReader = new FlightServerRecordBatchStreamReader(new MockStreamReader(flightDataList));
+ while (await recordBatchReader.MoveNext().ConfigureAwait(false))
+ {
+ list.Add(recordBatchReader.Current);
+ }
+
+ return list;
+ }
+
+ public static async Task GetSchema(this IEnumerable flightDataList)
+ {
+ var recordBatchReader = new FlightServerRecordBatchStreamReader(new MockStreamReader(flightDataList));
+ return await recordBatchReader.Schema;
+ }
+
+ public static async Task> ToFlightData(this RecordBatch recordBatch, FlightDescriptor? descriptor = null)
+ {
+ var responseStream = new MockFlightServerRecordBatchStreamWriter();
+ await responseStream.WriteRecordBatchAsync(recordBatch).ConfigureAwait(false);
+ if (descriptor == null)
+ {
+ return responseStream.FlightData;
+ }
+
+ return responseStream.FlightData.Select(
+ flightData => new FlightData(descriptor, flightData.DataBody, flightData.DataHeader, flightData.AppMetadata)
+ );
+ }
+}
+
+internal class MockStreamReader: IAsyncStreamReader
+{
+ private readonly IEnumerator _flightActions;
+
+ public MockStreamReader(IEnumerable flightActions)
+ {
+ _flightActions = flightActions.GetEnumerator();
+ }
+
+ public Task MoveNext(CancellationToken cancellationToken)
+ {
+ return Task.FromResult(_flightActions.MoveNext());
+ }
+
+ public T Current => _flightActions.Current;
+}
+
+internal class MockFlightServerRecordBatchStreamWriter : FlightServerRecordBatchStreamWriter
+{
+ private readonly MockStreamWriter _streamWriter;
+ public MockFlightServerRecordBatchStreamWriter() : this(new MockStreamWriter()) { }
+
+ private MockFlightServerRecordBatchStreamWriter(MockStreamWriter clientStreamWriter) : base(clientStreamWriter)
+ {
+ _streamWriter = clientStreamWriter;
+ }
+
+ public IEnumerable FlightData => _streamWriter.Messages;
+
+ public async Task WriteRecordBatchAsync(RecordBatch recordBatch)
+ {
+ await WriteAsync(recordBatch).ConfigureAwait(false);
+ }
+}
+
+
diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs
new file mode 100644
index 0000000000000..031495fffdcc7
--- /dev/null
+++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using Google.Protobuf;
+using Google.Protobuf.WellKnownTypes;
+
+namespace Apache.Arrow.Flight.Sql.Tests;
+
+public static class FlightSqlTestExtensions
+{
+ public static ByteString PackAndSerialize(this IMessage command)
+ {
+ return Any.Pack(command).Serialize();
+ }
+}
diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlUtilsTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlUtilsTests.cs
new file mode 100644
index 0000000000000..3ea7a8d3f0463
--- /dev/null
+++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlUtilsTests.cs
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using Arrow.Flight.Protocol.Sql;
+using Google.Protobuf;
+using Google.Protobuf.WellKnownTypes;
+using Xunit;
+
+namespace Apache.Arrow.Flight.Sql.Tests
+{
+ public class FlightSqlUtilsTests
+ {
+ [Fact]
+ public void EnsureParseCanCorrectlyReviveTheCommand()
+ {
+ //Given
+ var expectedCommand = new CommandStatementQuery
+ {
+ Query = "select * from database"
+ };
+
+ //When
+ var command = FlightSqlUtils.Parse(Any.Pack(expectedCommand).ToByteString());
+
+ //Then
+ Assert.Equal(command.Unpack(), expectedCommand);
+ }
+
+ [Fact]
+ public void EnsureUnpackCanCreateTheCorrectObject()
+ {
+ //Given
+ var expectedCommand = new CommandPreparedStatementQuery
+ {
+ PreparedStatementHandle = ByteString.Empty
+ };
+
+ //When
+ var command = FlightSqlUtils.Unpack(Any.Pack(expectedCommand));
+
+ //Then
+ Assert.Equal(command, expectedCommand);
+ }
+
+ [Fact]
+ public void EnsureParseAndUnpackProducesTheCorrectObject()
+ {
+ //Given
+ var expectedCommand = new CommandPreparedStatementQuery
+ {
+ PreparedStatementHandle = ByteString.Empty
+ };
+
+ //When
+ var command = FlightSqlUtils.ParseAndUnpack(Any.Pack(expectedCommand).ToByteString());
+
+ //Then
+ Assert.Equal(command, expectedCommand);
+ }
+ }
+}
diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs
new file mode 100644
index 0000000000000..3dca632b5b761
--- /dev/null
+++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs
@@ -0,0 +1,89 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System.Collections.Generic;
+using System.Reflection;
+using System.Threading.Tasks;
+using Apache.Arrow.Flight.Server;
+using Apache.Arrow.Types;
+using Arrow.Flight.Protocol.Sql;
+using Grpc.Core;
+
+namespace Apache.Arrow.Flight.Sql.Tests;
+
+public class TestFlightSqlSever : FlightSqlServer
+{
+ protected override Task GetStatementQueryFlightInfo(CommandStatementQuery commandStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty()));
+
+ protected override Task GetPreparedStatementQueryFlightInfo(CommandPreparedStatementQuery preparedStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty()));
+
+ protected override Task GetCatalogFlightInfo(CommandGetCatalogs command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty()));
+
+ protected override Task GetDbSchemaFlightInfo(CommandGetDbSchemas command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty()));
+
+ protected override Task GetTablesFlightInfo(CommandGetTables command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty()));
+
+ protected override Task GetTableTypesFlightInfo(CommandGetTableTypes command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty()));
+
+ protected override Task GetSqlFlightInfo(CommandGetSqlInfo commandGetSqlInfo, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty()));
+
+ protected override Task GetPrimaryKeysFlightInfo(CommandGetPrimaryKeys command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty()));
+
+ protected override Task GetExportedKeysFlightInfo(CommandGetExportedKeys command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty()));
+
+ protected override Task GetImportedKeysFlightInfo(CommandGetImportedKeys command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty()));
+
+ protected override Task GetCrossReferenceFlightInfo(CommandGetCrossReference command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty()));
+
+ protected override Task GetXdbcTypeFlightInfo(CommandGetXdbcTypeInfo command, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty()));
+
+ protected override async Task DoGetPreparedStatementQuery(CommandPreparedStatementQuery preparedStatementQuery, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetPreparedStatementQuery"));
+
+ protected override async Task DoGetSqlInfo(CommandGetSqlInfo getSqlInfo, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetSqlInfo"));
+
+ protected override async Task DoGetCatalog(CommandGetCatalogs command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetCatalog"));
+
+ protected override async Task DoGetTableType(CommandGetTableTypes command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetTableType"));
+
+ protected override async Task DoGetTables(CommandGetTables command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetTables"));
+
+ protected override async Task DoGetPrimaryKeys(CommandGetPrimaryKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetPrimaryKeys"));
+
+ protected override async Task DoGetDbSchema(CommandGetDbSchemas command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetDbSchema"));
+
+ protected override async Task DoGetExportedKeys(CommandGetExportedKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetExportedKeys"));
+
+ protected override async Task DoGetImportedKeys(CommandGetImportedKeys command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetImportedKeys"));
+
+ protected override async Task DoGetCrossReference(CommandGetCrossReference command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetCrossReference"));
+
+ protected override async Task DoGetXbdcTypeInfo(CommandGetXdbcTypeInfo command, FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(MockRecordBatch("DoGetXbdcTypeInfo"));
+
+ protected override async Task CreatePreparedStatement(ActionCreatePreparedStatementRequest request, FlightAction action, IAsyncStreamWriter streamWriter, ServerCallContext serverCallContext) => await streamWriter.WriteAsync(new FlightResult("CreatePreparedStatement"));
+
+ protected override async Task ClosePreparedStatement(ActionClosePreparedStatementRequest request, FlightAction action, IAsyncStreamWriter streamWriter, ServerCallContext serverCallContext) => await streamWriter.WriteAsync(new FlightResult("ClosePreparedStatement"));
+
+ protected override async Task PutPreparedStatementUpdate(CommandPreparedStatementUpdate command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(new FlightPutResult("PutPreparedStatementUpdate"));
+
+ protected override async Task PutStatementUpdate(CommandStatementUpdate command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(new FlightPutResult("PutStatementUpdate"));
+
+ protected override async Task PutPreparedStatementQuery(CommandPreparedStatementQuery command, FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) => await responseStream.WriteAsync(new FlightPutResult("PutPreparedStatementQuery"));
+
+ private RecordBatch MockRecordBatch(string name)
+ {
+ var schema = new Schema(new List {new(name, StringType.Default, false)}, System.Array.Empty>());
+ return new RecordBatch(schema, new []{ new StringArray.Builder().Append(name).Build() }, 1);
+ }
+}
diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs
index ae6e2e4b03b6d..a613b04a32002 100644
--- a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs
+++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs
@@ -18,6 +18,7 @@
using System.Linq;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Server;
+using Google.Protobuf;
using Grpc.Core;
using Grpc.Core.Utils;
@@ -86,6 +87,21 @@ public override Task GetFlightInfo(FlightDescriptor request, ServerC
throw new RpcException(new Status(StatusCode.NotFound, "Flight not found"));
}
+ public override async Task Handshake(IAsyncStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context)
+ {
+ while (await requestStream.MoveNext().ConfigureAwait(false))
+ {
+ if (requestStream.Current.Payload.ToStringUtf8() == "Hello")
+ {
+ await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Hello handshake"))).ConfigureAwait(false);
+ }
+ else
+ {
+ await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Done"))).ConfigureAwait(false);
+ }
+ }
+ }
+
public override Task GetSchema(FlightDescriptor request, ServerCallContext context)
{
if(_flightStore.Flights.TryGetValue(request, out var flightHolder))
diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
index 3556c6a17feef..267fe4e4b606d 100644
--- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
+++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
@@ -283,6 +283,32 @@ public async Task TestListFlights()
}
}
+ [Fact]
+ public async Task TestHandshake()
+ {
+ var duplexStreamingCall = _flightClient.Handshake();
+
+ await duplexStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.Empty)).ConfigureAwait(false);
+ await duplexStreamingCall.RequestStream.CompleteAsync().ConfigureAwait(false);
+ var results = await duplexStreamingCall.ResponseStream.ToListAsync().ConfigureAwait(false);
+
+ Assert.Single(results);
+ Assert.Equal("Done", results.First().Payload.ToStringUtf8());
+ }
+
+ [Fact]
+ public async Task TestHandshakeWithSpecificMessage()
+ {
+ var duplexStreamingCall = _flightClient.Handshake();
+
+ await duplexStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.CopyFromUtf8("Hello"))).ConfigureAwait(false);
+ await duplexStreamingCall.RequestStream.CompleteAsync().ConfigureAwait(false);
+ var results = await duplexStreamingCall.ResponseStream.ToListAsync().ConfigureAwait(false);
+
+ Assert.Single(results);
+ Assert.Equal("Hello handshake", results.First().Payload.ToStringUtf8());
+ }
+
[Fact]
public async Task TestGetBatchesWithAsyncEnumerable()
{
diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt
index 7bdb692d048e9..aebe321d613ab 100644
--- a/dev/release/rat_exclude_files.txt
+++ b/dev/release/rat_exclude_files.txt
@@ -113,8 +113,10 @@ csharp/src/Apache.Arrow/Properties/Resources.resx
csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj
csharp/src/Apache.Arrow.Flight.AspNetCore/Apache.Arrow.Flight.AspNetCore.csproj
csharp/src/Apache.Arrow.Compression/Apache.Arrow.Compression.csproj
+csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj
csharp/test/Apache.Arrow.Benchmarks/Apache.Arrow.Benchmarks.csproj
csharp/test/Apache.Arrow.Flight.Tests/Apache.Arrow.Flight.Tests.csproj
+csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj
csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj
csharp/test/Apache.Arrow.IntegrationTest/Apache.Arrow.IntegrationTest.csproj
csharp/test/Apache.Arrow.Tests/Apache.Arrow.Tests.csproj