Skip to content

Commit 3bb6c86

Browse files
committed
ruby: Add TLS support to the benchmark script
1 parent d0531c4 commit 3bb6c86

File tree

4 files changed

+79
-8
lines changed

4 files changed

+79
-8
lines changed

lib/rb/benchmark/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.pem

lib/rb/benchmark/benchmark.rb

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@ def initialize(opts)
4040
@interpreter = opts.fetch(:interpreter, "ruby")
4141
@host = opts.fetch(:host, ::HOST)
4242
@port = opts.fetch(:port, ::PORT)
43+
@tls = opts.fetch(:tls, false)
4344
end
4445

4546
def start
4647
return if @serverclass == Object
4748
args = (File.basename(@interpreter) == "jruby" ? "-J-server" : "")
48-
@pipe = IO.popen("#{@interpreter} #{args} #{File.dirname(__FILE__)}/server.rb #{@host} #{@port} #{@serverclass.name}", "r+")
49+
@pipe = IO.popen("#{@interpreter} #{args} #{File.dirname(__FILE__)}/server.rb #{"-tls" if @tls} #{@host} #{@port} #{@serverclass.name}", "r+")
4950
Marshal.load(@pipe) # wait until the server has started
5051
sleep 0.4 # give the server time to actually start spawning sockets
5152
end
@@ -75,6 +76,7 @@ def initialize(opts, server)
7576
@interpreter = opts.fetch(:interpreter, "ruby")
7677
@server = server
7778
@log_exceptions = opts.fetch(:log_exceptions, false)
79+
@tls = opts.fetch(:tls, false)
7880
end
7981

8082
def run
@@ -93,7 +95,7 @@ def run
9395
end
9496

9597
def spawn
96-
pipe = IO.popen("#{@interpreter} #{File.dirname(__FILE__)}/client.rb #{"-log-exceptions" if @log_exceptions} #{@host} #{@port} #{@clients_per_process} #{@calls_per_client}")
98+
pipe = IO.popen("#{@interpreter} #{File.dirname(__FILE__)}/client.rb #{"-log-exceptions" if @log_exceptions} #{"-tls" if @tls} #{@host} #{@port} #{@clients_per_process} #{@calls_per_client}")
9799
@pool << pipe
98100
end
99101

@@ -249,18 +251,53 @@ def resolve_const(const)
249251
const and const.split('::').inject(Object) { |k,c| k.const_get(c) }
250252
end
251253

254+
def generate_certificate
255+
key = OpenSSL::PKey::EC.generate("prime256v1")
256+
257+
cert = OpenSSL::X509::Certificate.new
258+
cert.version = 2
259+
cert.serial = 1
260+
cert.subject = OpenSSL::X509::Name.parse("/C=US/O=Benchmark/CN=localhost")
261+
cert.issuer = cert.subject
262+
cert.public_key = key
263+
cert.not_before = Time.now
264+
cert.not_after = Time.now + 3600
265+
266+
# Add extensions
267+
ef = OpenSSL::X509::ExtensionFactory.new
268+
ef.subject_certificate = cert
269+
ef.issuer_certificate = cert
270+
cert.add_extension(ef.create_extension("basicConstraints", "CA:TRUE", true))
271+
cert.add_extension(ef.create_extension("subjectAltName", "DNS:localhost,IP:127.0.0.1", false))
272+
273+
cert.sign(key, OpenSSL::Digest.new("SHA256"))
274+
275+
[cert, key]
276+
end
277+
278+
if ENV['THRIFT_TLS']
279+
puts "Generating TLS certificate and key..."
280+
require 'openssl'
281+
282+
cert, key = generate_certificate
283+
File.write(File.expand_path("cert.pem", __dir__), cert.to_pem)
284+
File.write(File.expand_path("key.pem", __dir__), key.to_pem)
285+
end
286+
252287
puts "Starting server..."
253288
args = {}
254289
args[:interpreter] = ENV['THRIFT_SERVER_INTERPRETER'] || ENV['THRIFT_INTERPRETER'] || "ruby"
255290
args[:class] = resolve_const(ENV['THRIFT_SERVER']) || Thrift::NonblockingServer
256291
args[:host] = ENV['THRIFT_HOST'] || HOST
257292
args[:port] = (ENV['THRIFT_PORT'] || PORT).to_i
293+
args[:tls] = ENV['THRIFT_TLS'] == 'true'
258294
server = Server.new(args)
259295
server.start
260296

261297
args = {}
262298
args[:host] = ENV['THRIFT_HOST'] || HOST
263299
args[:port] = (ENV['THRIFT_PORT'] || PORT).to_i
300+
args[:tls] = ENV['THRIFT_TLS'] == 'true'
264301
args[:num_processes] = (ENV['THRIFT_NUM_PROCESSES'] || 40).to_i
265302
args[:clients_per_process] = (ENV['THRIFT_NUM_CLIENTS'] || 5).to_i
266303
args[:calls_per_client] = (ENV['THRIFT_NUM_CALLS'] || 50).to_i

lib/rb/benchmark/client.rb

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,33 @@
2323
require 'benchmark_service'
2424

2525
class Client
26-
def initialize(host, port, clients_per_process, calls_per_client, log_exceptions)
26+
def initialize(host, port, clients_per_process, calls_per_client, log_exceptions, tls)
2727
@host = host
2828
@port = port
2929
@clients_per_process = clients_per_process
3030
@calls_per_client = calls_per_client
3131
@log_exceptions = log_exceptions
32+
@tls = tls
3233
end
3334

3435
def run
3536
@clients_per_process.times do
36-
socket = Thrift::Socket.new(@host, @port)
37+
socket = if @tls
38+
ssl_context = OpenSSL::SSL::SSLContext.new.tap do |ctx|
39+
ctx.verify_mode = OpenSSL::SSL::VERIFY_NONE
40+
ctx.min_version = OpenSSL::SSL::TLS1_2_VERSION
41+
42+
# Load certificate chain and private key
43+
certs = OpenSSL::X509::Certificate.load_file(File.expand_path("cert.pem", __dir__))
44+
pkey = OpenSSL::PKey.read(File.binread(File.expand_path("key.pem", __dir__)))
45+
ctx.add_certificate(certs.first, pkey, *certs[1..])
46+
47+
ctx
48+
end
49+
Thrift::SSLSocket.new(@host, @port, nil, ssl_context)
50+
else
51+
Thrift::Socket.new(@host, @port)
52+
end
3753
transport = Thrift::FramedTransport.new(socket)
3854
protocol = Thrift::BinaryProtocol.new(transport)
3955
client = ThriftBenchmark::BenchmarkService::Client.new(protocol)
@@ -68,7 +84,8 @@ def print_exception(e)
6884
end
6985

7086
log_exceptions = true if ARGV[0] == '-log-exceptions' and ARGV.shift
87+
tls = true if ARGV[0] == '-tls' and ARGV.shift
7188

7289
host, port, clients_per_process, calls_per_client = ARGV
7390

74-
Client.new(host, port.to_i, clients_per_process.to_i, calls_per_client.to_i, log_exceptions).run
91+
Client.new(host, port.to_i, clients_per_process.to_i, calls_per_client.to_i, log_exceptions, tls).run

lib/rb/benchmark/server.rb

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,24 @@ def fibonacci(n)
3636
end
3737
end
3838

39-
def self.start_server(host, port, serverClass)
39+
def self.start_server(host, port, serverClass, tls)
4040
handler = BenchmarkHandler.new
4141
processor = ThriftBenchmark::BenchmarkService::Processor.new(handler)
42-
transport = ServerSocket.new(host, port)
42+
transport = if tls
43+
ssl_context = OpenSSL::SSL::SSLContext.new.tap do |ctx|
44+
ctx.verify_mode = OpenSSL::SSL::VERIFY_NONE
45+
ctx.min_version = OpenSSL::SSL::TLS1_2_VERSION
46+
47+
certs = OpenSSL::X509::Certificate.load_file(File.expand_path("cert.pem", __dir__))
48+
pkey = OpenSSL::PKey.read(File.binread(File.expand_path("key.pem", __dir__)))
49+
ctx.add_certificate(certs.first, pkey, *certs[1..])
50+
51+
ctx
52+
end
53+
Thrift::SSLServerSocket.new(host, port, ssl_context)
54+
else
55+
ServerSocket.new(host, port)
56+
end
4357
transport_factory = FramedTransportFactory.new
4458
args = [processor, transport, transport_factory, nil, 20]
4559
if serverClass == NonblockingServer
@@ -68,9 +82,11 @@ def resolve_const(const)
6882
const and const.split('::').inject(Object) { |k,c| k.const_get(c) }
6983
end
7084

85+
tls = true if ARGV[0] == '-tls' and ARGV.shift
86+
7187
host, port, serverklass = ARGV
7288

73-
Server.start_server(host, port.to_i, resolve_const(serverklass))
89+
Server.start_server(host, port.to_i, resolve_const(serverklass), tls)
7490

7591
# let our host know that the interpreter has started
7692
# ideally we'd wait until the server was serving, but we don't have a hook for that

0 commit comments

Comments
 (0)