30
30
import com .github .shyiko .mysql .binlog .event .deserialization .GtidEventDataDeserializer ;
31
31
import com .github .shyiko .mysql .binlog .event .deserialization .QueryEventDataDeserializer ;
32
32
import com .github .shyiko .mysql .binlog .event .deserialization .RotateEventDataDeserializer ;
33
- import com .github .shyiko .mysql .binlog .io .BufferedSocketInputStream ;
34
33
import com .github .shyiko .mysql .binlog .io .ByteArrayInputStream ;
35
34
import com .github .shyiko .mysql .binlog .jmx .BinaryLogClientMXBean ;
36
35
import com .github .shyiko .mysql .binlog .network .AuthenticationException ;
36
+ import com .github .shyiko .mysql .binlog .network .ClientCapabilities ;
37
+ import com .github .shyiko .mysql .binlog .network .DefaultSSLSocketFactory ;
38
+ import com .github .shyiko .mysql .binlog .network .DefaultSocketFactory ;
39
+ import com .github .shyiko .mysql .binlog .network .SSLMode ;
40
+ import com .github .shyiko .mysql .binlog .network .SSLSocketFactory ;
37
41
import com .github .shyiko .mysql .binlog .network .ServerException ;
38
42
import com .github .shyiko .mysql .binlog .network .SocketFactory ;
43
+ import com .github .shyiko .mysql .binlog .network .TLSHostnameVerifier ;
39
44
import com .github .shyiko .mysql .binlog .network .protocol .ErrorPacket ;
40
45
import com .github .shyiko .mysql .binlog .network .protocol .GreetingPacket ;
41
46
import com .github .shyiko .mysql .binlog .network .protocol .Packet ;
47
52
import com .github .shyiko .mysql .binlog .network .protocol .command .DumpBinaryLogGtidCommand ;
48
53
import com .github .shyiko .mysql .binlog .network .protocol .command .PingCommand ;
49
54
import com .github .shyiko .mysql .binlog .network .protocol .command .QueryCommand ;
55
+ import com .github .shyiko .mysql .binlog .network .protocol .command .SSLRequestCommand ;
50
56
51
57
import java .io .EOFException ;
52
58
import java .io .IOException ;
53
- import java .io .InputStream ;
54
59
import java .net .InetSocketAddress ;
55
60
import java .net .Socket ;
56
61
import java .net .SocketException ;
78
83
*/
79
84
public class BinaryLogClient implements BinaryLogClientMXBean {
80
85
81
- private static final SocketFactory DEFAULT_SOCKET_FACTORY = new SocketFactory () {
82
-
83
- @ Override
84
- public Socket createSocket () throws SocketException {
85
- return new Socket () {
86
-
87
- private InputStream inputStream ;
88
-
89
- @ Override
90
- public synchronized InputStream getInputStream () throws IOException {
91
- return inputStream != null ? inputStream :
92
- (inputStream = new BufferedSocketInputStream (super .getInputStream ()));
93
- }
94
- };
95
- }
96
- };
86
+ private static final SocketFactory DEFAULT_SOCKET_FACTORY = new DefaultSocketFactory ();
87
+ private static final SSLSocketFactory DEFAULT_SSL_SOCKET_FACTORY = new DefaultSSLSocketFactory ();
97
88
98
89
private final Logger logger = Logger .getLogger (getClass ().getName ());
99
90
@@ -108,6 +99,7 @@ public synchronized InputStream getInputStream() throws IOException {
108
99
private volatile String binlogFilename ;
109
100
private volatile long binlogPosition = 4 ;
110
101
private volatile long connectionId ;
102
+ private SSLMode sslMode = SSLMode .DISABLED ;
111
103
112
104
private volatile GtidSet gtidSet ;
113
105
private final Object gtidSetAccessLock = new Object ();
@@ -118,6 +110,7 @@ public synchronized InputStream getInputStream() throws IOException {
118
110
private final List <LifecycleListener > lifecycleListeners = new LinkedList <LifecycleListener >();
119
111
120
112
private SocketFactory socketFactory ;
113
+ private SSLSocketFactory sslSocketFactory ;
121
114
122
115
private PacketChannel channel ;
123
116
private volatile boolean connected ;
@@ -187,6 +180,17 @@ public void setBlocking(boolean blocking) {
187
180
this .blocking = blocking ;
188
181
}
189
182
183
+ public SSLMode getSSLMode () {
184
+ return sslMode ;
185
+ }
186
+
187
+ public void setSSLMode (SSLMode sslMode ) {
188
+ if (sslMode == null ) {
189
+ throw new IllegalArgumentException ("SSL mode cannot be NULL" );
190
+ }
191
+ this .sslMode = sslMode ;
192
+ }
193
+
190
194
/**
191
195
* @return server id (65535 by default)
192
196
* @see #setServerId(long)
@@ -347,6 +351,13 @@ public void setSocketFactory(SocketFactory socketFactory) {
347
351
this .socketFactory = socketFactory ;
348
352
}
349
353
354
+ /**
355
+ * @param sslSocketFactory custom ssl socket factory
356
+ */
357
+ public void setSslSocketFactory (SSLSocketFactory sslSocketFactory ) {
358
+ this .sslSocketFactory = sslSocketFactory ;
359
+ }
360
+
350
361
/**
351
362
* @param threadFactory custom thread factory. If not provided, threads will be created using simple "new Thread()".
352
363
*/
@@ -367,7 +378,7 @@ public void connect() throws IOException {
367
378
try {
368
379
establishConnection ();
369
380
GreetingPacket greetingPacket = receiveGreeting ();
370
- authenticate (greetingPacket . getScramble (), greetingPacket . getServerCollation () );
381
+ authenticate (greetingPacket );
371
382
connectionId = greetingPacket .getThreadId ();
372
383
if (binlogFilename == null && gtidSet == null ) {
373
384
autoPosition ();
@@ -474,10 +485,29 @@ private void ensureEventDataDeserializer(EventType eventType,
474
485
}
475
486
}
476
487
477
- private void authenticate (String salt , int collation ) throws IOException {
478
- AuthenticateCommand authenticateCommand = new AuthenticateCommand (schema , username , password , salt );
488
+ private void authenticate (GreetingPacket greetingPacket ) throws IOException {
489
+ int collation = greetingPacket .getServerCollation ();
490
+ int packetNumber = 1 ;
491
+ if (sslMode != SSLMode .DISABLED ) {
492
+ boolean serverSupportsSSL = (greetingPacket .getServerCapabilities () & ClientCapabilities .SSL ) != 0 ;
493
+ if (!serverSupportsSSL && (sslMode == SSLMode .REQUIRED || sslMode == SSLMode .VERIFY_CA ||
494
+ sslMode == SSLMode .VERIFY_IDENTITY )) {
495
+ throw new IOException ("MySQL server does not support SSL" );
496
+ }
497
+ if (serverSupportsSSL ) {
498
+ SSLRequestCommand sslRequestCommand = new SSLRequestCommand ();
499
+ sslRequestCommand .setCollation (collation );
500
+ channel .write (sslRequestCommand , packetNumber ++);
501
+ SSLSocketFactory sslSocketFactory = this .sslSocketFactory != null ? this .sslSocketFactory :
502
+ DEFAULT_SSL_SOCKET_FACTORY ;
503
+ channel .upgradeToSSL (sslSocketFactory ,
504
+ sslMode == SSLMode .VERIFY_IDENTITY ? new TLSHostnameVerifier () : null );
505
+ }
506
+ }
507
+ AuthenticateCommand authenticateCommand = new AuthenticateCommand (schema , username , password ,
508
+ greetingPacket .getScramble ());
479
509
authenticateCommand .setCollation (collation );
480
- channel .write (authenticateCommand );
510
+ channel .write (authenticateCommand , packetNumber );
481
511
byte [] authenticationResult = channel .read ();
482
512
if (authenticationResult [0 ] == ErrorPacket .HEADER ) {
483
513
ErrorPacket errorPacket = new ErrorPacket (authenticationResult , 1 );
0 commit comments