Skip to content

Issue 233 - Add :local_infile option and refactor mysql_options code #252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 1, 2012
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,27 @@ results.each(:as => :array) do |row|
end
```

## Connection options

You may set the following connection options in Mysql2::Client.new(...):

``` ruby
Mysql2::Client.new(
:host,
:username,
:password,
:port,
:database,
:socket = '/path/to/mysql.sock',
:flags = REMEMBER_OPTIONS | LONG_PASSWORD | LONG_FLAG | TRANSACTIONS | PROTOCOL_41 | SECURE_CONNECTION | MULTI_STATEMENTS,
:encoding = 'utf8',
:read_timeout = seconds,
:connect_timeout = seconds,
:reconnect = true/false,
:local_infile = true/false,
)
```

You can also retrieve multiple result sets. For this to work you need to connect with
flags `Mysql2::Client::MULTI_STATEMENTS`. Using multiple result sets is normally used
when calling stored procedures that return more than one result set
Expand Down
141 changes: 94 additions & 47 deletions ext/mysql2/client.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,23 @@ static VALUE intern_encoding_from_charset;
static VALUE sym_id, sym_version, sym_async, sym_symbolize_keys, sym_as, sym_array, sym_stream;
static ID intern_merge, intern_error_number_eql, intern_sql_state_eql;

#define REQUIRE_OPEN_DB(wrapper) \
if(!wrapper->reconnect_enabled && wrapper->closed) { \
#define REQUIRE_INITIALIZED(wrapper) \
if (!wrapper->initialized) { \
rb_raise(cMysql2Error, "MySQL client is not initialized"); \
}

#define REQUIRE_CONNECTED(wrapper) \
REQUIRE_INITIALIZED(wrapper) \
if (!wrapper->connected && !wrapper->reconnect_enabled) { \
rb_raise(cMysql2Error, "closed MySQL connection"); \
}

#define REQUIRE_NOT_CONNECTED(wrapper) \
REQUIRE_INITIALIZED(wrapper) \
if (wrapper->connected) { \
rb_raise(cMysql2Error, "MySQL connection is already open"); \
}

#define MARK_CONN_INACTIVE(conn) \
wrapper->active_thread = Qnil;

Expand Down Expand Up @@ -138,9 +150,9 @@ static VALUE nogvl_close(void *ptr) {
int flags;
#endif
wrapper = ptr;
if (!wrapper->closed) {
wrapper->closed = 1;
if (wrapper->connected) {
wrapper->active_thread = Qnil;
wrapper->connected = 0;
/*
* we'll send a QUIT message to the server, but that message is more of a
* formality than a hard requirement since the socket is getting shutdown
Expand Down Expand Up @@ -178,7 +190,8 @@ static VALUE allocate(VALUE klass) {
wrapper->encoding = Qnil;
wrapper->active_thread = Qnil;
wrapper->reconnect_enabled = 0;
wrapper->closed = 1;
wrapper->connected = 0; // means that a database connection is open
wrapper->initialized = 0; // means that that the wrapper is initialized
wrapper->client = (MYSQL*)xmalloc(sizeof(MYSQL));
return obj;
}
Expand Down Expand Up @@ -232,6 +245,7 @@ static VALUE rb_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VALUE po
return rb_raise_mysql2_error(wrapper);
}

wrapper->connected = 1;
return self;
}

Expand All @@ -244,7 +258,7 @@ static VALUE rb_connect(VALUE self, VALUE user, VALUE pass, VALUE host, VALUE po
static VALUE rb_mysql_client_close(VALUE self) {
GET_CLIENT(self);

if (!wrapper->closed) {
if (wrapper->connected) {
rb_thread_blocking_region(nogvl_close, wrapper, RUBY_UBF_IO, 0);
}

Expand Down Expand Up @@ -332,7 +346,7 @@ static VALUE rb_mysql_client_async_result(VALUE self) {
if (NIL_P(wrapper->active_thread))
return Qnil;

REQUIRE_OPEN_DB(wrapper);
REQUIRE_CONNECTED(wrapper);
if (rb_thread_blocking_region(nogvl_read_query_result, wrapper->client, RUBY_UBF_IO, 0) == Qfalse) {
// an error occurred, mark this connection inactive
MARK_CONN_INACTIVE(self);
Expand Down Expand Up @@ -375,8 +389,8 @@ struct async_query_args {
static VALUE disconnect_and_raise(VALUE self, VALUE error) {
GET_CLIENT(self);

wrapper->closed = 1;
wrapper->active_thread = Qnil;
wrapper->connected = 0;

// manually close the socket for read/write
// this feels dirty, but is there another way?
Expand Down Expand Up @@ -473,7 +487,7 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) {
#endif
GET_CLIENT(self);

REQUIRE_OPEN_DB(wrapper);
REQUIRE_CONNECTED(wrapper);
args.mysql = wrapper->client;


Expand Down Expand Up @@ -550,7 +564,7 @@ static VALUE rb_mysql_client_real_escape(VALUE self, VALUE str) {
#endif
GET_CLIENT(self);

REQUIRE_OPEN_DB(wrapper);
REQUIRE_CONNECTED(wrapper);
Check_Type(str, T_STRING);
#ifdef HAVE_RUBY_ENCODING_H
default_internal_enc = rb_default_internal_encoding();
Expand Down Expand Up @@ -580,6 +594,59 @@ static VALUE rb_mysql_client_real_escape(VALUE self, VALUE str) {
}
}

static VALUE _mysql_client_options(VALUE self, int opt, VALUE value) {
int result;
void *retval = NULL;
unsigned int intval = 0;
my_bool boolval;

GET_CLIENT(self);

REQUIRE_NOT_CONNECTED(wrapper);

if (NIL_P(value))
return Qfalse;

switch(opt) {
case MYSQL_OPT_CONNECT_TIMEOUT:
intval = NUM2INT(value);
retval = &intval;
break;

case MYSQL_OPT_LOCAL_INFILE:
intval = (value == Qfalse ? 0 : 1);
retval = &intval;
break;

case MYSQL_OPT_RECONNECT:
boolval = (value == Qfalse ? 0 : 1);
retval = &boolval;
break;

default:
return Qfalse;
}

result = mysql_options(wrapper->client, opt, retval);

// Zero means success
if (result != 0) {
rb_warn("%s\n", mysql_error(wrapper->client));
} else {
// Special case for reconnect, this option is also stored in the wrapper struct
if (opt == MYSQL_OPT_RECONNECT)
wrapper->reconnect_enabled = boolval;
}

return (result == 0) ? Qtrue : Qfalse;
}

static VALUE rb_mysql_client_options(VALUE self, VALUE option, VALUE value) {
Check_Type(option, T_FIXNUM);
int opt = NUM2INT(option);
return _mysql_client_options(self, opt, value);
}

/* call-seq:
* client.info
*
Expand Down Expand Up @@ -624,7 +691,7 @@ static VALUE rb_mysql_client_server_info(VALUE self) {
#endif
GET_CLIENT(self);

REQUIRE_OPEN_DB(wrapper);
REQUIRE_CONNECTED(wrapper);
#ifdef HAVE_RUBY_ENCODING_H
default_internal_enc = rb_default_internal_encoding();
conn_enc = rb_to_encoding(wrapper->encoding);
Expand All @@ -651,7 +718,7 @@ static VALUE rb_mysql_client_server_info(VALUE self) {
static VALUE rb_mysql_client_socket(VALUE self) {
GET_CLIENT(self);
#ifndef _WIN32
REQUIRE_OPEN_DB(wrapper);
REQUIRE_CONNECTED(wrapper);
int fd_set_fd = wrapper->client->net.fd;
return INT2NUM(fd_set_fd);
#else
Expand All @@ -667,7 +734,7 @@ static VALUE rb_mysql_client_socket(VALUE self) {
*/
static VALUE rb_mysql_client_last_id(VALUE self) {
GET_CLIENT(self);
REQUIRE_OPEN_DB(wrapper);
REQUIRE_CONNECTED(wrapper);
return ULL2NUM(mysql_insert_id(wrapper->client));
}

Expand All @@ -681,7 +748,7 @@ static VALUE rb_mysql_client_affected_rows(VALUE self) {
my_ulonglong retVal;
GET_CLIENT(self);

REQUIRE_OPEN_DB(wrapper);
REQUIRE_CONNECTED(wrapper);
retVal = mysql_affected_rows(wrapper->client);
if (retVal == (my_ulonglong)-1) {
rb_raise_mysql2_error(wrapper);
Expand All @@ -698,7 +765,7 @@ static VALUE rb_mysql_client_thread_id(VALUE self) {
unsigned long retVal;
GET_CLIENT(self);

REQUIRE_OPEN_DB(wrapper);
REQUIRE_CONNECTED(wrapper);
retVal = mysql_thread_id(wrapper->client);
return ULL2NUM(retVal);
}
Expand All @@ -723,7 +790,7 @@ static VALUE rb_mysql_client_select_db(VALUE self, VALUE db)
struct nogvl_select_db_args args;

GET_CLIENT(self);
REQUIRE_OPEN_DB(wrapper);
REQUIRE_CONNECTED(wrapper);

args.mysql = wrapper->client;
args.db = StringValuePtr(db);
Expand Down Expand Up @@ -751,7 +818,7 @@ static VALUE nogvl_ping(void *ptr) {
static VALUE rb_mysql_client_ping(VALUE self) {
GET_CLIENT(self);

if (wrapper->closed) {
if (!wrapper->connected) {
return Qfalse;
} else {
return rb_thread_blocking_region(nogvl_ping, wrapper->client, RUBY_UBF_IO, 0);
Expand Down Expand Up @@ -829,37 +896,15 @@ static VALUE rb_mysql_client_encoding(VALUE self) {
#endif

static VALUE set_reconnect(VALUE self, VALUE value) {
my_bool reconnect;
GET_CLIENT(self);

if(!NIL_P(value)) {
reconnect = value == Qfalse ? 0 : 1;
return _mysql_client_options(self, MYSQL_OPT_RECONNECT, value);
}

wrapper->reconnect_enabled = reconnect;
/* set default reconnect behavior */
if (mysql_options(wrapper->client, MYSQL_OPT_RECONNECT, &reconnect)) {
/* TODO: warning - unable to set reconnect behavior */
rb_warn("%s\n", mysql_error(wrapper->client));
}
}
return value;
static VALUE set_local_infile(VALUE self, VALUE value) {
return _mysql_client_options(self, MYSQL_OPT_LOCAL_INFILE, value);
}

static VALUE set_connect_timeout(VALUE self, VALUE value) {
unsigned int connect_timeout = 0;
GET_CLIENT(self);

if(!NIL_P(value)) {
connect_timeout = NUM2INT(value);
if(0 == connect_timeout) return value;

/* set default connection timeout behavior */
if (mysql_options(wrapper->client, MYSQL_OPT_CONNECT_TIMEOUT, &connect_timeout)) {
/* TODO: warning - unable to set connection timeout */
rb_warn("%s\n", mysql_error(wrapper->client));
}
}
return value;
return _mysql_client_options(self, MYSQL_OPT_CONNECT_TIMEOUT, value);
}

static VALUE set_charset_name(VALUE self, VALUE value) {
Expand Down Expand Up @@ -906,15 +951,15 @@ static VALUE set_ssl_options(VALUE self, VALUE key, VALUE cert, VALUE ca, VALUE
return self;
}

static VALUE init_connection(VALUE self) {
static VALUE initialize_ext(VALUE self) {
GET_CLIENT(self);

if (rb_thread_blocking_region(nogvl_init, wrapper->client, RUBY_UBF_IO, 0) == Qfalse) {
/* TODO: warning - not enough memory? */
return rb_raise_mysql2_error(wrapper);
}

wrapper->closed = 0;
wrapper->initialized = 1;
return self;
}

Expand Down Expand Up @@ -960,15 +1005,17 @@ void init_mysql2_client() {
rb_define_method(cMysql2Client, "more_results", rb_mysql_client_more_results, 0);
rb_define_method(cMysql2Client, "next_result", rb_mysql_client_next_result, 0);
rb_define_method(cMysql2Client, "store_result", rb_mysql_client_store_result, 0);
rb_define_method(cMysql2Client, "options", rb_mysql_client_options, 2);
#ifdef HAVE_RUBY_ENCODING_H
rb_define_method(cMysql2Client, "encoding", rb_mysql_client_encoding, 0);
#endif

rb_define_private_method(cMysql2Client, "reconnect=", set_reconnect, 1);
rb_define_private_method(cMysql2Client, "connect_timeout=", set_connect_timeout, 1);
rb_define_private_method(cMysql2Client, "local_infile=", set_local_infile, 1);
rb_define_private_method(cMysql2Client, "charset_name=", set_charset_name, 1);
rb_define_private_method(cMysql2Client, "ssl_set", set_ssl_options, 5);
rb_define_private_method(cMysql2Client, "init_connection", init_connection, 0);
rb_define_private_method(cMysql2Client, "initialize_ext", initialize_ext, 0);
rb_define_private_method(cMysql2Client, "connect", rb_connect, 7);

intern_encoding_from_charset = rb_intern("encoding_from_charset");
Expand Down
4 changes: 3 additions & 1 deletion ext/mysql2/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ typedef struct {
VALUE encoding;
VALUE active_thread; /* rb_thread_current() or Qnil */
int reconnect_enabled;
int closed;
int active;
int connected;
int initialized;
MYSQL *client;
} mysql_client_wrapper;

Expand Down
6 changes: 4 additions & 2 deletions lib/mysql2/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ def initialize(opts = {})
@query_options = @@default_query_options.dup
@query_options.merge! opts

init_connection
initialize_ext

[:reconnect, :connect_timeout].each do |key|
# Set MySQL connection options (each one is a call to mysql_options())
[:reconnect, :connect_timeout, :local_infile].each do |key|
next unless opts.key?(key)
send(:"#{key}=", opts[key])
end

# force the encoding to utf8
self.charset_name = opts[:encoding] || 'utf8'

Expand Down
10 changes: 9 additions & 1 deletion spec/mysql2/client_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,16 @@ def connect *args
result = @client.async_result
result.class.should eql(Mysql2::Result)
end

it "should not allow options to be set on an open connection" do
lambda {
@client.escape ""
@client.query("SELECT 1")
@client.options(0, 0)
}.should raise_error(Mysql2::Error)
end
end

context "Multiple results sets" do
before(:each) do
@multi_client = Mysql2::Client.new( :flags => Mysql2::Client::MULTI_STATEMENTS)
Expand Down