Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Progress Bar, Backoff, Batching #165

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .devcontainer/postInstall.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

PATH=/home/vscode/.cargo/bin:$PATH
cd dolma
source /home/vscode/miniforge3/bin/activate && pip install cmake "maturin[patchelf]>=1.1,<2.0"
source /home/vscode/miniforge3/bin/activate && pip install cmake "maturin>=1.5,<2.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙏

2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ setup:
$(shell "${PROTOBUF_SETUP}")
$(shell "${OPENSSL_SETUP}")
which cargo || curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
which maturin || pip install maturin[patchelf]
which maturin || pip install 'maturin>=1.5,<2.0'

publish:
maturin publish
Expand Down
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"omegaconf>=2.3.0",
# "pycld2==0.41",
# "pycld3==0.22", # does not install correctly
"hyperscan>=0.7.0",
"platformdirs>=4.2.0",
"pyyaml",
"requests",
Expand All @@ -30,6 +31,8 @@ dependencies = [
"numpy",
"necessary>=0.4.3",
"charset-normalizer>=3.2.0",
"zstandard>=0.20.0",
"backoff>=2.0.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this version required? There's 2 minor versions since this 2.0 release "2.2.1"

]
classifiers = [
"Development Status :: 5 - Production/Stable",
Expand Down Expand Up @@ -99,7 +102,7 @@ dolma = "dolma.cli.__main__:main"

[project.optional-dependencies]
dev = [
"black>=22.6.0",
"black[jupyter]>=22.6.0",
"flake8>=5.0",
"flake8-pyi>=22.8.1",
"Flake8-pyproject>=1.1.0",
Expand Down Expand Up @@ -127,7 +130,6 @@ warc = [
"fastwarc",
"w3lib",
"url-normalize",

]
trafilatura = [
# must include warc dependencies
Expand Down Expand Up @@ -159,7 +161,7 @@ all = [

[build-system]
requires = [
"maturin[patchelf]>=1.1,<2.0",
"maturin>=1.5,<2.0",
"setuptools >= 61.0.0",
"wheel"
]
Expand All @@ -175,7 +177,7 @@ features = ["pyo3/extension-module"]
where = ["src"]

[tool.setuptools.package-data]
dolma = ["py.typed", "data/*"]
dolma = ["py.typed", "data/*", "*.pyi"]

[tool.black]
line-length = 115
Expand Down
9 changes: 7 additions & 2 deletions python/dolma/core/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@
DOLMA_PREFIX = "dolma"


def get_logger(name: str) -> logging.Logger:
def get_logger(name: str, level: Union[int, str] = logging.WARN) -> logging.Logger:
if (proc_name := multiprocessing.current_process().name) == "MainProcess":
proc_name = "main"
proc_name = proc_name.replace(" ", "_")

# set the log level
level = level if isinstance(level, int) else getattr(logging, level.strip().upper(), logging.WARN)

# set name
name = f"{proc_name}.dolma.{name}"
logger = logging.getLogger(name)
logger.setLevel(logging.WARN)
logger.setLevel(level)

# add handler
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
Expand Down
130 changes: 130 additions & 0 deletions python/dolma/core/mp_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import multiprocessing
import time
from contextlib import ExitStack
from multiprocessing.managers import SyncManager
from multiprocessing.pool import Pool
from queue import Queue
from typing import Any, Callable, Dict, Generic, Iterable, Optional, TypeVar, Union

T = TypeVar("T")
R = TypeVar("R")


def get_manager(pool: Union[Pool, "PoolWithDebug"]) -> Union[SyncManager, "ManagerWithDebug"]:
if getattr(pool, "debug", False):
return ManagerWithDebug()
else:
return multiprocessing.Manager()


class ResultWithDebug(Generic[T]):
def __init__(self, result: T, *args, **kwargs):
self.result = result

def get(self, timeout: Optional[float] = None) -> T:
return self.result

def wait(self, timeout: Optional[float] = None) -> None:
time.sleep(timeout or 0)

def successful(self) -> bool:
return True

def ready(self) -> bool:
return True


class ManagerWithDebug:
def Queue(self):
return Queue()

def shutdown(self) -> None:
pass


class PoolWithDebug:
"""A wrapper around multiprocessing.Pool that allows for debugging (i.e., running without multiprocessing).
Supports creating a manager for shared memory objects (mock in case of debugging)."""

def __init__(
self,
processes: Optional[int] = None,
initializer: Optional[Callable[..., Any]] = None,
initargs: Iterable[Any] = (),
maxtasksperchild: Optional[int] = None,
debug: bool = False,
):
self.processes = processes
self.initializer = initializer
self.initargs = initargs
self.maxtasksperchild = maxtasksperchild
self.debug = debug

# we are gonna keep track of resources in stack; but also keeping them indexed
# separately for easy access
self.stack = ExitStack()
self._manager: Optional[SyncManager] = None
self._pool: Optional[Pool] = None

# let's make sure that the start method is spawn for best performance
try:
multiprocessing.set_start_method("spawn")
except RuntimeError:
assert multiprocessing.get_start_method() == "spawn", "Multiprocessing start method must be spawn"

def __enter__(self):
if self._pool is None and not self.debug:
self._pool = self.stack.enter_context(
Pool(
processes=self.processes,
initializer=self.initializer,
initargs=self.initargs,
maxtasksperchild=self.maxtasksperchild,
)
)
return self

def Manager(self):
if self._manager is None:
self._manager = (
ManagerWithDebug() # pyright: ignore
if self.debug
else self.stack.enter_context(multiprocessing.Manager())
)
return self._manager

def __exit__(self, *exc):
return self.stack.close()

def apply_async(
self,
func: Callable[..., R],
args: Iterable[Any] = (),
kwds: Dict[str, Any] = {},
callback: Optional[Callable[[R], Any]] = None,
error_callback: Optional[Callable[[Any], Any]] = None,
):
if self._pool is None:
if self.initializer:
# run the initializer once by calling it with the initargs and then setting it to None
self.initializer(*self.initargs)
self.initializer = None
try:
resp = func(*args, **kwds)
if callback is not None:
callback(resp)
return ResultWithDebug(resp)
except Exception as e:
if error_callback is not None:
error_callback(e)
raise e
else:
return self._pool.apply_async(
func=func, args=args, kwds=kwds, callback=callback, error_callback=error_callback
)

def close(self):
return self._pool and self._pool.close()

def join(self):
return self._pool and self._pool.join()
19 changes: 19 additions & 0 deletions python/dolma/core/mp_tools.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from collections.abc import Callable, Iterable
from multiprocessing.managers import SyncManager
from multiprocessing.pool import ApplyResult, Pool
from typing import Any

class ResultWithDebug(ApplyResult): ... # noqa: E701,E302
class ManagerWithDebug(SyncManager): ... # noqa: E701

class PoolWithDebug(Pool): # noqa: E302
def __init__( # noqa: E704
self,
processes: int | None = None,
initializer: Callable[..., Any] | None = None,
initargs: Iterable[Any] = (),
maxtasksperchild: int | None = None,
debug: bool = False,
): ...

def get_manager(pool: Pool) -> SyncManager: ... # noqa: E701, E704, E302
Loading
Loading