Skip to content

Commit

Permalink
Fix tempfile bugs under distributed training
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Jun 15, 2023
1 parent 5833ee8 commit 28c0240
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion supar/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from supar.utils.common import INF
from supar.utils.fn import binarize, debinarize, kmeans
from supar.utils.logging import get_logger, progress_bar
from supar.utils.parallel import is_dist, is_master
from supar.utils.parallel import gather, is_dist, is_master
from supar.utils.transform import Batch, Transform

logger = get_logger(__name__)
Expand Down Expand Up @@ -157,6 +157,7 @@ def build(
self.sentences = debinarize(self.fbin, meta=True)['sentences']
else:
with tempfile.TemporaryDirectory() as ftemp:
ftemp = gather(ftemp)[0] if is_dist() else ftemp
fbin = self.fbin if self.cache else os.path.join(ftemp, 'data.pt')

@contextmanager
Expand Down Expand Up @@ -188,6 +189,8 @@ def numericalize(sentences, fs, fb, max_len):
self.sentences = debinarize(fbin, meta=True)['sentences']
if not self.cache:
self.sentences = [debinarize(fbin, i) for i in progress_bar(self.sentences)]
if is_dist():
dist.barrier()
# NOTE: the final bucket count is roughly equal to n_buckets
self.buckets = dict(zip(*kmeans(self.sizes, n_buckets)))
self.loader = DataLoader(transform=self.transform,
Expand Down

0 comments on commit 28c0240

Please sign in to comment.