Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ZST support in Deduper and Mixer #170

Merged
merged 8 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ jaq-core = "1.2.1"
jaq-std = "1.2.1"
jaq-parse = "1.0.2"
jaq-interpret = { version = "1.2.1", features = ["serde_json"] }
zstd = "0.13.1"

[dev-dependencies]
tempfile = "3.10.1"

# [target.'cfg(target_arch = "aarch64")'.dependencies]
# openssl = { version = "0.10.63", features = ["vendored"] }
Expand Down
19 changes: 18 additions & 1 deletion python/dolma/cli/deduper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@

from dolma import deduper
from dolma.cli import BaseCli, field, print_config
from dolma.cli.shared import WorkDirConfig, get_path_to_temp_file, make_workdirs
from dolma.cli.shared import (
CompressionConfig,
WorkDirConfig,
get_path_to_temp_file,
make_workdirs,
)
from dolma.core.errors import DolmaConfigError
from dolma.core.loggers import get_logger
from dolma.core.paths import glob_path, is_local
Expand Down Expand Up @@ -99,6 +104,13 @@ class DeduperConfig:
processes: int = field(
default=1, help="Number of processes to use for deduplication. If 1, no multiprocessing will be used."
)
compression: CompressionConfig = field(
default=CompressionConfig(),
help=(
"Configuration for input/output compression. By default, compression of files is inferred "
"from the file extension."
),
)
dryrun: bool = field(
default=False,
help="If true, only print the configuration and exit without running the deduper.",
Expand Down Expand Up @@ -209,6 +221,11 @@ def run(cls, parsed_config: DeduperConfig):
dict_config["work_dir"] = {"input": str(work_dirs.input), "output": str(work_dirs.output)}
dict_config["processes"] = int(parsed_config.processes)

dict_config["compression"] = {
"input": str(i) if (i := parsed_config.compression.input) is not None else None,
"output": str(o) if (o := parsed_config.compression.output) is not None else None,
}

if len(dict_config["documents"]) == 0:
raise ValueError("At least one document must be specified")

Expand Down
15 changes: 14 additions & 1 deletion python/dolma/cli/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dolma import mixer
from dolma.cli import BaseCli, field, print_config
from dolma.cli.shared import WorkDirConfig, make_workdirs
from dolma.cli.shared import CompressionConfig, WorkDirConfig, make_workdirs
from dolma.core.errors import DolmaConfigError
from dolma.core.loggers import get_logger
from dolma.core.paths import glob_path
Expand Down Expand Up @@ -59,6 +59,13 @@ class StreamConfig:
default=None, help="Configuration for filtering documents."
)
span_replacement: List[SpanReplacementConfig] = field(default=[], help="Configuration for replacing spans.")
compression: CompressionConfig = field(
default=CompressionConfig(),
help=(
"Configuration for input/output compression. By default, compression of files is inferred "
"from the file extension."
),
)


@dataclass
Expand Down Expand Up @@ -159,6 +166,12 @@ def run(cls, parsed_config: MixerConfig):
"max_size_in_bytes": int(stream_config.output.max_size_in_bytes),
}

# add compression config to the stream config dict
stream_config_dict["compression"] = {
"input": str(i) if (i := stream_config.compression.input) is not None else None,
"output": str(o) if (o := stream_config.compression.output) is not None else None,
}

if stream_config.output.min_text_length:
stream_config_dict["output"]["min_text_length"] = int(stream_config.output.min_text_length)
if stream_config.output.min_text_length < 0:
Expand Down
6 changes: 6 additions & 0 deletions python/dolma/cli/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ class WorkDirConfig:
output: Optional[str] = field(default=None, help="Path to the output directory.")


@dataclass
class CompressionConfig:
input: Optional[str] = field(default=None, help="Compression algorithm to use for input files")
output: Optional[str] = field(default=None, help="Compression algorithm to use for output files")


@contextmanager
def get_path_to_temp_file(prefix="dolma-", suffix=None) -> Generator[Path, None, None]:
with tempfile.NamedTemporaryFile(prefix=prefix, suffix=suffix, delete=True) as f:
Expand Down
80 changes: 50 additions & 30 deletions src/deduper.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
use std::collections::VecDeque;
use std::fs::OpenOptions;
use std::io;
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::io::{BufRead, Write};
use std::path::PathBuf;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;

use flate2::read::MultiGzDecoder;
use flate2::write::GzEncoder;
use flate2::Compression;
use serde_json::{json, Value};
use threadpool::ThreadPool;

use crate::bloom_filter::BloomFilter;
use crate::io::MultiStream;
use crate::s3_util;
use crate::shard::shard_config::WorkDirConfig;
use crate::shard::shard_config::{CompressionConfig, WorkDirConfig};
use crate::shard::{find_objects_matching_patterns, FileCache};
use crate::wimbd::tokens::tokenize;

Expand Down Expand Up @@ -42,8 +39,12 @@ pub fn run(config: DeduperConfig) -> Result<u32, u32> {
let dedupe = config.dedupe.clone();
let bloom_filter = bloom_filter.clone();
let failed_shard_count_ref = failed_shard_count_ref.clone();
let compression = match config.compression.clone() {
Some(c) => c,
None => CompressionConfig::infer(),
};
threadpool.execute(move || {
let result = write_attributes(path, work_dirs, dedupe, bloom_filter);
let result = write_attributes(path, work_dirs, dedupe, compression, bloom_filter);
if let Err(e) = result {
log::error!("Failed to process {:?}: {}", p, e);
failed_shard_count_ref.fetch_add(1, Ordering::Relaxed);
Expand Down Expand Up @@ -79,6 +80,7 @@ fn write_attributes(
docs_location: String,
work_dirs: WorkDirConfig,
dedupe_config: DedupeConfig,
compression: CompressionConfig,
bloom_filter: Arc<BloomFilter>,
) -> Result<(), io::Error> {
let cache = FileCache {
Expand Down Expand Up @@ -110,24 +112,40 @@ fn write_attributes(
{
let local_input = cache.prepare_input(&docs_location)?;

let input_file = OpenOptions::new()
.read(true)
.write(false)
.create(false)
.open(local_input.clone())?;
let reader = BufReader::with_capacity(1024 * 1024, MultiGzDecoder::new(input_file));

let tmp_output = OpenOptions::new()
.read(false)
.write(true)
.create(true)
.truncate(true)
.open(&local_output)?;

let mut writer = BufWriter::with_capacity(
1024 * 1024,
GzEncoder::new(tmp_output, Compression::default()),
);
// The input_compression is either provided by the user or inferred from the file extension.
// We use `infer_compression_from_temp` to deal with local files potentially including `.tmp`
// at the end when they are cached version of S3 files.
let input_compression: String = match compression.input {
Some(ref input) => input.clone(),
None => MultiStream::infer_compression_from_temp(local_input.clone()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comment contradicts what you're doing here - you say you'll use docs_location but use local_input

};

// for the output_compression, it is either provided by the user or we use
// the same compression type as the input.
let output_compression = match compression.output {
Some(ref output) => output.clone(),
None => input_compression.clone(),
};

// let's open a stream to read the input file
let reader = MultiStream::new(
local_input.clone(),
Some(input_compression),
Some(1024 * 1024),
None,
None,
)
.reader()?;

// this is the stream we use to write the output file
let mut writer_stream = MultiStream::new(
local_output.clone(),
Some(output_compression),
Some(1024 * 1024),
None,
None,
)
.writer()?;

let min_content_length = dedupe_config.min_length.unwrap_or(0);
let min_word_count = dedupe_config.min_words.unwrap_or(0);
Expand Down Expand Up @@ -346,8 +364,8 @@ fn write_attributes(
let mut output_object = json!({});
output_object["id"] = data["id"].clone();
output_object["attributes"] = attributes;
serde_json::to_writer(&mut writer, &output_object)?;
writer.write_all(b"\n")?;
serde_json::to_writer(&mut writer_stream, &output_object)?;
writer_stream.write_all(b"\n")?;
}

// only remove the local_input file if it is different from docs_location
Expand All @@ -370,10 +388,11 @@ fn write_attributes(

pub mod deduper_config {
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io;
use std::path::PathBuf;

use crate::bloom_filter::BloomFilterConfig;
use crate::io::MultiStream;
use crate::shard::shard_config::*;

#[derive(Serialize, Deserialize, Clone)]
Expand Down Expand Up @@ -430,12 +449,13 @@ pub mod deduper_config {
pub dedupe: DedupeConfig,
pub bloom_filter: BloomFilterConfig,
pub processes: usize,
pub compression: Option<CompressionConfig>,
}

impl DeduperConfig {
pub fn read_from_file(path: &str) -> Result<DeduperConfig, io::Error> {
let file = File::open(path)?;
let reader = io::BufReader::new(file);
let config_path = PathBuf::from(path);
let reader = MultiStream::with_default(config_path).reader()?;
let config: DeduperConfig = serde_json::from_reader(reader)?;
Ok(config)
}
Expand Down
Loading
Loading