diff --git a/sdk/python/feast/infra/online_stores/datastore.py b/sdk/python/feast/infra/online_stores/datastore.py index f8964129cd..f788f1bc74 100644 --- a/sdk/python/feast/infra/online_stores/datastore.py +++ b/sdk/python/feast/infra/online_stores/datastore.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools +import logging from datetime import datetime from multiprocessing.pool import ThreadPool +from queue import Queue +from threading import Lock, Thread from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple from pydantic import PositiveInt, StrictStr @@ -33,6 +36,8 @@ from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.usage import log_exceptions_and_usage, tracing_span +LOGGER = logging.getLogger(__name__) + try: from google.auth.exceptions import DefaultCredentialsError from google.cloud import datastore @@ -262,15 +267,46 @@ def online_read( def _delete_all_values(client, key): """ Delete all data under the key path in datastore. + + Creates and uses a queue of lists of entity keys, which are batch deleted + by multiple threads. """ + + class AtomicCounter(object): + # for tracking how many deletions have already occurred; not used outside this method + def __init__(self): + self.value = 0 + self.lock = Lock() + + def increment(self): + with self.lock: + self.value += 1 + + BATCH_SIZE = 500 # Dec 2021: delete_multi has a max size of 500: https://cloud.google.com/datastore/docs/concepts/limits + NUM_THREADS = 3 + deletion_queue = Queue() + status_info_counter = AtomicCounter() + + def worker(shared_counter): + while True: + client.delete_multi(deletion_queue.get()) + shared_counter.increment() + LOGGER.debug( + f"batch deletions completed: {shared_counter.value} ({shared_counter.value * BATCH_SIZE} total entries) & outstanding queue size: {deletion_queue.qsize()}" + ) + deletion_queue.task_done() + + for _ in range(NUM_THREADS): + Thread(target=worker, args=(status_info_counter,), daemon=True).start() + + query = client.query(kind="Row", ancestor=key) while True: - query = client.query(kind="Row", ancestor=key) - entities = list(query.fetch(limit=1000)) + entities = list(query.fetch(limit=BATCH_SIZE)) if not entities: - return + break + deletion_queue.put([entity.key for entity in entities]) - for entity in entities: - client.delete(entity.key) + deletion_queue.join() def _initialize_client(