Skip to content

Commit

Permalink
Implement client context (#191)
Browse files Browse the repository at this point in the history
* Introduce client context

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Pass context to all relevant classes

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Make driver a part of context

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Fix tests

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

---------

Signed-off-by: Levko Kravets <levko.ne@gmail.com>
  • Loading branch information
kravets-levko authored Oct 4, 2023
1 parent 46c3586 commit 9eb3807
Show file tree
Hide file tree
Showing 26 changed files with 829 additions and 537 deletions.
56 changes: 38 additions & 18 deletions lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import { EventEmitter } from 'events';
import TCLIService from '../thrift/TCLIService';
import { TProtocolVersion } from '../thrift/TCLIService_types';
import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } from './contracts/IDBSQLClient';
import IDriver from './contracts/IDriver';
import IClientContext from './contracts/IClientContext';
import HiveDriver from './hive/HiveDriver';
import { Int64 } from './hive/Types';
import DBSQLSession from './DBSQLSession';
Expand Down Expand Up @@ -41,13 +43,17 @@ function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) {
};
}

export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
export default class DBSQLClient extends EventEmitter implements IDBSQLClient, IClientContext {
private connectionProvider?: IConnectionProvider;

private authProvider?: IAuthentication;

private client?: TCLIService.Client;

private readonly driver = new HiveDriver({
context: this,
});

private readonly logger: IDBSQLLogger;

private readonly thrift = thrift;
Expand All @@ -73,7 +79,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
};
}

private getAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
private initAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
if (authProvider) {
return authProvider;
}
Expand All @@ -84,15 +90,16 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
return new PlainHttpAuthentication({
username: 'token',
password: options.token,
context: this,
});
case 'databricks-oauth':
return new DatabricksOAuth({
host: options.host,
logger: this.logger,
persistence: options.persistence,
azureTenantId: options.azureTenantId,
clientId: options.oauthClientId,
clientSecret: options.oauthClientSecret,
context: this,
});
case 'custom':
return options.provider;
Expand All @@ -110,7 +117,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
* const session = client.connect({host, path, token});
*/
public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise<IDBSQLClient> {
this.authProvider = this.getAuthProvider(options, authProvider);
this.authProvider = this.initAuthProvider(options, authProvider);

this.connectionProvider = new HttpConnection(this.getConnectionOptions(options));

Expand Down Expand Up @@ -156,44 +163,57 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
* const session = await client.openSession();
*/
public async openSession(request: OpenSessionRequest = {}): Promise<IDBSQLSession> {
const driver = new HiveDriver(() => this.getClient());

const response = await driver.openSession({
const response = await this.driver.openSession({
client_protocol_i64: new Int64(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8),
...getInitialNamespaceOptions(request.initialCatalog, request.initialSchema),
});

Status.assert(response.status);
const session = new DBSQLSession(driver, definedOrError(response.sessionHandle), {
logger: this.logger,
const session = new DBSQLSession({
handle: definedOrError(response.sessionHandle),
context: this,
});
this.sessions.add(session);
return session;
}

private async getClient() {
public async close(): Promise<void> {
await this.sessions.closeAll();

this.client = undefined;
this.connectionProvider = undefined;
this.authProvider = undefined;
}

public getLogger(): IDBSQLLogger {
return this.logger;
}

public async getConnectionProvider(): Promise<IConnectionProvider> {
if (!this.connectionProvider) {
throw new HiveDriverError('DBSQLClient: not connected');
}

return this.connectionProvider;
}

public async getClient(): Promise<TCLIService.Client> {
const connectionProvider = await this.getConnectionProvider();

if (!this.client) {
this.logger.log(LogLevel.info, 'DBSQLClient: initializing thrift client');
this.client = this.thrift.createClient(TCLIService, await this.connectionProvider.getThriftConnection());
this.client = this.thrift.createClient(TCLIService, await connectionProvider.getThriftConnection());
}

if (this.authProvider) {
const authHeaders = await this.authProvider.authenticate();
this.connectionProvider.setHeaders(authHeaders);
connectionProvider.setHeaders(authHeaders);
}

return this.client;
}

public async close(): Promise<void> {
await this.sessions.closeAll();

this.client = undefined;
this.connectionProvider = undefined;
this.authProvider = undefined;
public async getDriver(): Promise<IDriver> {
return this.driver;
}
}
11 changes: 6 additions & 5 deletions lib/DBSQLOperation/FetchResultsHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import {
TRowSet,
} from '../../thrift/TCLIService_types';
import { ColumnCode, FetchType, Int64 } from '../hive/Types';
import HiveDriver from '../hive/HiveDriver';
import Status from '../dto/Status';
import IClientContext from '../contracts/IClientContext';

function checkIfOperationHasMoreRows(response: TFetchResultsResp): boolean {
if (response.hasMoreRows) {
Expand Down Expand Up @@ -36,7 +36,7 @@ function checkIfOperationHasMoreRows(response: TFetchResultsResp): boolean {
}

export default class FetchResultsHelper {
private readonly driver: HiveDriver;
private readonly context: IClientContext;

private readonly operationHandle: TOperationHandle;

Expand All @@ -49,12 +49,12 @@ export default class FetchResultsHelper {
public hasMoreRows: boolean = false;

constructor(
driver: HiveDriver,
context: IClientContext,
operationHandle: TOperationHandle,
prefetchedResults: Array<TFetchResultsResp | undefined>,
returnOnlyPrefetchedResults: boolean,
) {
this.driver = driver;
this.context = context;
this.operationHandle = operationHandle;
prefetchedResults.forEach((item) => {
if (item) {
Expand Down Expand Up @@ -85,7 +85,8 @@ export default class FetchResultsHelper {
return this.processFetchResponse(prefetchedResponse);
}

const response = await this.driver.fetchResults({
const driver = await this.context.getDriver();
const response = await driver.fetchResults({
operationHandle: this.operationHandle,
orientation: this.fetchOrientation,
maxRows: new Int64(maxRows),
Expand Down
68 changes: 34 additions & 34 deletions lib/DBSQLOperation/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import IOperation, {
GetSchemaOptions,
WaitUntilReadyOptions,
} from '../contracts/IOperation';
import HiveDriver from '../hive/HiveDriver';
import {
TGetOperationStatusResp,
TOperationHandle,
Expand All @@ -18,19 +17,22 @@ import {
} from '../../thrift/TCLIService_types';
import Status from '../dto/Status';
import FetchResultsHelper from './FetchResultsHelper';
import IDBSQLLogger, { LogLevel } from '../contracts/IDBSQLLogger';
import { LogLevel } from '../contracts/IDBSQLLogger';
import OperationStateError, { OperationStateErrorCode } from '../errors/OperationStateError';
import IOperationResult from '../result/IOperationResult';
import JsonResult from '../result/JsonResult';
import ArrowResult from '../result/ArrowResult';
import CloudFetchResult from '../result/CloudFetchResult';
import { definedOrError } from '../utils';
import HiveDriverError from '../errors/HiveDriverError';
import IClientContext from '../contracts/IClientContext';

const defaultMaxRows = 100000;

interface DBSQLOperationConstructorOptions {
logger: IDBSQLLogger;
handle: TOperationHandle;
directResults?: TSparkDirectResults;
context: IClientContext;
}

async function delay(ms?: number): Promise<void> {
Expand All @@ -42,12 +44,10 @@ async function delay(ms?: number): Promise<void> {
}

export default class DBSQLOperation implements IOperation {
private readonly driver: HiveDriver;
private readonly context: IClientContext;

private readonly operationHandle: TOperationHandle;

private readonly logger: IDBSQLLogger;

public onClose?: () => void;

private readonly _data: FetchResultsHelper;
Expand All @@ -70,32 +70,26 @@ export default class DBSQLOperation implements IOperation {

private resultHandler?: IOperationResult;

constructor(
driver: HiveDriver,
operationHandle: TOperationHandle,
{ logger }: DBSQLOperationConstructorOptions,
directResults?: TSparkDirectResults,
) {
this.driver = driver;
this.operationHandle = operationHandle;
this.logger = logger;
constructor({ handle, directResults, context }: DBSQLOperationConstructorOptions) {
this.operationHandle = handle;
this.context = context;

const useOnlyPrefetchedResults = Boolean(directResults?.closeOperation);

this.hasResultSet = operationHandle.hasResultSet;
this.hasResultSet = this.operationHandle.hasResultSet;
if (directResults?.operationStatus) {
this.processOperationStatusResponse(directResults.operationStatus);
}

this.metadata = directResults?.resultSetMetadata;
this._data = new FetchResultsHelper(
this.driver,
this.context,
this.operationHandle,
[directResults?.resultSet],
useOnlyPrefetchedResults,
);
this.closeOperation = directResults?.closeOperation;
this.logger.log(LogLevel.debug, `Operation created with id: ${this.getId()}`);
this.context.getLogger().log(LogLevel.debug, `Operation created with id: ${this.getId()}`);
}

public getId() {
Expand All @@ -118,7 +112,7 @@ export default class DBSQLOperation implements IOperation {
const chunk = await this.fetchChunk(options);
data.push(chunk);
} while (await this.hasMoreRows()); // eslint-disable-line no-await-in-loop
this.logger?.log(LogLevel.debug, `Fetched all data from operation with id: ${this.getId()}`);
this.context.getLogger().log(LogLevel.debug, `Fetched all data from operation with id: ${this.getId()}`);

return data.flat();
}
Expand Down Expand Up @@ -149,10 +143,12 @@ export default class DBSQLOperation implements IOperation {
await this.failIfClosed();

const result = await resultHandler.getValue(data ? [data] : []);
this.logger?.log(
LogLevel.debug,
`Fetched chunk of size: ${options?.maxRows || defaultMaxRows} from operation with id: ${this.getId()}`,
);
this.context
.getLogger()
.log(
LogLevel.debug,
`Fetched chunk of size: ${options?.maxRows || defaultMaxRows} from operation with id: ${this.getId()}`,
);
return result;
}

Expand All @@ -163,13 +159,14 @@ export default class DBSQLOperation implements IOperation {
*/
public async status(progress: boolean = false): Promise<TGetOperationStatusResp> {
await this.failIfClosed();
this.logger?.log(LogLevel.debug, `Fetching status for operation with id: ${this.getId()}`);
this.context.getLogger().log(LogLevel.debug, `Fetching status for operation with id: ${this.getId()}`);

if (this.operationStatus) {
return this.operationStatus;
}

const response = await this.driver.getOperationStatus({
const driver = await this.context.getDriver();
const response = await driver.getOperationStatus({
operationHandle: this.operationHandle,
getProgressUpdate: progress,
});
Expand All @@ -186,9 +183,10 @@ export default class DBSQLOperation implements IOperation {
return Status.success();
}

this.logger?.log(LogLevel.debug, `Cancelling operation with id: ${this.getId()}`);
this.context.getLogger().log(LogLevel.debug, `Cancelling operation with id: ${this.getId()}`);

const response = await this.driver.cancelOperation({
const driver = await this.context.getDriver();
const response = await driver.cancelOperation({
operationHandle: this.operationHandle,
});
Status.assert(response.status);
Expand All @@ -209,11 +207,12 @@ export default class DBSQLOperation implements IOperation {
return Status.success();
}

this.logger?.log(LogLevel.debug, `Closing operation with id: ${this.getId()}`);
this.context.getLogger().log(LogLevel.debug, `Closing operation with id: ${this.getId()}`);

const driver = await this.context.getDriver();
const response =
this.closeOperation ??
(await this.driver.closeOperation({
(await driver.closeOperation({
operationHandle: this.operationHandle,
}));
Status.assert(response.status);
Expand Down Expand Up @@ -254,7 +253,7 @@ export default class DBSQLOperation implements IOperation {

await this.waitUntilReady(options);

this.logger?.log(LogLevel.debug, `Fetching schema for operation with id: ${this.getId()}`);
this.context.getLogger().log(LogLevel.debug, `Fetching schema for operation with id: ${this.getId()}`);
const metadata = await this.fetchMetadata();
return metadata.schema ?? null;
}
Expand Down Expand Up @@ -332,7 +331,8 @@ export default class DBSQLOperation implements IOperation {

private async fetchMetadata() {
if (!this.metadata) {
const metadata = await this.driver.getResultSetMetadata({
const driver = await this.context.getDriver();
const metadata = await driver.getResultSetMetadata({
operationHandle: this.operationHandle,
});
Status.assert(metadata.status);
Expand All @@ -349,13 +349,13 @@ export default class DBSQLOperation implements IOperation {
if (!this.resultHandler) {
switch (resultFormat) {
case TSparkRowSetType.COLUMN_BASED_SET:
this.resultHandler = new JsonResult(metadata.schema);
this.resultHandler = new JsonResult(this.context, metadata.schema);
break;
case TSparkRowSetType.ARROW_BASED_SET:
this.resultHandler = new ArrowResult(metadata.schema, metadata.arrowSchema);
this.resultHandler = new ArrowResult(this.context, metadata.schema, metadata.arrowSchema);
break;
case TSparkRowSetType.URL_BASED_SET:
this.resultHandler = new CloudFetchResult(metadata.schema);
this.resultHandler = new CloudFetchResult(this.context, metadata.schema);
break;
default:
this.resultHandler = undefined;
Expand Down
Loading

0 comments on commit 9eb3807

Please sign in to comment.