Skip to content

Commit

Permalink
test(datatypes): handle more cases and improve coverage of default in…
Browse files Browse the repository at this point in the history
…terval precision
  • Loading branch information
cpcloud authored and kszucs committed Dec 26, 2023
1 parent d22f97a commit 53c936f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
19 changes: 7 additions & 12 deletions ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import abc
from functools import partial

import sqlglot as sg
Expand Down Expand Up @@ -123,10 +122,8 @@


class SqlglotType(TypeMapper):
@property
@abc.abstractmethod
def dialect(self) -> str:
"""The dialect this parser is for."""
dialect: str | None = None
"""The dialect this parser is for."""

default_nullable = True
"""Default nullability when not specified."""
Expand Down Expand Up @@ -231,18 +228,18 @@ def _from_sqlglot_TIMESTAMPLTZ(cls, scale=None) -> dt.Timestamp:

@classmethod
def _from_sqlglot_INTERVAL(
cls, precision_or_span: sge.DataTypeParam | sge.IntervalSpan | None = None
cls, precision_or_span: sge.IntervalSpan | None = None
) -> dt.Interval:
nullable = cls.default_nullable
if precision_or_span is None:
precision_or_span = cls.default_interval_precision

if isinstance(precision_or_span, str):
return dt.Interval(precision_or_span, nullable=nullable)
elif isinstance(precision_or_span, sge.DataTypeParam):
return dt.Interval(str(precision_or_span), nullable=nullable)
elif isinstance(precision_or_span, sge.IntervalSpan):
return dt.Interval(unit=precision_or_span.this.this, nullable=nullable)
elif precision_or_span is None:
raise com.IbisTypeError("Interval precision is None")
else:
raise com.IbisTypeError(precision_or_span)

Expand Down Expand Up @@ -274,12 +271,10 @@ def _from_sqlglot_GEOGRAPHY(cls) -> sge.DataType:

@classmethod
def _from_ibis_Interval(cls, dtype: dt.Interval) -> sge.DataType:
if (unit := dtype.unit) is None:
return sge.DataType(this=typecode.INTERVAL)

assert dtype.unit is not None, "interval unit cannot be None"
return sge.DataType(
this=typecode.INTERVAL,
expressions=[sge.IntervalSpan(this=sge.Var(this=unit.name))],
expressions=[sge.IntervalSpan(this=sge.Var(this=dtype.unit.name))],
)

@classmethod
Expand Down
11 changes: 10 additions & 1 deletion ibis/backends/base/sqlglot/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import hypothesis as h
import hypothesis.strategies as st
import pytest
import sqlglot.expressions as sge

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.tests.strategies as its
from ibis.backends.base.sqlglot.datatypes import SqlglotType
from ibis.backends.base.sqlglot.datatypes import DuckDBType, PostgresType, SqlglotType


def assert_dtype_roundtrip(ibis_type, sqlglot_expected=None):
Expand Down Expand Up @@ -65,3 +67,10 @@ def test_specific_geometry_types(ibis_type):
assert SqlglotType.to_ibis(sqlglot_result) == dt.GeoSpatial(
geotype="geometry", nullable=ibis_type.nullable
)


def test_interval_without_unit():
with pytest.raises(com.IbisTypeError, match="precision is None"):
SqlglotType.from_string("INTERVAL")
assert PostgresType.from_string("INTERVAL") == dt.Interval("s")
assert DuckDBType.from_string("INTERVAL") == dt.Interval("us")

0 comments on commit 53c936f

Please sign in to comment.