Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

<ItemGroup>
<PackageReference Include="ApacheThrift" Version="0.21.0" />
<PackageReference Include="K4os.Compression.LZ4" Version="1.3.8" />
<PackageReference Include="K4os.Compression.LZ4.Streams" Version="1.3.8" />
<PackageReference Include="System.Net.Http" Version="4.3.4" />
<PackageReference Include="System.Text.Json" Version="8.0.5" />
</ItemGroup>
Expand Down
2 changes: 1 addition & 1 deletion csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ internal async Task OpenAsync()

internal abstract SchemaParser SchemaParser { get; }

internal abstract IArrowArrayStream NewReader<T>(T statement, Schema schema) where T : HiveServer2Statement;
internal abstract IArrowArrayStream NewReader<T>(T statement, Schema schema, TGetResultSetMetadataResp? metadataResp = null) where T : HiveServer2Statement;

public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, IReadOnlyList<string>? tableTypes, string? columnNamePattern)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public override AdbcStatement CreateStatement()
return new HiveServer2Statement(this);
}

internal override IArrowArrayStream NewReader<T>(T statement, Schema schema) => new HiveServer2Reader(
internal override IArrowArrayStream NewReader<T>(T statement, Schema schema, TGetResultSetMetadataResp? metadataResp = null) => new HiveServer2Reader(
statement,
schema,
dataTypeConversion: statement.Connection.DataTypeConversion,
Expand Down
14 changes: 5 additions & 9 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,11 @@ private async Task<QueryResult> ExecuteQueryAsyncInternal(CancellationToken canc
// take QueryTimeoutSeconds (but this could be restricting)
await ExecuteStatementAsync(cancellationToken); // --> get QueryTimeout +
await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to QueryTimeout
Schema schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken); // + get the result, up to QueryTimeout
TGetResultSetMetadataResp response = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken);
Comment thread
jadewang-db marked this conversation as resolved.
Schema schema = Connection.SchemaParser.GetArrowSchema(response.Schema, Connection.DataTypeConversion);

return new QueryResult(-1, Connection.NewReader(this, schema));
// Store metadata for use in readers
return new QueryResult(-1, Connection.NewReader(this, schema, response));
}

public override async ValueTask<QueryResult> ExecuteQueryAsync()
Expand All @@ -108,12 +110,6 @@ public override async ValueTask<QueryResult> ExecuteQueryAsync()
}
}

private async Task<Schema> GetResultSetSchemaAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default)
{
TGetResultSetMetadataResp response = await HiveServer2Connection.GetResultSetMetadataAsync(operationHandle, client, cancellationToken);
return Connection.SchemaParser.GetArrowSchema(response.Schema, Connection.DataTypeConversion);
}

public async Task<UpdateResult> ExecuteUpdateAsyncInternal(CancellationToken cancellationToken = default)
{
const string NumberOfAffectedRowsColumnName = "num_affected_rows";
Expand Down Expand Up @@ -195,7 +191,7 @@ public override void SetOption(string key, string value)

protected async Task ExecuteStatementAsync(CancellationToken cancellationToken = default)
{
TExecuteStatementReq executeRequest = new TExecuteStatementReq(Connection.SessionHandle, SqlQuery);
TExecuteStatementReq executeRequest = new TExecuteStatementReq(Connection.SessionHandle!, SqlQuery!);
SetStatementProperties(executeRequest);
TExecuteStatementResp executeResponse = await Connection.Client.ExecuteStatement(executeRequest, cancellationToken);
if (executeResponse.Status.StatusCode == TStatusCode.ERROR_STATUS)
Expand Down
2 changes: 1 addition & 1 deletion csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ protected override void ValidateOptions()
}
}

internal override IArrowArrayStream NewReader<T>(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion);
internal override IArrowArrayStream NewReader<T>(T statement, Schema schema, TGetResultSetMetadataResp? metadataResp = null) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion);

protected override TTransport CreateTransport()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ protected override TOpenSessionReq CreateSessionRequest()
return request;
}

internal override IArrowArrayStream NewReader<T>(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion);
internal override IArrowArrayStream NewReader<T>(T statement, Schema schema, TGetResultSetMetadataResp? metadataResp = null) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion);

internal override ImpalaServerType ServerType => ImpalaServerType.Standard;

Expand Down
318 changes: 318 additions & 0 deletions csharp/src/Drivers/Apache/Spark/CloudFetch/SparkCloudFetchReader.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Arrow.Ipc;
using Apache.Hive.Service.Rpc.Thrift;
using K4os.Compression.LZ4.Streams;

namespace Apache.Arrow.Adbc.Drivers.Apache.Spark.CloudFetch
{
/// <summary>
/// Reader for CloudFetch results from Databricks Spark Thrift server.
/// Handles downloading and processing URL-based result sets.
/// </summary>
internal sealed class SparkCloudFetchReader : IArrowArrayStream
{
// Default values used if not specified in connection properties
private const int DefaultMaxRetries = 3;
private const int DefaultRetryDelayMs = 500;
private const int DefaultTimeoutMinutes = 5;

private readonly int maxRetries;
private readonly int retryDelayMs;
private readonly int timeoutMinutes;

private HiveServer2Statement? statement;
private readonly Schema schema;
private List<TSparkArrowResultLink>? resultLinks;
private int linkIndex;
private ArrowStreamReader? currentReader;
private readonly bool isLz4Compressed;
private long startOffset;

// Lazy initialization of HttpClient
private readonly Lazy<HttpClient> httpClient;

/// <summary>
/// Initializes a new instance of the <see cref="SparkCloudFetchReader"/> class.
/// </summary>
/// <param name="statement">The HiveServer2 statement.</param>
/// <param name="schema">The Arrow schema.</param>
/// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param>
public SparkCloudFetchReader(HiveServer2Statement statement, Schema schema, bool isLz4Compressed)
{
this.statement = statement;
this.schema = schema;
this.isLz4Compressed = isLz4Compressed;

// Get configuration values from connection properties or use defaults
var connectionProps = statement.Connection.Properties;

// Parse max retries
int parsedMaxRetries = DefaultMaxRetries;
if (connectionProps.TryGetValue(SparkParameters.CloudFetchMaxRetries, out string? maxRetriesStr) &&
int.TryParse(maxRetriesStr, out parsedMaxRetries) &&
parsedMaxRetries > 0)
{
// Value was successfully parsed
}
else
{
parsedMaxRetries = DefaultMaxRetries;
}
this.maxRetries = parsedMaxRetries;

// Parse retry delay
int parsedRetryDelay = DefaultRetryDelayMs;
if (connectionProps.TryGetValue(SparkParameters.CloudFetchRetryDelayMs, out string? retryDelayStr) &&
int.TryParse(retryDelayStr, out parsedRetryDelay) &&
parsedRetryDelay > 0)
{
// Value was successfully parsed
}
else
{
parsedRetryDelay = DefaultRetryDelayMs;
}
this.retryDelayMs = parsedRetryDelay;

// Parse timeout minutes
int parsedTimeout = DefaultTimeoutMinutes;
if (connectionProps.TryGetValue(SparkParameters.CloudFetchTimeoutMinutes, out string? timeoutStr) &&
int.TryParse(timeoutStr, out parsedTimeout) &&
parsedTimeout > 0)
{
// Value was successfully parsed
}
else
{
parsedTimeout = DefaultTimeoutMinutes;
}
this.timeoutMinutes = parsedTimeout;

// Initialize HttpClient with the configured timeout
this.httpClient = new Lazy<HttpClient>(() =>
{
var client = new HttpClient();
client.Timeout = TimeSpan.FromMinutes(this.timeoutMinutes);
return client;
});
}

/// <summary>
/// Gets the Arrow schema.
/// </summary>
public Schema Schema { get { return schema; } }

private HttpClient HttpClient
{
get { return httpClient.Value; }
}

/// <summary>
/// Reads the next record batch from the result set.
/// </summary>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The next record batch, or null if there are no more batches.</returns>
public async ValueTask<RecordBatch?> ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
{
while (true)
{
// If we have a current reader, try to read the next batch
if (this.currentReader != null)
{
RecordBatch? next = await this.currentReader.ReadNextRecordBatchAsync(cancellationToken);
if (next != null)
{
return next;
}
else
{
this.currentReader.Dispose();
this.currentReader = null;
}
}

// If we have more links to process, download and process the next one
if (this.resultLinks != null && this.linkIndex < this.resultLinks.Count)
{
var link = this.resultLinks[this.linkIndex++];
byte[]? fileData = null;

// Retry logic for downloading files
for (int retry = 0; retry < this.maxRetries; retry++)
{
try
{
fileData = await DownloadFileAsync(link.FileLink, cancellationToken);
break; // Success, exit retry loop
}
catch (Exception ex) when (retry < this.maxRetries - 1)
{
// Log the error and retry
Debug.WriteLine($"Error downloading file (attempt {retry + 1}/{this.maxRetries}): {ex.Message}");
await Task.Delay(this.retryDelayMs * (retry + 1), cancellationToken);
}
}

// Process the downloaded file data
MemoryStream dataStream;

// If the data is LZ4 compressed, decompress it
if (this.isLz4Compressed)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to leverage the Apache.Arrow.Compression assembly to do decompression? It works by passing a CompressionCodecFactory to the ArrowStreamReader constructor.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have code pointers? I tried it, seems not working.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't, no. I can try to figure it out later; this doesn't need to be blocking.

{
try
{
dataStream = new MemoryStream();
using (var inputStream = new MemoryStream(fileData!))
using (var decompressor = LZ4Stream.Decode(inputStream))
{
await decompressor.CopyToAsync(dataStream);
}
dataStream.Position = 0;
}
catch (Exception ex)
{
Debug.WriteLine($"Error decompressing data: {ex.Message}");
continue; // Skip this link and try the next one
}
}
else
{
dataStream = new MemoryStream(fileData!);
}

try
{
this.currentReader = new ArrowStreamReader(dataStream);
continue;
}
catch (Exception ex)
{
Debug.WriteLine($"Error creating Arrow reader: {ex.Message}");
dataStream.Dispose();
continue; // Skip this link and try the next one
}
}

this.resultLinks = null;
this.linkIndex = 0;

// If there's no statement, we're done
if (this.statement == null)
{
return null;
}

// Fetch more results from the server
TFetchResultsReq request = new TFetchResultsReq(this.statement.OperationHandle!, TFetchOrientation.FETCH_NEXT, this.statement.BatchSize);

// Set the start row offset if we have processed some links already
if (this.startOffset > 0)
{
request.StartRowOffset = this.startOffset;
}

TFetchResultsResp response;
try
{
response = await this.statement.Connection.Client!.FetchResults(request, cancellationToken);
}
catch (Exception ex)
{
Debug.WriteLine($"Error fetching results from server: {ex.Message}");
this.statement = null; // Mark as done due to error
return null;
}

// Check if we have URL-based results
if (response.Results.__isset.resultLinks &&
response.Results.ResultLinks != null &&
response.Results.ResultLinks.Count > 0)
{
this.resultLinks = response.Results.ResultLinks;

// Update the start offset for the next fetch by calculating it from the links
if (this.resultLinks.Count > 0)
{
var lastLink = this.resultLinks[this.resultLinks.Count - 1];
this.startOffset = lastLink.StartRowOffset + lastLink.RowCount;
}

// If the server indicates there are no more rows, we can close the statement
if (!response.HasMoreRows)
{
this.statement = null;
}
}
else
{
// If there are no more results, we're done
this.statement = null;
return null;
}
}
}

/// <summary>
/// Downloads a file from a URL.
/// </summary>
/// <param name="url">The URL to download from.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The downloaded file data.</returns>
private async Task<byte[]> DownloadFileAsync(string url, CancellationToken cancellationToken)
{
using HttpResponseMessage response = await HttpClient.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
response.EnsureSuccessStatusCode();

// Get the content length if available
long? contentLength = response.Content.Headers.ContentLength;
if (contentLength.HasValue && contentLength.Value > 0)
{
Debug.WriteLine($"Downloading file of size: {contentLength.Value / 1024.0 / 1024.0:F2} MB");
}

return await response.Content.ReadAsByteArrayAsync();
}

/// <summary>
/// Disposes the reader.
/// </summary>
public void Dispose()
{
if (this.currentReader != null)
{
this.currentReader.Dispose();
this.currentReader = null;
}

// Dispose the HttpClient if it was created
if (httpClient.IsValueCreated)
{
httpClient.Value.Dispose();
}
}
}
}
Loading