Skip to content

Commit 332cdd0

Browse files
committed
Support "ipaddress" module objects in extension
1 parent a609d64 commit 332cdd0

File tree

3 files changed

+76
-5
lines changed

3 files changed

+76
-5
lines changed

extension/maxminddb.c

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,64 @@ static int Reader_init(PyObject *self, PyObject *args, PyObject *kwds)
107107
static PyObject *Reader_get(PyObject *self, PyObject *args)
108108
{
109109
char *ip_address = NULL;
110+
PyObject *object = NULL;
111+
#if PY_MAJOR_VERSION > 2
112+
PyObject *bytes = NULL;
113+
#endif
110114

111115
Reader_obj *mmdb_obj = (Reader_obj *)self;
112-
if (!PyArg_ParseTuple(args, "s", &ip_address)) {
116+
117+
if (PyArg_ParseTuple(args, "s", &ip_address)) {
118+
// pass
119+
} else if (PyErr_Clear(), PyArg_ParseTuple(args, "O", &object)) {
120+
PyObject* module = PyImport_ImportModule("ipaddress");
121+
if (module == NULL) {
122+
return NULL;
123+
}
124+
125+
PyObject* module_dict = PyModule_GetDict(module);
126+
if (module_dict == NULL) {
127+
return NULL;
128+
}
129+
130+
PyObject* ipaddress_IPv4Address = PyDict_GetItemString(
131+
module_dict, "IPv4Address");
132+
int is_ipaddress_object = 0;
133+
134+
if (ipaddress_IPv4Address != NULL) {
135+
is_ipaddress_object = PyObject_IsInstance(ipaddress_IPv4Address,
136+
object);
137+
}
138+
139+
PyObject* ipaddress_IPv6Address;
140+
if (!is_ipaddress_object &&
141+
(ipaddress_IPv6Address = PyDict_GetItemString(module_dict,
142+
"IPv6Address")) != NULL) {
143+
is_ipaddress_object = PyObject_IsInstance(ipaddress_IPv6Address,
144+
object);
145+
}
146+
147+
PyErr_Clear();
148+
PyObject *str;
149+
150+
if (!is_ipaddress_object) {
151+
PyErr_SetString(PyExc_TypeError, "IP address must be a string, "
152+
" ipaddress.IPv4Address or ipaddress.IPv6Address object");
153+
} else if ((str = PyObject_Str(object)) != NULL) {
154+
#if PY_MAJOR_VERSION > 2
155+
bytes = PyUnicode_AsEncodedString(str, "UTF-8", "strict");
156+
ip_address = PyBytes_AS_STRING(bytes);
157+
#else
158+
ip_address = PyString_AsString(str);
159+
#endif
160+
}
161+
162+
Py_DECREF(module);
163+
} else {
164+
return NULL;
165+
}
166+
167+
if (ip_address == NULL) {
113168
return NULL;
114169
}
115170

@@ -127,6 +182,12 @@ static PyObject *Reader_get(PyObject *self, PyObject *args)
127182
MMDB_lookup_string(mmdb, ip_address, &gai_error,
128183
&mmdb_error);
129184

185+
#if PY_MAJOR_VERSION > 2
186+
if (bytes != NULL) {
187+
Py_DECREF(bytes);
188+
}
189+
#endif
190+
130191
if (0 != gai_error) {
131192
PyErr_Format(PyExc_ValueError,
132193
"'%s' does not appear to be an IPv4 or IPv6 address.",

tests/data

Submodule data updated 59 files

tests/reader_test.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import unicode_literals
55

6+
import ipaddress
67
import logging
78
import mock
89
import os
@@ -125,9 +126,7 @@ def test_no_extension_exception(self):
125126
def test_ip_object_lookup(self):
126127
reader = open_database('tests/data/test-data/GeoIP2-City-Test.mmdb',
127128
self.mode)
128-
with self.assertRaisesRegex(TypeError,
129-
"must be str(?:ing)?, not IPv6Address"):
130-
reader.get(compat_ip_address('2001:220::'))
129+
reader.get(compat_ip_address('2001:220::'))
131130
reader.close()
132131

133132
def test_broken_database(self):
@@ -362,6 +361,11 @@ def _check_ip_v4(self, reader, file_name):
362361
'found expected data record for ' + key_address + ' in ' +
363362
file_name)
364363

364+
self.assertEqual(
365+
data, reader.get(ipaddress.ip_address(key_address)),
366+
'found expected data record for ' + key_address + ' in ' +
367+
file_name)
368+
365369
for ip in ['1.1.1.33', '255.254.253.123']:
366370
self.assertIsNone(reader.get(ip))
367371

@@ -391,8 +395,14 @@ def _check_ip_v6(self, reader, file_name):
391395
'found expected data record for ' + key_address +
392396
' in ' + file_name)
393397

398+
self.assertEqual({'ip': value_address},
399+
reader.get(ipaddress.ip_address(key_address)),
400+
'found expected data record for ' + key_address +
401+
' in ' + file_name)
402+
394403
for ip in ['1.1.1.33', '255.254.253.123', '89fa::']:
395404
self.assertIsNone(reader.get(ip))
405+
self.assertIsNone(reader.get(ipaddress.ip_address(ip)))
396406

397407

398408
def has_maxminddb_extension():

0 commit comments

Comments
 (0)