Skip to content

Commit

Permalink
Build dataloaders with multi-processes
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed May 23, 2023
1 parent faa7a56 commit ab9ba00
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions supar/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import itertools
import os
import queue
import shutil
import tempfile
import threading
from contextlib import contextmanager
Expand All @@ -14,12 +13,13 @@
import pathos.multiprocessing as mp
import torch
import torch.distributed as dist
from torch.distributions.utils import lazy_property

from supar.utils.common import INF
from supar.utils.fn import binarize, debinarize, kmeans
from supar.utils.logging import get_logger, progress_bar
from supar.utils.parallel import is_dist, is_master
from supar.utils.transform import Batch, Transform
from torch.distributions.utils import lazy_property

logger = get_logger(__name__)

Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(

if cache:
if not isinstance(data, str) or not os.path.exists(data):
raise FileNotFoundError("Only files are allowed for binarization, but not found")
raise FileNotFoundError("Please specify a valid file path for caching!")
if self.bin is None:
self.fbin = data + '.pt'
else:
Expand Down Expand Up @@ -150,21 +150,19 @@ def build(
even: bool = True,
n_workers: int = 0,
pin_memory: bool = True,
chunk_size: int = 1000,
chunk_size: int = 10000,
) -> Dataset:
# numericalize all fields
if not self.cache:
self.sentences = [i for i in self.transform(self.sentences) if len(i) < self.max_len]
# if not forced and the binarized file already exists, directly load the meta file
if self.cache and os.path.exists(self.fbin) and not self.binarize:
self.sentences = debinarize(self.fbin, meta=True)['sentences']
else:
# if not forced to do binarization and the binarized file already exists, directly load the meta file
if os.path.exists(self.fbin) and not self.binarize:
self.sentences = debinarize(self.fbin, meta=True)['sentences']
else:
with tempfile.TemporaryDirectory() as ftemp:
fbin = self.fbin if self.cache else os.path.join(ftemp, 'data.pt')

@contextmanager
def cache(sentences):
ftemp = tempfile.mkdtemp()
fs = os.path.join(ftemp, 'sentences')
fb = os.path.join(ftemp, os.path.basename(self.fbin))
fb = os.path.join(ftemp, os.path.basename(fbin))
global global_transform
global_transform = self.transform
sentences = binarize({'sentences': progress_bar(sentences)}, fs)[1]['sentences']
Expand All @@ -173,23 +171,23 @@ def cache(sentences):
for i, s in enumerate(range(0, len(sentences), chunk_size)))
finally:
del global_transform
shutil.rmtree(ftemp)

def numericalize(sentences, fs, fb, max_len):
sentences = global_transform((debinarize(fs, sentence) for sentence in sentences))
sentences = [i for i in sentences if len(i) < max_len]
return binarize({'sentences': sentences, 'sizes': [sentence.size for sentence in sentences]}, fb)[0]

logger.info(f"Seeking to cache the data to {self.fbin} first")
logger.info(f"Caching the data to {fbin}")
# numericalize the fields of each sentence
if is_master():
with cache(self.transform.load(self.data, **self.kwargs)) as chunks, mp.Pool(32) as pool:
results = [pool.apply_async(numericalize, chunk) for chunk in chunks]
self.sentences = binarize((r.get() for r in results), self.fbin, merge=True)[1]['sentences']
self.sentences = binarize((r.get() for r in results), fbin, merge=True)[1]['sentences']
if is_dist():
dist.barrier()
if not is_master():
self.sentences = debinarize(self.fbin, meta=True)['sentences']
self.sentences = debinarize(fbin, meta=True)['sentences']
if not self.cache:
self.sentences = [debinarize(fbin, i) for i in progress_bar(self.sentences)]
# NOTE: the final bucket count is roughly equal to n_buckets
self.buckets = dict(zip(*kmeans(self.sizes, n_buckets)))
self.loader = DataLoader(transform=self.transform,
Expand Down

0 comments on commit ab9ba00

Please sign in to comment.