Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel recursive breadth-first w.workspace.list(..., recursive=True, threads=os.cpu_count()) to iterate over 10K notebooks faster #284

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions databricks/sdk/mixins/workspace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
from typing import BinaryIO, Iterator, Optional
from queue import Queue

from ..core import DatabricksError
from ..service.workspace import (ExportFormat, ImportFormat, Language,
Expand All @@ -8,6 +10,54 @@
def _fqcn(x: any) -> str:
return f'{x.__module__}.{x.__name__}'

_LOG = logging.getLogger('databricks.sdk')

class _ParallelRecursiveListing:
def __init__(self, path, listing, threads, notebooks_modified_after):
self.path = path
self.listing = listing
self.threads = threads
self.notebooks_modified_after = notebooks_modified_after
self.directories = Queue()
self.results = Queue()
self.directories.put_nowait(path)
self._start()
nfx marked this conversation as resolved.
Show resolved Hide resolved

def _worker(self):
while True:
path = self.directories.get()
if path is None:
_LOG.debug('stopping thread')
break # poison pill
for object_info in self.listing(
path, notebooks_modified_after=self.notebooks_modified_after):
if object_info.object_type == ObjectType.DIRECTORY:
self.directories.put(object_info.path)
continue
nfx marked this conversation as resolved.
Show resolved Hide resolved
_LOG.debug(f'found: {object_info.path}')
self.results.put_nowait(object_info)
self.directories.task_done()
nfx marked this conversation as resolved.
Show resolved Hide resolved
if path == self.path:
_LOG.debug('done iterating')
for _ in range(self.threads-1):
self.directories.put(None)

def _start(self):
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as pool:
tasks = []
for _ in range(self.threads):
tasks.append(pool.submit(self._worker))
concurrent.futures.wait(tasks)

def __iter__(self) -> Iterator[ObjectInfo]:

while self.results.not_empty:
nfx marked this conversation as resolved.
Show resolved Hide resolved
yield self.__next__()

def __next__(self) -> bytes:
yield self.results.get()


class WorkspaceExt(WorkspaceAPI):
__doc__ = WorkspaceAPI.__doc__
Expand All @@ -17,6 +67,7 @@ def list(self,
*,
notebooks_modified_after: Optional[int] = None,
recursive: Optional[bool] = False,
threads: Optional[int] = None,
**kwargs) -> Iterator[ObjectInfo]:
"""List workspace objects

Expand All @@ -26,6 +77,8 @@ def list(self,
:returns: Iterator of workspaceObjectInfo
"""
parent_list = super().list
if threads is not None:
return _ParallelRecursiveListing(path, parent_list,threads, notebooks_modified_after)
queue = [path]
while queue:
path, queue = queue[0], queue[1:]
Expand Down
15 changes: 15 additions & 0 deletions tests/integration/test_workspace.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
import io
import logging

from databricks.sdk.service.workspace import ImportFormat, Language

_LOG = logging.getLogger('databricks.sdk')

def test_workspace_recursive_list_parallel_xx():
from databricks.sdk import WorkspaceClient
w = WorkspaceClient(profile='demo')
for i in w.workspace.list(f'/Users/serge.smertin@databricks.com', threads=20, recursive=True):
_LOG.info(f'FOUND: {i}')
_LOG.info('DONE')


def test_workspace_recursive_list_parallel(w):
for i in w.workspace._parallel_recursive_listing(f'/Users', threads=20):
print(f'FOUND: {i}')
print('DONE')

def test_workspace_recursive_list(w, random):
names = []
Expand Down
Loading