Skip to content

SASL / SCRAM-SHA-256 Authentication #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Support SASL / SCRAM-SHA-256 Authentication, [#6](https://github.com/isoos/postgresql-dart/pull/6).
- Decoder for type `numeric` / `decimal`, [#7](https://github.com/isoos/postgresql-dart/pull/7).

## 2.3.2
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[![CI](https://github.com/isoos/postgresql-dart/actions/workflows/dart.yml/badge.svg)](https://github.com/isoos/postgresql-dart/actions/workflows/dart.yml)

A library for connecting to and querying PostgreSQL databases.
A library for connecting to and querying PostgreSQL databases (see [Postgres Protocol](https://www.postgresql.org/docs/13/protocol-overview.html)).

This driver uses the more efficient and secure extended query format of the PostgreSQL protocol.

Expand Down
30 changes: 30 additions & 0 deletions lib/src/auth/auth.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import 'package:crypto/crypto.dart';
import 'package:sasl_scram/sasl_scram.dart';

import '../../postgres.dart';
import '../server_messages.dart';
import 'md5_authenticator.dart';
import 'sasl_authenticator.dart';

enum AuthenticationScheme { MD5, SCRAM_SHA_256 }

abstract class PostgresAuthenticator {
static String? name;
late final PostgreSQLConnection connection;

PostgresAuthenticator(this.connection);

void onMessage(AuthenticationMessage message);
}

PostgresAuthenticator createAuthenticator(PostgreSQLConnection connection, AuthenticationScheme authenticationScheme) {
switch (authenticationScheme) {
case AuthenticationScheme.MD5:
return MD5Authenticator(connection);
case AuthenticationScheme.SCRAM_SHA_256:
final credentials = UsernamePasswordCredential(username: connection.username, password: connection.password);
return PostgresSaslAuthenticator(connection, ScramAuthenticator('SCRAM-SHA-256', sha256, credentials));
default:
throw PostgreSQLException("Authenticator wasn't specified");
}
}
43 changes: 43 additions & 0 deletions lib/src/auth/md5_authenticator.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import 'package:buffer/buffer.dart';
import 'package:crypto/crypto.dart';

import '../../postgres.dart';
import '../client_messages.dart';
import '../server_messages.dart';
import '../utf8_backed_string.dart';
import 'auth.dart';

class MD5Authenticator extends PostgresAuthenticator {
static final String name = 'MD5';

MD5Authenticator(PostgreSQLConnection connection) : super(connection);

@override
void onMessage(AuthenticationMessage message) {
final reader = ByteDataReader()..add(message.bytes);
final salt = reader.read(4, copy: true);

final authMessage = AuthMD5Message(connection.username!, connection.password!, salt);

connection.socket!.add(authMessage.asBytes());
}
}

class AuthMD5Message extends ClientMessage {
UTF8BackedString? _hashedAuthString;

AuthMD5Message(String username, String password, List<int> saltBytes) {
final passwordHash = md5.convert('$password$username'.codeUnits).toString();
final saltString = String.fromCharCodes(saltBytes);
final md5Hash = md5.convert('$passwordHash$saltString'.codeUnits).toString();
_hashedAuthString = UTF8BackedString('md5$md5Hash');
}

@override
void applyToBuffer(ByteDataWriter buffer) {
buffer.writeUint8(ClientMessage.PasswordIdentifier);
final length = 5 + _hashedAuthString!.utf8Length;
buffer.writeUint32(length);
_hashedAuthString!.applyToBuffer(buffer);
}
}
82 changes: 82 additions & 0 deletions lib/src/auth/sasl_authenticator.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import 'dart:typed_data';

import 'package:buffer/buffer.dart';
import 'package:sasl_scram/sasl_scram.dart';

import '../../postgres.dart';
import '../client_messages.dart';
import '../server_messages.dart';
import '../utf8_backed_string.dart';
import 'auth.dart';

/// Structure for SASL Authenticator
class PostgresSaslAuthenticator extends PostgresAuthenticator {
final SaslAuthenticator authenticator;

PostgresSaslAuthenticator(PostgreSQLConnection connection, this.authenticator) : super(connection);

@override
void onMessage(AuthenticationMessage message) {
ClientMessage msg;
switch (message.type) {
case AuthenticationMessage.KindSASL:
final bytesToSend = authenticator.handleMessage(SaslMessageType.AuthenticationSASL, message.bytes);
if (bytesToSend == null) throw PostgreSQLException('KindSASL: No bytes to send');
msg = SaslClientFirstMessage(bytesToSend, authenticator.mechanism.name);
break;
case AuthenticationMessage.KindSASLContinue:
final bytesToSend = authenticator.handleMessage(SaslMessageType.AuthenticationSASLContinue, message.bytes);
if (bytesToSend == null) throw PostgreSQLException('KindSASLContinue: No bytes to send');
msg = SaslClientLastMessage(bytesToSend);
break;
case AuthenticationMessage.KindSASLFinal:
authenticator.handleMessage(SaslMessageType.AuthenticationSASLFinal, message.bytes);
return;
default:
throw PostgreSQLException('Unsupported authentication type ${message.type}, closing connection.');
}
connection.socket!.add(msg.asBytes());
}
}

class SaslClientFirstMessage extends ClientMessage {
Uint8List bytesToSendToServer;
String mechanismName;

SaslClientFirstMessage(this.bytesToSendToServer, this.mechanismName);

@override
void applyToBuffer(ByteDataWriter buffer) {
buffer.writeUint8(ClientMessage.PasswordIdentifier);

final utf8CachedMechanismName = UTF8BackedString(mechanismName);

final msgLength = bytesToSendToServer.length;
// No Identifier bit + 4 byte counts (for whole length) + mechanism bytes + zero byte + 4 byte counts (for msg length) + msg bytes
final length = 4 + utf8CachedMechanismName.utf8Length + 1 + 4 + msgLength;

buffer.writeUint32(length);
utf8CachedMechanismName.applyToBuffer(buffer);

// do not add the msg byte count for whatever reason
buffer.writeUint32(msgLength);
buffer.write(bytesToSendToServer);
}
}

class SaslClientLastMessage extends ClientMessage {
Uint8List bytesToSendToServer;

SaslClientLastMessage(this.bytesToSendToServer);

@override
void applyToBuffer(ByteDataWriter buffer) {
buffer.writeUint8(ClientMessage.PasswordIdentifier);

// No Identifier bit + 4 byte counts (for msg length) + msg bytes
final length = 4 + bytesToSendToServer.length;

buffer.writeUint32(length);
buffer.write(bytesToSendToServer);
}
}
56 changes: 15 additions & 41 deletions lib/src/client_messages.dart
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import 'dart:typed_data';

import 'package:buffer/buffer.dart';
import 'package:crypto/crypto.dart';

import 'constants.dart';
import 'query.dart';
Expand All @@ -13,13 +12,13 @@ abstract class ClientMessage {

static const int ProtocolVersion = 196608;

static const int BindIdentifier = 66;
static const int DescribeIdentifier = 68;
static const int ExecuteIdentifier = 69;
static const int ParseIdentifier = 80;
static const int QueryIdentifier = 81;
static const int SyncIdentifier = 83;
static const int PasswordIdentifier = 112;
static const int BindIdentifier = 66; // B
static const int DescribeIdentifier = 68; // D
static const int ExecuteIdentifier = 69; // E
static const int ParseIdentifier = 80; //P
static const int QueryIdentifier = 81; // Q
static const int SyncIdentifier = 83; // S
static const int PasswordIdentifier = 112; //p

void applyToBuffer(ByteDataWriter buffer);

Expand All @@ -36,11 +35,6 @@ abstract class ClientMessage {
}
}

void _applyStringToBuffer(UTF8BackedString string, ByteDataWriter buffer) {
buffer.write(string.utf8Bytes);
buffer.writeInt8(0);
}

class StartupMessage extends ClientMessage {
final UTF8BackedString? _username;
final UTF8BackedString _databaseName;
Expand All @@ -66,42 +60,22 @@ class StartupMessage extends ClientMessage {

if (_username != null) {
buffer.write(UTF8ByteConstants.user);
_applyStringToBuffer(_username!, buffer);
_username!.applyToBuffer(buffer);
}

buffer.write(UTF8ByteConstants.database);
_applyStringToBuffer(_databaseName, buffer);
_databaseName.applyToBuffer(buffer);

buffer.write(UTF8ByteConstants.clientEncoding);
buffer.write(UTF8ByteConstants.utf8);

buffer.write(UTF8ByteConstants.timeZone);
_applyStringToBuffer(_timeZone, buffer);
_timeZone.applyToBuffer(buffer);

buffer.writeInt8(0);
}
}

class AuthMD5Message extends ClientMessage {
UTF8BackedString? _hashedAuthString;

AuthMD5Message(String username, String password, List<int> saltBytes) {
final passwordHash = md5.convert('$password$username'.codeUnits).toString();
final saltString = String.fromCharCodes(saltBytes);
final md5Hash =
md5.convert('$passwordHash$saltString'.codeUnits).toString();
_hashedAuthString = UTF8BackedString('md5$md5Hash');
}

@override
void applyToBuffer(ByteDataWriter buffer) {
buffer.writeUint8(ClientMessage.PasswordIdentifier);
final length = 5 + _hashedAuthString!.utf8Length;
buffer.writeUint32(length);
_applyStringToBuffer(_hashedAuthString!, buffer);
}
}

class QueryMessage extends ClientMessage {
final UTF8BackedString _queryString;

Expand All @@ -113,7 +87,7 @@ class QueryMessage extends ClientMessage {
buffer.writeUint8(ClientMessage.QueryIdentifier);
final length = 5 + _queryString.utf8Length;
buffer.writeUint32(length);
_applyStringToBuffer(_queryString, buffer);
_queryString.applyToBuffer(buffer);
}
}

Expand All @@ -131,8 +105,8 @@ class ParseMessage extends ClientMessage {
final length = 8 + _statement.utf8Length + _statementName.utf8Length;
buffer.writeUint32(length);
// Name of prepared statement
_applyStringToBuffer(_statementName, buffer);
_applyStringToBuffer(_statement, buffer); // Query string
_statementName.applyToBuffer(buffer);
_statement.applyToBuffer(buffer); // Query string
buffer.writeUint16(0);
}
}
Expand All @@ -149,7 +123,7 @@ class DescribeMessage extends ClientMessage {
final length = 6 + _statementName.utf8Length;
buffer.writeUint32(length);
buffer.writeUint8(83);
_applyStringToBuffer(_statementName, buffer); // Name of prepared statement
_statementName.applyToBuffer(buffer); // Name of prepared statement
}
}

Expand Down Expand Up @@ -193,7 +167,7 @@ class BindMessage extends ClientMessage {
// Name of portal - currently unnamed portal.
buffer.writeUint8(0);
// Name of prepared statement.
_applyStringToBuffer(_statementName, buffer);
_statementName.applyToBuffer(buffer);

// OK, if we have no specified types at all, we can use 0. If we have all specified types, we can use 1. If we have a mix, we have to individually
// call out each type.
Expand Down
9 changes: 5 additions & 4 deletions lib/src/connection.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import 'dart:io';
import 'dart:typed_data';

import 'package:buffer/buffer.dart';
import 'auth/auth.dart';

import 'client_messages.dart';
import 'execution_context.dart';
Expand Down Expand Up @@ -118,7 +119,6 @@ class PostgreSQLConnection extends Object
late int _processID;
// ignore: unused_field
late int _secretKey;
late List<int> _salt;

bool _hasConnectedPreviously = false;
late _PostgreSQLConnectionState _connectionState;
Expand All @@ -129,6 +129,8 @@ class PostgreSQLConnection extends Object
@override
PostgreSQLConnection get _connection => this;

Socket? get socket => _socket;

/// Establishes a connection with a PostgreSQL database.
///
/// This method will return a [Future] that completes when the connection is established. Queries can be executed
Expand Down Expand Up @@ -246,8 +248,7 @@ class PostgreSQLConnection extends Object
_connectionState = newState;
_connectionState.connection = this;

_connectionState = _connectionState.onEnter();
_connectionState.connection = this;
_transitionToState(_connectionState.onEnter());
}

Future _close([dynamic error, StackTrace? trace]) async {
Expand Down Expand Up @@ -292,7 +293,7 @@ class PostgreSQLConnection extends Object
originalSocket.listen((data) {
if (data.length != 1) {
sslCompleter.completeError(PostgreSQLException(
'Could not initalize SSL connection, received unknown byte stream.'));
'Could not initialize SSL connection, received unknown byte stream.'));
return;
}

Expand Down
Loading