Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci_code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ jobs:

- name: Unit Testing
run: |
sqlite3 test.db "create table t(f int); drop table t;"
make unit_testing pytest_arguments="--cov=superduper --cov-report=xml" SUPERDUPER_CONFIG=test/configs/${{ matrix.config }}

- name: Usecase Testing
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/ci_plugins.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ jobs:
- name: Plugin Testing
run: |
export PYTHONPATH=./
if [ "${{ matrix.plugin }}" = "sql" ]; then
sqlite3 test.db "create table t(f int); drop table t;"
fi
if [ -d "plugins/${{ matrix.plugin }}/plugin_test" ]; then
pytest --cov=superduper --cov-report=xml plugins/${{ matrix.plugin }}/plugin_test
else
Expand All @@ -103,6 +106,9 @@ jobs:

- name: Optionally run the base testing
run: |
if [ "${{ matrix.plugin }}" = "sql" ]; then
sqlite3 test.db "create table t(f int); drop table t;"
fi
SUPERDUPER_CONFIG="plugins/${{ matrix.plugin }}/plugin_test/config.yaml"
if [ -f "$SUPERDUPER_CONFIG" ]; then
echo "Running the base testing..."
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add assertion to verify directory copy in FileSystemArtifactStore
- Batch the Qdrant requests and add a retry to the config of Qdrant
- Add use_component_cache to config
- Save data in the `Component` table instead of in individual tables

### Bug fixes

Expand Down
42 changes: 24 additions & 18 deletions plugins/mongodb/superduper_mongodb/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from bson.objectid import ObjectId
from superduper import CFG, logging
from superduper.backends.base.data_backend import BaseDataBackend
from superduper.base.base import REGISTRY
from superduper.base.query import Query
from superduper.base.schema import Schema

Expand Down Expand Up @@ -171,24 +172,29 @@ def missing_outputs(self, query, predict_id: str):
def select(self, query: Query):
"""Select data from the table."""
if query.decomposition.outputs:
return self._outputs(query)

collection = self._database[query.table]

logging.debug(str(query))

limit = self._get_limit(query)
if limit:
native_query = collection.find(
self._mongo_filter(query), self._get_project(query)
).limit(limit)
if skip := self._get_skip(query):
native_query = native_query.skip(skip)
return list(native_query)

return list(
collection.find(self._mongo_filter(query), self._get_project(query))
)
output = self._outputs(query)
else:
collection = self._database[query.table]

logging.debug(str(query))

limit = self._get_limit(query)
if limit:
native_query = collection.find(
self._mongo_filter(query), self._get_project(query)
).limit(limit)
if skip := self._get_skip(query):
native_query = native_query.skip(skip)
output = list(native_query)
else:
output = list(
collection.find(self._mongo_filter(query), self._get_project(query))
)
if query.table in REGISTRY and REGISTRY[query.table].primary_id != '_id':
for o in output:
if '_id' in o:
del o['_id']
return output

def to_id(self, id):
"""Convert the ID to the correct format."""
Expand Down
1 change: 0 additions & 1 deletion plugins/qdrant/superduper_qdrant/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def _do_scroll(offset):

return ids


def _create_collection(self):
measure = (
self.measure.name
Expand Down
2 changes: 1 addition & 1 deletion plugins/sql/plugin_test/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
data_backend: sqlite://
data_backend: sqlite://./test.db
auto_schema: false
force_apply: true
json_native: false
Expand Down
26 changes: 14 additions & 12 deletions plugins/sql/plugin_test/test_query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing as t
from test.utils.setup.fake_data import add_listeners, add_models, add_random_data

import numpy as np
Expand Down Expand Up @@ -44,7 +45,7 @@ def test_renamings(db):
add_random_data(db, n=5)
add_models(db)
add_listeners(db)
t = db["documents"]
t = db["documentz"]
listener_uuid = [db.load('Listener', k).outputs for k in db.show("Listener")][0]
q = t.select("id", "x", "y").outputs(listener_uuid.split('__', 1)[-1])
data = q.execute()
Expand All @@ -62,13 +63,13 @@ def test_serialize_query(db):

def test_get_data(db):
add_random_data(db, n=5)
db["documents"].limit(2)
db.metadata.get_component("Table", "documents")
db["documentz"].limit(2)
db.metadata.get_component("Table", "documentz")


def test_insert_select(db):
add_random_data(db, n=5)
q = db["documents"].select("id", "x", "y").limit(2)
q = db["documentz"].select("id", "x", "y").limit(2)
r = q.execute()

assert len(r) == 2
Expand All @@ -77,7 +78,7 @@ def test_insert_select(db):

def test_filter(db):
add_random_data(db, n=5)
t = db["documents"]
t = db["documentz"]
q = t.select("id", "y")
r = q.execute()
ys = [x["y"] for x in r]
Expand All @@ -88,17 +89,18 @@ def test_filter(db):
assert len(r) == uq[1][0]


class documents(Base):
class documents_plugin(Base):
primary_id: t.ClassVar[str] = 'id'
this: 'str'


def test_select_using_ids(db):
db.create(documents)
db.create(documents_plugin)

table = db["documents"]
table = db["documents_plugin"]
table.insert([{"this": f"is a test {i}", "id": str(i)} for i in range(4)])

basic_select = db['documents'].select()
basic_select = db['documents_plugin'].select()

assert len(basic_select.execute()) == 4
assert len(basic_select.subset(['1', '2'])) == 2
Expand All @@ -112,16 +114,16 @@ def my_func(this: str):

my_func = ObjectModel('my_func', object=my_func)

db.create(documents)
db.create(documents_plugin)

table = db["documents"]
table = db["documents_plugin"]
table.insert([{"this": f"is a test {i}", "id": str(i)} for i in range(4)])

listener = Listener(
'test',
model=my_func,
key='this',
select=db['documents'].select(),
select=db['documents_plugin'].select(),
)
db.apply(listener)

Expand Down
13 changes: 9 additions & 4 deletions plugins/sql/superduper_sql/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,10 @@ class SQLDatabackend(IbisDataBackend):

def __init__(self, uri, plugin, flavour=None):
super().__init__(uri, plugin, flavour)
self._create_sqlalchemy_engine()
if 'sqlite://./' in uri:
self._create_sqlalchemy_engine(uri.replace('./', '//'))
else:
self._create_sqlalchemy_engine(uri)
self.sm = sessionmaker(bind=self.alchemy_engine)

@property
Expand All @@ -374,6 +377,8 @@ def update(self, table, condition, key, value):
with self.sm() as session:
metadata = MetaData()

assert table in self.list_tables()

metadata.reflect(bind=session.bind)
table = Table(table, metadata, autoload_with=session.bind)

Expand Down Expand Up @@ -422,16 +427,16 @@ def delete(self, table, condition):
except NoSuchTableError:
raise exceptions.NotFound("Table", table)

def _create_sqlalchemy_engine(self):
def _create_sqlalchemy_engine(self, uri):
with self.connection_manager.get_connection() as conn:
self.alchemy_engine = create_engine(self.uri, creator=lambda: conn.con)
self.alchemy_engine = create_engine(uri, creator=lambda: conn.con)
if not self._test_engine():
logging.warn(
"Unable to reuse the ibis connection "
"to create the SQLAlchemy engine. "
"Creating a new connection with the URI."
)
self.alchemy_engine = create_engine(self.uri)
self.alchemy_engine = create_engine(uri)

def _test_engine(self):
"""Test the engine."""
Expand Down
Loading
Loading