From bdb1b0e2ac5c68aec4214d08925db1a20e9d63ac Mon Sep 17 00:00:00 2001 From: Sonny Piers Date: Thu, 2 Jan 2020 20:44:32 +0100 Subject: [PATCH] starttls: only upgrade net.Socket (#809) --- .eslintrc.yaml | 2 +- packages/connection/index.js | 25 +++------------- .../test/promise.test.js} | 12 ++++---- packages/resolve/index.js | 5 ++-- packages/resolve/package.json | 2 +- packages/starttls/client.js | 30 ++++++++----------- packages/starttls/package.json | 1 + packages/starttls/starttls.js | 20 +++++++++++++ packages/starttls/starttls.test.js | 13 ++++++++ packages/starttls/test.js | 15 ++++++++-- test/see-other-host.js | 2 +- 11 files changed, 75 insertions(+), 52 deletions(-) rename packages/{connection/test/socketConnect.js => events/test/promise.test.js} (80%) create mode 100644 packages/starttls/starttls.js create mode 100644 packages/starttls/starttls.test.js diff --git a/.eslintrc.yaml b/.eslintrc.yaml index 60154298b..010332834 100644 --- a/.eslintrc.yaml +++ b/.eslintrc.yaml @@ -28,7 +28,7 @@ rules: operator-linebreak: [error, after, {overrides: {'?': before, ':': 'before'}}] capitalized-comments: [ - error, + warn, always, {ignorePattern: prettier-ignore, ignoreConsecutiveComments: true}, ] diff --git a/packages/connection/index.js b/packages/connection/index.js index c14d4f1c7..50ae188b4 100644 --- a/packages/connection/index.js +++ b/packages/connection/index.js @@ -8,24 +8,6 @@ const {parseHost, parseService} = require('./lib/util') const NS_STREAM = 'urn:ietf:params:xml:ns:xmpp-streams' -function socketConnect(socket, ...params) { - return new Promise((resolve, reject) => { - function onError(err) { - socket.removeListener('connect', onConnect) - reject(err) - } - - function onConnect(value) { - socket.removeListener('error', onError) - resolve(value) - } - - socket.once('error', onError) - socket.once('connect', onConnect) - socket.connect(...params) - }) -} - class Connection extends EventEmitter { constructor(options = {}) { super() @@ -241,9 +223,11 @@ class Connection extends EventEmitter { */ async connect(service) { this._status('connecting', service) - this._attachSocket(new this.Socket()) + const socket = new this.Socket() + this._attachSocket(socket) // The 'connect' status is set by the socket 'connect' listener - return socketConnect(this.socket, this.socketParameters(service)) + socket.connect(this.socketParameters(service)) + return promise(socket, 'connect') } /** @@ -397,4 +381,3 @@ Connection.prototype.Socket = null Connection.prototype.Parser = null module.exports = Connection -module.exports.socketConnect = socketConnect diff --git a/packages/connection/test/socketConnect.js b/packages/events/test/promise.test.js similarity index 80% rename from packages/connection/test/socketConnect.js rename to packages/events/test/promise.test.js index b4d467e4c..e41988577 100644 --- a/packages/connection/test/socketConnect.js +++ b/packages/events/test/promise.test.js @@ -1,7 +1,7 @@ 'use strict' const test = require('ava') -const {socketConnect} = require('..') +const {promise} = require('..') const EventEmitter = require('events') class Socket extends EventEmitter { @@ -17,14 +17,15 @@ class Socket extends EventEmitter { } } -test('resolves if "connect" is emitted', async t => { +test('resolves if "event" is emitted', async t => { const value = {} const socket = new Socket(function() { this.emit('connect', value) }) t.is(socket.listenerCount('error'), 0) t.is(socket.listenerCount('connect'), 0) - const p = socketConnect(socket, 'foo') + socket.connect() + const p = promise(socket, 'connect') t.is(socket.listenerCount('error'), 1) t.is(socket.listenerCount('connect'), 1) const result = await p @@ -33,14 +34,15 @@ test('resolves if "connect" is emitted', async t => { t.is(socket.listenerCount('connect'), 0) }) -test('rejects if "error" is emitted', t => { +test('rejects if "errorEvent" is emitted', t => { const error = new Error('foobar') const socket = new Socket(function() { this.emit('error', error) }) t.is(socket.listenerCount('error'), 0) t.is(socket.listenerCount('connect'), 0) - const p = socketConnect(socket, 'foo') + socket.connect() + const p = promise(socket, 'connect', 'error') t.is(socket.listenerCount('error'), 1) t.is(socket.listenerCount('connect'), 1) return p.catch(err => { diff --git a/packages/resolve/index.js b/packages/resolve/index.js index 157649d9d..8c7799c66 100644 --- a/packages/resolve/index.js +++ b/packages/resolve/index.js @@ -1,7 +1,7 @@ 'use strict' const resolve = require('./resolve') -const {socketConnect} = require('@xmpp/connection') +const {promise} = require('@xmpp/events') async function fetchURIs(domain) { return [ @@ -44,7 +44,8 @@ async function fallbackConnect(entity, uris) { const socket = new Transport.prototype.Socket() try { - await socketConnect(socket, params) + socket.connect(params) + await promise(socket, 'connect') // eslint-disable-next-line no-unused-vars } catch (err) { return fallbackConnect(entity, uris) diff --git a/packages/resolve/package.json b/packages/resolve/package.json index dc43cb429..02b1f4a23 100644 --- a/packages/resolve/package.json +++ b/packages/resolve/package.json @@ -21,7 +21,7 @@ "node-fetch": false }, "dependencies": { - "@xmpp/connection": "^0.9.1", + "@xmpp/events": "^0.9.0", "@xmpp/xml": "^0.9.1", "node-fetch": "^2.3.0" }, diff --git a/packages/starttls/client.js b/packages/starttls/client.js index 114f6c052..73cd8c995 100644 --- a/packages/starttls/client.js +++ b/packages/starttls/client.js @@ -1,7 +1,7 @@ 'use strict' const xml = require('@xmpp/xml') -const tls = require('tls') +const {canUpgrade, upgrade} = require('./starttls') /* * References @@ -10,20 +10,7 @@ const tls = require('tls') const NS = 'urn:ietf:params:xml:ns:xmpp-tls' -function proceed(entity, options = {}) { - return new Promise((resolve, reject) => { - const tlsSocket = tls.connect( - {socket: entity._detachSocket(), host: entity.options.domain, ...options}, - err => { - if (err) return reject(err) - entity._attachSocket(tlsSocket) - resolve() - } - ) - }) -} - -async function starttls(entity) { +async function negotiate(entity) { const element = await entity.sendReceive(xml('starttls', {xmlns: NS})) if (element.is('proceed', NS)) { return element @@ -33,9 +20,16 @@ async function starttls(entity) { } module.exports = function({streamFeatures}) { - return streamFeatures.use('starttls', NS, async ({entity}) => { - await starttls(entity) - await proceed(entity) + return streamFeatures.use('starttls', NS, async ({entity}, next) => { + const {socket} = entity + if (!canUpgrade(socket)) { + return next() + } + + await negotiate(entity) + const tlsSocket = await upgrade(socket, {host: entity.options.domain}) + entity._attachSocket(tlsSocket) + await entity.restart() }) } diff --git a/packages/starttls/package.json b/packages/starttls/package.json index 2bb2fd881..9d28ea4c8 100644 --- a/packages/starttls/package.json +++ b/packages/starttls/package.json @@ -11,6 +11,7 @@ "STARTTLS" ], "dependencies": { + "@xmpp/events": "^0.9.0", "@xmpp/xml": "^0.9.1" }, "engines": { diff --git a/packages/starttls/starttls.js b/packages/starttls/starttls.js new file mode 100644 index 000000000..0d51e1220 --- /dev/null +++ b/packages/starttls/starttls.js @@ -0,0 +1,20 @@ +'use strict' + +const tls = require('tls') +const net = require('net') +const {promise} = require('@xmpp/events') + +function canUpgrade(socket) { + return socket instanceof net.Socket && !(socket instanceof tls.TLSSocket) +} + +module.exports.canUpgrade = canUpgrade + +async function upgrade(socket, options = {}) { + const tlsSocket = tls.connect({socket, ...options}) + await promise(tlsSocket, 'secureConnect') + + return tlsSocket +} + +module.exports.upgrade = upgrade diff --git a/packages/starttls/starttls.test.js b/packages/starttls/starttls.test.js new file mode 100644 index 000000000..ed21f09fd --- /dev/null +++ b/packages/starttls/starttls.test.js @@ -0,0 +1,13 @@ +'use strict' + +const test = require('ava') +const tls = require('tls') +const {canUpgrade} = require('./starttls') +const net = require('net') +const WebSocket = require('../websocket/lib/Socket') + +test('canUpgrade', t => { + t.is(canUpgrade(new WebSocket()), false) + t.is(canUpgrade(new tls.TLSSocket()), false) + t.is(canUpgrade(new net.Socket()), true) +}) diff --git a/packages/starttls/test.js b/packages/starttls/test.js index 22fafa710..6fe944b12 100644 --- a/packages/starttls/test.js +++ b/packages/starttls/test.js @@ -4,9 +4,18 @@ const {mock, stub} = require('sinon') const test = require('ava') const {mockClient, promise, delay} = require('@xmpp/test') const tls = require('tls') +const net = require('net') +const EventEmitter = require('events') + +function mockSocket() { + const socket = new net.Socket() + socket.write = (data, cb) => cb() + return socket +} test('success', async t => { const {entity} = mockClient() + entity.socket = mockSocket() const {socket} = entity const host = (entity.options.domain = 'foobar') @@ -15,9 +24,8 @@ test('success', async t => { .expects('connect') .once() .withArgs({socket, host}) - .callsFake((options, callback) => { - process.nextTick(callback) - return {} + .callsFake(() => { + return new EventEmitter() }) stub(entity, '_attachSocket') @@ -43,6 +51,7 @@ test('success', async t => { test('failure', async t => { const {entity} = mockClient() + entity.socket = mockSocket() entity.mockInput( diff --git a/test/see-other-host.js b/test/see-other-host.js index 4a4488cc0..ca6b8b759 100644 --- a/test/see-other-host.js +++ b/test/see-other-host.js @@ -23,7 +23,7 @@ test.afterEach(t => { } }) -test.serial.only('see-other-host', async t => { +test.serial('see-other-host', async t => { const net = require('net') const Connection = require('../packages/connection-tcp') const {promise} = require('../packages/events')