Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick zstd compressor #180

Merged
merged 7 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 37 additions & 40 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
name: CI

on:
push:
branches:
- main
- master
tags:
- '*'
pull_request:
branches:
- main
- master
workflow_dispatch:

push:
branches:
- main
- master
tags:
- "*"
pull_request:
branches:
- main
- master
workflow_dispatch:

permissions:
contents: read
Expand All @@ -22,7 +21,6 @@ env:
DOLMA_TEST_S3_PREFIX: s3://dolma-tests
RUST_CHANNEL: stable


jobs:
info:
name: Run info
Expand All @@ -40,32 +38,31 @@ jobs:
echo "PR base repo: ${{ github.event.pull_request.base.repo.full_name }}/tree/${{ github.event.pull_request.base.ref }}"
echo "PR head repo: ${{ github.event.pull_request.head.repo.full_name }}/tree/${{ github.event.pull_request.head.ref }}"


should_build:
name: "Check if build"
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
ref: ${{ github.ref }}
- name: List branches and tags
run: |
git branch -a
git tag -l
git log | head -n 1000
- id: check_version
run: |
set +e
has_updated=$(git diff --name-only '${{ github.event.pull_request.base.sha }}' | grep -E 'pyproject.toml|Cargo.toml')
is_main_or_release='${{ github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/tags/') }}'
if [[ -n "${has_updated}" ]] || [[ "${is_main_or_release}" == 'true' ]]; then
echo "should_build=true" >> $GITHUB_OUTPUT
else
echo "should_build=false" >> $GITHUB_OUTPUT
fi
shell: bash
- name: checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
ref: ${{ github.ref }}
- name: List branches and tags
run: |
git branch -a
git tag -l
git log | head -n 1000
- id: check_version
run: |
set +e
has_updated=$(git diff --name-only '${{ github.event.pull_request.base.sha }}' | grep -E 'pyproject.toml|Cargo.toml')
is_main_or_release='${{ github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/tags/') }}'
if [[ -n "${has_updated}" ]] || [[ "${is_main_or_release}" == 'true' ]]; then
echo "should_build=true" >> $GITHUB_OUTPUT
else
echo "should_build=false" >> $GITHUB_OUTPUT
fi
shell: bash
outputs:
should_build: ${{ steps.check_version.outputs.should_build }}

Expand All @@ -88,7 +85,7 @@ jobs:

- name: Setup system libraries
if: steps.cache-venv.outputs.cache-hit != 'true'
run: |
run: |
sudo apt-get update
sudo apt-get install --yes --upgrade build-essential cmake protobuf-compiler libssl-dev glibc-source musl-tools

Expand All @@ -103,7 +100,7 @@ jobs:
if: steps.cache-venv.outputs.cache-hit != 'true'
uses: actions/setup-python@v4
with:
python-version: '3.8'
python-version: "3.8"
architecture: "x64"

- name: Create a new Python environment & install maturin
Expand Down Expand Up @@ -188,7 +185,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: "3.10"
- name: Install 32bit version of libc
if: ${{ matrix.target == 'x86' || contains(matrix.target, 'i686') }}
run: |
Expand Down Expand Up @@ -222,7 +219,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: "3.10"
architecture: ${{ matrix.target }}
- name: Build wheels
uses: PyO3/maturin-action@v1
Expand All @@ -247,7 +244,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: "3.10"
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
Expand Down
9 changes: 5 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ threadpool = "1.8.1"
tokenizers = { version = "0.15.0", features = ["http"] }
tokio = { version = "1.27.0", features = ["full"] }
tokio-util = "0.7.7"
time = "0.3.36"
unicode-segmentation = "1.7"
openssl = { version = "0.10.63", features = ["vendored"] }
adblock = { version = "0.8.6", features = ["content-blocking"] }
Expand Down
31 changes: 10 additions & 21 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dolma"
version = "1.0.4"
version = "1.0.5"
description = "Data filters"
license = { text = "Apache-2.0" }
readme = "README.md"
Expand Down Expand Up @@ -30,6 +30,7 @@ dependencies = [
"numpy",
"necessary>=0.4.3",
"charset-normalizer>=3.2.0",
"zstandard>=0.23.0",
]
classifiers = [
"Development Status :: 5 - Production/Stable",
Expand Down Expand Up @@ -117,35 +118,27 @@ pii = ["presidio_analyzer==2.2.32", "regex"]
# language detection; by default, we use fastttext, everything else is optional
lang = [
"fasttext-wheel==0.9.2",
"LTpycld2==0.42", # fork of pycld2 that works on Apple Silicon
"LTpycld2==0.42", # fork of pycld2 that works on Apple Silicon
"lingua-language-detector>=2.0.0",
"langdetect>=1.0.9"
"langdetect>=1.0.9",
]

# extension to parse warc files
warc = [
"fastwarc",
"w3lib",
"url-normalize",

]
warc = ["fastwarc", "w3lib", "url-normalize"]
trafilatura = [
# must include warc dependencies
"dolma[warc]",
# core package
"trafilatura>=1.6.1",
# following are all for speeding up trafilatura
"brotli",
"cchardet >= 2.1.7; python_version < '3.11'", # build issue
"faust-cchardet >= 2.1.18; python_version >= '3.11'", # fix for build
"cchardet >= 2.1.7; python_version < '3.11'", # build issue
"faust-cchardet >= 2.1.18; python_version >= '3.11'", # fix for build
"htmldate[speed] >= 1.4.3",
"py3langid >= 0.2.2",
]

resiliparse = [
"dolma[warc]",
"resiliparse",
]
resiliparse = ["dolma[warc]", "resiliparse"]

# all extensions
all = [
Expand All @@ -154,15 +147,11 @@ all = [
"dolma[pii]",
"dolma[trafilatura]",
"dolma[resiliparse]",
"dolma[lang]"
"dolma[lang]",
]

[build-system]
requires = [
"maturin[patchelf]>=1.1,<2.0",
"setuptools >= 61.0.0",
"wheel"
]
requires = ["maturin[patchelf]>=1.1,<2.0", "setuptools >= 61.0.0", "wheel"]
build-backend = "maturin"


Expand Down
38 changes: 38 additions & 0 deletions python/dolma/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import io
import os
import re
import string
Expand All @@ -14,8 +15,11 @@

import nltk
import uniseg.wordbreak
import zstandard
from necessary import necessary
from nltk.tokenize.punkt import PunktSentenceTokenizer
from omegaconf import OmegaConf as om
from smart_open import register_compressor

try:
nltk.data.find("tokenizers/punkt")
Expand Down Expand Up @@ -148,3 +152,37 @@ def dataclass_to_dict(dataclass_instance) -> dict:

# force typecasting because a dataclass instance will always be a dict
return cast(dict, om.to_object(om.structured(dataclass_instance)))


def add_compression():
"""
Adds support for zstandard (.zst) compression format to the smart_open library.

This function registers a custom compressor for the .zst file extension in the smart_open library.
The compressor uses the zstandard library to handle zstandard compression.
"""

def _handle_zstd(file_obj, mode):
result = zstandard.open(filename=file_obj, mode=mode)
# zstandard.open returns an io.TextIOWrapper in text mode, but otherwise
# returns a raw stream reader/writer, and we need the `io` wrapper
# to make FileLikeProxy work correctly.
if "b" in mode and "w" in mode:
result = io.BufferedWriter(result)
elif "b" in mode and "r" in mode:
result = io.BufferedReader(result)
return result

register_compressor(".zst", _handle_zstd)
register_compressor(".zstd", _handle_zstd)


with necessary(("smart_open", "7.0.4"), soft=True) as SMART_OPEN_HAS_ZSTD:
if SMART_OPEN_HAS_ZSTD:
# add additional extension for smart_open
from smart_open.compression import _handle_zstd

register_compressor(".zstd", _handle_zstd)
else:
# add zstd compression
add_compression()
Loading