Skip to content

Commit

Permalink
ENH: Enable unary math operations for pandas, sqlite
Browse files Browse the repository at this point in the history
Implement decimal for pandas
Add SQLite unary ops
Fix operations in postgres that require numeric
  • Loading branch information
cpcloud committed Jul 20, 2017
1 parent 19d6177 commit 0dc84c3
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 16 deletions.
26 changes: 19 additions & 7 deletions ibis/pandas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
}


def pandas_dtypes_to_ibis_schema(df):
def pandas_dtypes_to_ibis_schema(df, schema):
dtypes = df.dtypes

pairs = []
Expand All @@ -41,10 +41,20 @@ def pandas_dtypes_to_ibis_schema(df):
'Column names must be strings to use the pandas backend'
)

if dtype == np.object_:
ibis_type = _INFERRED_DTYPE_TO_IBIS_TYPE[
infer_dtype(df[column_name].dropna())
]
if column_name in schema:
ibis_type = dt.validate_type(schema[column_name])
elif dtype == np.object_:
inferred_dtype = infer_dtype(df[column_name].dropna())

if inferred_dtype == 'mixed':
raise TypeError(
'Unable to infer type of column {0!r}. Try instantiating '
'your table from the client with client.table('
"'my_table', schema={{{0!r}: <explicit type>}})".format(
column_name
)
)
ibis_type = _INFERRED_DTYPE_TO_IBIS_TYPE[inferred_dtype]
elif hasattr(dtype, 'tz'):
ibis_type = dt.Timestamp(str(dtype.tz))
else:
Expand All @@ -60,9 +70,11 @@ class PandasClient(client.Client):
def __init__(self, dictionary):
self.dictionary = dictionary

def table(self, name):
def table(self, name, schema=None):
df = self.dictionary[name]
schema = pandas_dtypes_to_ibis_schema(df)
schema = pandas_dtypes_to_ibis_schema(
df, schema if schema is not None else {}
)
return ops.DatabaseTable(name, schema, self).to_expr()

def execute(self, query, *args, **kwargs):
Expand Down
48 changes: 48 additions & 0 deletions ibis/pandas/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import operator
import datetime
import functools
import decimal
import math

import six

Expand Down Expand Up @@ -56,6 +58,11 @@ def execute_cast_series_generic(op, data, type, scope=None):
return data.astype(_IBIS_TYPE_TO_PANDAS_TYPE[type])


@execute_node.register(ops.Cast, pd.Series, dt.Decimal)
def execute_cast_series_to_decimal(op, data, type, scope=None):
return data.apply(decimal.Decimal)


@execute_node.register(ops.Cast, pd.Series, dt.Timestamp)
def execute_cast_series_timestamp(op, data, type, scope=None):
# TODO(phillipc): Consistent units
Expand All @@ -82,6 +89,42 @@ def execute_cast_series_date(op, data, _, scope=None):
}


@execute_node.register(ops.UnaryOp, pd.Series)
def execute_series_unary_op(op, data, scope=None):
function = getattr(np, type(op).__name__.lower())
if data.dtype == np.dtype(np.object_):
return data.apply(functools.partial(execute_node, op, scope=scope))
return function(data)


@execute_node.register(ops.Ln, pd.Series)
def execute_series_natural_log(op, data, scope=None):
if data.dtype == np.dtype(np.object_):
return data.apply(functools.partial(execute_node, op, scope=scope))
return np.log(data)


@execute_node.register(ops.Ln, decimal.Decimal)
def execute_decimal_natural_log(op, data, scope=None):
return math.log(data)


@execute_node.register(ops.UnaryOp, decimal.Decimal)
def execute_decimal_unary(op, data, scope=None):
function = getattr(math, type(op).__name__.lower())
return function(data)


@execute_node.register(ops.Sign, decimal.Decimal)
def execute_decimal_sign(op, data, scope=None):
return math.copysign(1, data)


@execute_node.register(ops.Abs, decimal.Decimal)
def execute_decimal_abs(op, data, scope=None):
return abs(data)


@execute_node.register(ops.Cast, datetime.datetime, dt.String)
def execute_cast_datetime_or_timestamp_to_string(op, data, type, scope=None):
"""Cast timestamps to strings"""
Expand Down Expand Up @@ -160,6 +203,11 @@ def execute_cast_string_literal(op, data, type, scope=None):
return cast_function(data)


@execute_node.register(ops.Round, pd.Series, (pd.Series, int, type(None)))
def execute_round_series(op, data, places, scope=None):
return data.round(places if places is not None else 0)


@execute_node.register(ops.TableColumn, (pd.DataFrame, DataFrameGroupBy))
def execute_table_column_dataframe_or_dataframe_groupby(op, data, scope=None):
return data[op.name]
Expand Down
26 changes: 25 additions & 1 deletion ibis/pandas/tests/test_operations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import operator
import datetime
import decimal

import pytest

Expand Down Expand Up @@ -37,6 +38,7 @@ def df(tz):
'int64_with_zeros': [0, 1, 0],
'float64_with_zeros': [1.0, 0.0, 1.0],
'strings_with_nulls': ['a', None, 'b'],
'decimal': list(map(decimal.Decimal, ['1.0', '2', '3.234'])),
})


Expand All @@ -63,7 +65,7 @@ def client(df, df1, df2):

@pytest.fixture
def t(client):
return client.table('df')
return client.table('df', schema={'decimal': dt.Decimal(4, 3)})


@pytest.fixture
Expand All @@ -86,6 +88,11 @@ def test_literal(client):
assert client.execute(ibis.literal(1)) == 1


def test_read_with_undiscoverable_type(client):
with pytest.raises(TypeError):
client.table('df')


@pytest.mark.parametrize('from_', ['plain_float64', 'plain_int64'])
@pytest.mark.parametrize(
('to', 'expected'),
Expand Down Expand Up @@ -711,3 +718,20 @@ def test_notnull(t, df):
result = expr.execute()
expected = df.strings_with_nulls.notnull()
tm.assert_series_equal(result, expected)


def test_cast_to_decimal(t, df):
expr = t.float64_as_strings.cast('decimal(12, 3)')
result = expr.execute()
expected = t.float64_as_strings.execute().apply(decimal.Decimal)
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize('places', [-2, 0, 1, 2, None])
def test_round(t, df, places):
expr = t.float64_as_strings.cast('double').round(places)
result = expr.execute()
expected = t.execute().float64_as_strings.astype('float64').round(
places if places is not None else 0
)
tm.assert_series_equal(result, expected)
31 changes: 31 additions & 0 deletions ibis/sql/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,34 @@ def _table_column(t, expr):
return out_expr


def _round(t, expr):
arg, digits = expr.op().args
sa_arg = t.translate(arg)

if digits is None:
return sa.func.round(sa_arg)

result = sa.func.round(sa.cast(sa_arg, sa.NUMERIC), t.translate(digits))
if digits is not None and isinstance(arg.type(), dt.Decimal):
return result
return sa.cast(result, sa.dialects.postgresql.DOUBLE_PRECISION())


def _mod(t, expr):
left, right = map(t.translate, expr.op().args)
if not isinstance(left.type, sa.Integer) or not isinstance(
right.type, sa.Integer
):
left = sa.cast(left, sa.NUMERIC)
right = sa.cast(right, sa.NUMERIC)
return left % right


def _floor_divide(t, expr):
left, right = map(t.translate, expr.op().args)
return sa.func.floor(left / right)


_operation_registry.update({
# We override this here to support time zones
ops.TableColumn: _table_column,
Expand Down Expand Up @@ -585,6 +613,7 @@ def _table_column(t, expr):

ops.Ceil: fixed_arity(sa.func.ceil, 1),
ops.Floor: fixed_arity(sa.func.floor, 1),
ops.FloorDivide: _floor_divide,
ops.Exp: fixed_arity(sa.func.exp, 1),
ops.Sign: fixed_arity(sa.func.sign, 1),
ops.Sqrt: fixed_arity(sa.func.sqrt, 1),
Expand All @@ -593,6 +622,8 @@ def _table_column(t, expr):
ops.Log2: fixed_arity(lambda x: sa.func.log(2, x), 1),
ops.Log10: fixed_arity(sa.func.log, 1),
ops.Power: fixed_arity(sa.func.power, 2),
ops.Round: _round,
ops.Modulus: _mod,

# dates and times
ops.Strftime: _strftime,
Expand Down
92 changes: 84 additions & 8 deletions ibis/sql/sqlite/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ class SQLiteDatabase(Database):
pass


_SQLITE_UDF_REGISTRY = set()
_SQLITE_UDAF_REGISTRY = set()


def udf(f):
_SQLITE_UDF_REGISTRY.add(f)
return f


def udaf(f):
_SQLITE_UDAF_REGISTRY.add(f)
return f


@udf
def _ibis_sqlite_regex_search(string, regex):
"""Return whether `regex` exists in `string`.
Expand All @@ -51,6 +66,7 @@ def _ibis_sqlite_regex_search(string, regex):
return re.search(regex, string) is not None


@udf
def _ibis_sqlite_regex_replace(string, pattern, replacement):
"""Replace occurences of `pattern` in `string` with `replacement`.
Expand All @@ -69,6 +85,7 @@ def _ibis_sqlite_regex_replace(string, pattern, replacement):
return re.sub(pattern, replacement, string)


@udf
def _ibis_sqlite_regex_extract(string, pattern, index):
"""Extract match of regular expression `pattern` from `string` at `index`.
Expand All @@ -92,6 +109,68 @@ def _ibis_sqlite_regex_extract(string, pattern, index):
return None


@udf
def _ibis_sqlite_exp(arg):
"""Exponentiate `arg`.
Parameters
----------
arg : number
Number to raise to `e`.
Returns
-------
result : Optional[number]
None If the input is None
"""
return math.exp(arg) if arg is not None else None


@udf
def _ibis_sqlite_log(arg, base):
if arg is None or base is None or arg < 0 or base < 0:
return None
return math.log(arg, base)


@udf
def _ibis_sqlite_ln(arg):
if arg is None or arg < 0:
return None
return math.log(arg)


@udf
def _ibis_sqlite_log2(arg):
return _ibis_sqlite_log(arg, 2)


@udf
def _ibis_sqlite_log10(arg):
return _ibis_sqlite_log(arg, 10)


@udf
def _ibis_sqlite_floor(arg):
return math.floor(arg) if arg is not None else None


@udf
def _ibis_sqlite_ceil(arg):
return math.ceil(arg) if arg is not None else None


@udf
def _ibis_sqlite_sign(arg):
return math.copysign(1, arg) if arg is not None else None


@udf
def _ibis_sqlite_floordiv(left, right):
return left // right


@udf
def _ibis_sqlite_power(arg, power):
"""Raise `arg` to the `power` power.
Expand All @@ -113,6 +192,7 @@ def _ibis_sqlite_power(arg, power):
return arg ** power


@udf
def _ibis_sqlite_sqrt(arg):
"""Square root of `arg`.
Expand Down Expand Up @@ -152,12 +232,14 @@ def finalize(self):
return self.sum_of_squares_of_differences / (self.count - self.offset)


@udaf
class _ibis_sqlite_var_pop(_ibis_sqlite_var):

def __init__(self):
super(_ibis_sqlite_var_pop, self).__init__(0)


@udaf
class _ibis_sqlite_var_samp(_ibis_sqlite_var):

def __init__(self):
Expand Down Expand Up @@ -228,16 +310,10 @@ def __init__(self, path=None, create=False):
if path is not None:
self.attach(self.database_name, path, create=create)

for func in (
_ibis_sqlite_regex_search,
_ibis_sqlite_regex_replace,
_ibis_sqlite_regex_extract,
_ibis_sqlite_power,
_ibis_sqlite_sqrt,
):
for func in _SQLITE_UDF_REGISTRY:
self.con.run_callable(functools.partial(_register_function, func))

for agg in (_ibis_sqlite_var_pop, _ibis_sqlite_var_samp):
for agg in _SQLITE_UDAF_REGISTRY:
self.con.run_callable(functools.partial(_register_aggregate, agg))

@property
Expand Down
Loading

0 comments on commit 0dc84c3

Please sign in to comment.