Skip to content

Commit 909d1f5

Browse files
committed
Test compress argument types and clarify docs
1 parent d5fb7e6 commit 909d1f5

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

django_pandas/io.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def _get_dtypes(fields_to_dtypes, fields):
8989
dtypes = []
9090
f2d = _FIELDS_TO_DTYPES.copy()
9191
f2d.update(fields_to_dtypes)
92+
for k, v in f2d.items():
93+
if not issubclass(k, django.db.models.fields.Field):
94+
raise TypeError(f'Expected a type of field, not {k!r}')
9295
for field in fields:
9396
# Find the lowest subclass among the keys of f2d
9497
t, dtype = object, object
@@ -138,7 +141,7 @@ def read_frame(qs, fieldnames=(), index_col=None, coerce_float=False,
138141
defined in the ``__unicode__`` or ``__str__``
139142
methods of the related class definition
140143
141-
compress: boolean or a mapping, default False
144+
compress: a false value, ``True``, or a mapping, default False
142145
If a true value, infer NumPy data types [#]_ for Pandas dataframe
143146
columns from the corresponding Django field types. For example, Django's
144147
built in ``SmallIntgerField`` is cast to NumPy's ``int16``. If
@@ -202,6 +205,8 @@ def read_frame(qs, fieldnames=(), index_col=None, coerce_float=False,
202205
recs = qs.iterator()
203206

204207
if compress:
208+
if not isinstance(compress, (bool, Mapping)):
209+
raise TypeError(f'Ambiguous compress argument: {compress!r}')
205210
if not isinstance(compress, Mapping):
206211
compress = {}
207212
recs = np.array(list(recs), dtype=_get_dtypes(compress, fields))

django_pandas/tests/test_io.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ def test_compress_basic(self):
6060
# Compress should use less memory
6161
self.assertLess(df.memory_usage().sum(), read_frame(qs).memory_usage().sum())
6262

63+
def test_compress_bad_argument(self):
64+
qs = MyModel.objects.all()
65+
bads = [(models.ByteField, np.int8), range(3), type, object(), 'a', 1.,
66+
{'IntegerField': int}, {int: models.ByteField}]
67+
for bad in bads:
68+
self.assertRaises(TypeError, read_frame, qs, compress=bad)
69+
6370
def assert_default_compressable(self, df):
6471
for field in models.CompressableModel._meta.get_fields():
6572
if field.name == 'id':

0 commit comments

Comments
 (0)