Skip to content

Commit

Permalink
[PECO-1042] Add proxy support (#193)
Browse files Browse the repository at this point in the history
* Introduce client context

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Pass context to all relevant classes

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Make driver a part of context

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Fix tests

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* [PECO-1042] Add proxy support

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Tidy up code a bit

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Tidy up code a bit

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

---------

Signed-off-by: Levko Kravets <levko.ne@gmail.com>
  • Loading branch information
kravets-levko committed Oct 4, 2023
1 parent 9eb3807 commit fd36b2e
Show file tree
Hide file tree
Showing 12 changed files with 621 additions and 54 deletions.
26 changes: 26 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
# Release History

## 1.x (unreleased)

### Highlights

- Proxy support added

### Proxy support

This feature allows to pass through proxy all the requests library makes. By default, proxy is disabled.
To enable proxy, pass a configuration object to `DBSQLClient.connect`:

```ts
client.connect({
// pass host, path, auth options as usual
proxy: {
protocol: 'http', // supported protocols: 'http', 'https', 'socks', 'socks4', 'socks4a', 'socks5', 'socks5h'
host: 'localhost', // proxy host (string)
port: 8070, // proxy port (number)
auth: { // optional proxy basic auth config
username: ...
password: ...
},
},
})
```

## 1.5.0

### Highlights
Expand Down
1 change: 1 addition & 0 deletions lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
path: prependSlash(options.path),
https: true,
socketTimeout: options.socketTimeout,
proxy: options.proxy,
headers: {
'User-Agent': buildUserAgentString(options.clientId),
},
Expand Down
17 changes: 14 additions & 3 deletions lib/DBSQLSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,11 @@ export default class DBSQLSession implements IDBSQLSession {
if (localFile === undefined) {
throw new StagingError('Local file path not provided');
}
const response = await fetch(presignedUrl, { method: 'GET', headers });

const connectionProvider = await this.context.getConnectionProvider();
const agent = await connectionProvider.getAgent();

const response = await fetch(presignedUrl, { method: 'GET', headers, agent });
if (!response.ok) {
throw new StagingError(`HTTP error ${response.status} ${response.statusText}`);
}
Expand All @@ -283,7 +287,10 @@ export default class DBSQLSession implements IDBSQLSession {
}

private async handleStagingRemove(presignedUrl: string, headers: HeadersInit): Promise<void> {
const response = await fetch(presignedUrl, { method: 'DELETE', headers });
const connectionProvider = await this.context.getConnectionProvider();
const agent = await connectionProvider.getAgent();

const response = await fetch(presignedUrl, { method: 'DELETE', headers, agent });
if (!response.ok) {
throw new StagingError(`HTTP error ${response.status} ${response.statusText}`);
}
Expand All @@ -297,8 +304,12 @@ export default class DBSQLSession implements IDBSQLSession {
if (localFile === undefined) {
throw new StagingError('Local file path not provided');
}

const connectionProvider = await this.context.getConnectionProvider();
const agent = await connectionProvider.getAgent();

const data = fs.readFileSync(localFile);
const response = await fetch(presignedUrl, { method: 'PUT', headers, body: data });
const response = await fetch(presignedUrl, { method: 'PUT', headers, agent, body: data });
if (!response.ok) {
throw new StagingError(`HTTP error ${response.status} ${response.statusText}`);
}
Expand Down
30 changes: 28 additions & 2 deletions lib/connection/auth/DatabricksOAuth/OAuthManager.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Issuer, BaseClient } from 'openid-client';
import http from 'http';
import { Issuer, BaseClient, custom } from 'openid-client';
import HiveDriverError from '../../../errors/HiveDriverError';
import { LogLevel } from '../../../contracts/IDBSQLLogger';
import OAuthToken from './OAuthToken';
Expand Down Expand Up @@ -26,6 +27,8 @@ export default abstract class OAuthManager {

protected readonly options: OAuthManagerOptions;

protected agent?: http.Agent;

protected issuer?: Issuer;

protected client?: BaseClient;
Expand All @@ -48,14 +51,35 @@ export default abstract class OAuthManager {
}

protected async getClient(): Promise<BaseClient> {
// Obtain http agent each time when we need an OAuth client
// to ensure that we always use a valid agent instance
const connectionProvider = await this.context.getConnectionProvider();
this.agent = await connectionProvider.getAgent();

const getHttpOptions = () => ({
agent: this.agent,
});

if (!this.issuer) {
const issuer = await Issuer.discover(this.getOIDCConfigUrl());
// To use custom http agent in Issuer.discover(), we'd have to set Issuer[custom.http_options].
// However, that's a static field, and if multiple instances of OAuthManager used, race condition
// may occur when they simultaneously override that field and then try to use Issuer.discover().
// Therefore we create a local class derived from Issuer, and set that field for it, thus making
// sure that it will not interfere with other instances (or other code that may use Issuer)
class CustomIssuer extends Issuer {
static [custom.http_options] = getHttpOptions;
}

const issuer = await CustomIssuer.discover(this.getOIDCConfigUrl());

// Overwrite `authorization_endpoint` in default config (specifically needed for Azure flow
// where this URL has to be different)
this.issuer = new Issuer({
...issuer.metadata,
authorization_endpoint: this.getAuthorizationUrl(),
});

this.issuer[custom.http_options] = getHttpOptions;
}

if (!this.client) {
Expand All @@ -64,6 +88,8 @@ export default abstract class OAuthManager {
client_secret: this.options.clientSecret,
token_endpoint_auth_method: this.options.clientSecret === undefined ? 'none' : 'client_secret_basic',
});

this.client[custom.http_options] = getHttpOptions;
}

return this.client;
Expand Down
56 changes: 46 additions & 10 deletions lib/connection/connections/HttpConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ import thrift from 'thrift';
import https from 'https';
import http from 'http';
import { HeadersInit } from 'node-fetch';
import { ProxyAgent } from 'proxy-agent';

import IConnectionProvider from '../contracts/IConnectionProvider';
import IConnectionOptions from '../contracts/IConnectionOptions';
import IConnectionOptions, { ProxyOptions } from '../contracts/IConnectionOptions';
import globalConfig from '../../globalConfig';

import ThriftHttpConnection from './ThriftHttpConnection';
Expand All @@ -16,6 +17,8 @@ export default class HttpConnection implements IConnectionProvider {

private connection?: ThriftHttpConnection;

private agent?: http.Agent;

constructor(options: IConnectionOptions) {
this.options = options;
}
Expand All @@ -28,26 +31,59 @@ export default class HttpConnection implements IConnectionProvider {
});
}

private async getAgent(): Promise<http.Agent> {
const { options } = this;
public async getAgent(): Promise<http.Agent> {
if (!this.agent) {
if (this.options.proxy !== undefined) {
this.agent = this.createProxyAgent(this.options.proxy);
} else {
this.agent = this.options.https ? this.createHttpsAgent() : this.createHttpAgent();
}
}

return this.agent;
}

const httpAgentOptions: http.AgentOptions = {
private getAgentDefaultOptions(): http.AgentOptions {
return {
keepAlive: true,
maxSockets: 5,
keepAliveMsecs: 10000,
timeout: options.socketTimeout ?? globalConfig.socketTimeout,
timeout: this.options.socketTimeout ?? globalConfig.socketTimeout,
};
}

private createHttpAgent(): http.Agent {
const httpAgentOptions = this.getAgentDefaultOptions();
return new http.Agent(httpAgentOptions);
}

private createHttpsAgent(): https.Agent {
const httpsAgentOptions: https.AgentOptions = {
...httpAgentOptions,
...this.getAgentDefaultOptions(),
minVersion: 'TLSv1.2',
rejectUnauthorized: false,
ca: options.ca,
cert: options.cert,
key: options.key,
ca: this.options.ca,
cert: this.options.cert,
key: this.options.key,
};
return new https.Agent(httpsAgentOptions);
}

private createProxyAgent(proxyOptions: ProxyOptions): ProxyAgent {
const proxyAuth = proxyOptions.auth?.username
? `${proxyOptions.auth.username}:${proxyOptions.auth?.password ?? ''}@`
: '';
const proxyUrl = `${proxyOptions.protocol}://${proxyAuth}${proxyOptions.host}:${proxyOptions.port}`;

return options.https ? new https.Agent(httpsAgentOptions) : new http.Agent(httpAgentOptions);
const proxyProtocol = `${proxyOptions.protocol}:`;

return new ProxyAgent({
...this.getAgentDefaultOptions(),
getProxyForUrl: () => proxyUrl,
httpsAgent: this.createHttpsAgent(),
httpAgent: this.createHttpAgent(),
protocol: proxyProtocol,
});
}

public async getThriftConnection(): Promise<any> {
Expand Down
11 changes: 11 additions & 0 deletions lib/connection/contracts/IConnectionOptions.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import { HeadersInit } from 'node-fetch';

export interface ProxyOptions {
protocol: 'http' | 'https' | 'socks' | 'socks4' | 'socks4a' | 'socks5' | 'socks5h';
host: string;
port: number;
auth?: {
username?: string;
password?: string;
};
}

export default interface IConnectionOptions {
host: string;
port: number;
path?: string;
https?: boolean;
headers?: HeadersInit;
socketTimeout?: number;
proxy?: ProxyOptions;

ca?: Buffer | string;
cert?: Buffer | string;
Expand Down
3 changes: 3 additions & 0 deletions lib/connection/contracts/IConnectionProvider.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import http from 'http';
import { HeadersInit } from 'node-fetch';

export default interface IConnectionProvider {
getThriftConnection(): Promise<any>;

getAgent(): Promise<http.Agent>;

setHeaders(headers: HeadersInit): void;
}
2 changes: 2 additions & 0 deletions lib/contracts/IDBSQLClient.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import IDBSQLLogger from './IDBSQLLogger';
import IDBSQLSession from './IDBSQLSession';
import IAuthentication from '../connection/contracts/IAuthentication';
import { ProxyOptions } from '../connection/contracts/IConnectionOptions';
import OAuthPersistence from '../connection/auth/DatabricksOAuth/OAuthPersistence';

export interface ClientOptions {
Expand Down Expand Up @@ -30,6 +31,7 @@ export type ConnectionOptions = {
path: string;
clientId?: string;
socketTimeout?: number;
proxy?: ProxyOptions;
} & AuthOptions;

export interface OpenSessionRequest {
Expand Down
8 changes: 7 additions & 1 deletion lib/result/CloudFetchResult.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ export default class CloudFetchResult extends ArrowResult {
}

private async fetch(url: RequestInfo, init?: RequestInit) {
return fetch(url, init);
const connectionProvider = await this.context.getConnectionProvider();
const agent = await connectionProvider.getAgent();

return fetch(url, {
agent,
...init,
});
}
}
Loading

0 comments on commit fd36b2e

Please sign in to comment.