Skip to content

Commit

Permalink
feat: Adding DatastoreOnlineStore 'database' argument. (#4180)
Browse files Browse the repository at this point in the history
* feat: adding database argument to DatastoreOnlineStore

Signed-off-by: pawel <paweel.drabczyk@gmail.com>

* feat: adding database argument to DatastoreOnlineStore

Signed-off-by: pawel <paweel.drabczyk@gmail.com>

* feat: adding database argument to DatastoreOnlineStore

Signed-off-by: pawel <paweel.drabczyk@gmail.com>

* formatting and linting sdk/python/tests/unit/diff/test_infra_diff.py

Signed-off-by: pawel <paweel.drabczyk@gmail.com>

---------

Signed-off-by: pawel <paweel.drabczyk@gmail.com>
  • Loading branch information
Pawel-Drabczyk authored May 8, 2024
1 parent 34d3635 commit e739745
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
3 changes: 3 additions & 0 deletions protos/feast/core/DatastoreTable.proto
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,7 @@ message DatastoreTable {

// Datastore namespace
google.protobuf.StringValue namespace = 4;

// Firestore database
google.protobuf.StringValue database = 5;
}
30 changes: 24 additions & 6 deletions sdk/python/feast/infra/online_stores/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ class DatastoreOnlineStoreConfig(FeastConfigBaseModel):
namespace: Optional[StrictStr] = None
""" (optional) Datastore namespace """

database: Optional[StrictStr] = None
""" (optional) Firestore database """

write_concurrency: Optional[PositiveInt] = 40
""" (optional) Amount of threads to use when writing batches of feature rows into Datastore"""

Expand Down Expand Up @@ -155,7 +158,9 @@ def teardown(
def _get_client(self, online_config: DatastoreOnlineStoreConfig):
if not self._client:
self._client = _initialize_client(
online_config.project_id, online_config.namespace
online_config.project_id,
online_config.namespace,
online_config.database,
)
return self._client

Expand Down Expand Up @@ -344,11 +349,14 @@ def worker(shared_counter):


def _initialize_client(
project_id: Optional[str], namespace: Optional[str]
project_id: Optional[str], namespace: Optional[str], database: Optional[str]
) -> datastore.Client:
try:
client = datastore.Client(
project=project_id, namespace=namespace, client_info=get_http_client_info()
project=project_id,
namespace=namespace,
database=database,
client_info=get_http_client_info(),
)
return client
except DefaultCredentialsError as e:
Expand All @@ -368,23 +376,27 @@ class DatastoreTable(InfraObject):
name: The name of the table.
project_id (optional): The GCP project id.
namespace (optional): Datastore namespace.
database (optional): Firestore database.
"""

project: str
project_id: Optional[str]
namespace: Optional[str]
database: Optional[str]

def __init__(
self,
project: str,
name: str,
project_id: Optional[str] = None,
namespace: Optional[str] = None,
database: Optional[str] = None,
):
super().__init__(name)
self.project = project
self.project_id = project_id
self.namespace = namespace
self.database = database

def to_infra_object_proto(self) -> InfraObjectProto:
datastore_table_proto = self.to_proto()
Expand All @@ -401,6 +413,8 @@ def to_proto(self) -> Any:
datastore_table_proto.project_id.value = self.project_id
if self.namespace:
datastore_table_proto.namespace.value = self.namespace
if self.database:
datastore_table_proto.database.value = self.database
return datastore_table_proto

@staticmethod
Expand All @@ -410,7 +424,7 @@ def from_infra_object_proto(infra_object_proto: InfraObjectProto) -> Any:
name=infra_object_proto.datastore_table.name,
)

# Distinguish between null and empty string, since project_id and namespace are StringValues.
# Distinguish between null and empty string, since project_id, namespace and database are StringValues.
if infra_object_proto.datastore_table.HasField("project_id"):
datastore_table.project_id = (
infra_object_proto.datastore_table.project_id.value
Expand All @@ -419,6 +433,8 @@ def from_infra_object_proto(infra_object_proto: InfraObjectProto) -> Any:
datastore_table.namespace = (
infra_object_proto.datastore_table.namespace.value
)
if infra_object_proto.datastore_table.HasField("database"):
datastore_table.database = infra_object_proto.datastore_table.database.value

return datastore_table

Expand All @@ -434,11 +450,13 @@ def from_proto(datastore_table_proto: DatastoreTableProto) -> Any:
datastore_table.project_id = datastore_table_proto.project_id.value
if datastore_table_proto.HasField("namespace"):
datastore_table.namespace = datastore_table_proto.namespace.value
if datastore_table_proto.HasField("database"):
datastore_table.database = datastore_table_proto.database.value

return datastore_table

def update(self):
client = _initialize_client(self.project_id, self.namespace)
client = _initialize_client(self.project_id, self.namespace, self.database)
key = client.key("Project", self.project, "Table", self.name)
entity = datastore.Entity(
key=key, exclude_from_indexes=("created_ts", "event_ts", "values")
Expand All @@ -447,7 +465,7 @@ def update(self):
client.put(entity)

def teardown(self):
client = _initialize_client(self.project_id, self.namespace)
client = _initialize_client(self.project_id, self.namespace, self.database)
key = client.key("Project", self.project, "Table", self.name)
_delete_all_values(client, key)

Expand Down
17 changes: 14 additions & 3 deletions sdk/python/tests/unit/diff/test_infra_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ def test_tag_infra_proto_objects_for_keep_delete_add():

def test_diff_between_datastore_tables():
pre_changed = DatastoreTable(
project="test", name="table", project_id="pre", namespace="pre"
project="test", name="table", project_id="pre", namespace="pre", database="pre"
).to_proto()
post_changed = DatastoreTable(
project="test", name="table", project_id="post", namespace="post"
project="test",
name="table",
project_id="post",
namespace="post",
database="post",
).to_proto()

infra_object_diff = diff_between(pre_changed, pre_changed, "datastore table")
Expand All @@ -51,7 +55,7 @@ def test_diff_between_datastore_tables():

infra_object_diff = diff_between(pre_changed, post_changed, "datastore table")
infra_object_property_diffs = infra_object_diff.infra_object_property_diffs
assert len(infra_object_property_diffs) == 2
assert len(infra_object_property_diffs) == 3

assert infra_object_property_diffs[0].property_name == "project_id"
assert infra_object_property_diffs[0].val_existing == wrappers.StringValue(
Expand All @@ -67,6 +71,13 @@ def test_diff_between_datastore_tables():
assert infra_object_property_diffs[1].val_declared == wrappers.StringValue(
value="post"
)
assert infra_object_property_diffs[2].property_name == "database"
assert infra_object_property_diffs[2].val_existing == wrappers.StringValue(
value="pre"
)
assert infra_object_property_diffs[2].val_declared == wrappers.StringValue(
value="post"
)


def test_diff_infra_protos():
Expand Down

0 comments on commit e739745

Please sign in to comment.