From 6a9eef2c2c1a7f8ed97bebb522454336f9c9a7c5 Mon Sep 17 00:00:00 2001 From: yzhangcs Date: Wed, 17 May 2023 22:01:56 +0800 Subject: [PATCH] Support back and forth conversion of sents/bytes --- supar/utils/transform.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/supar/utils/transform.py b/supar/utils/transform.py index 7d2b92e8..1db865ff 100644 --- a/supar/utils/transform.py +++ b/supar/utils/transform.py @@ -2,12 +2,15 @@ from __future__ import annotations -from typing import Any, Iterable, Optional, Tuple +import os +import pickle +import struct +from io import BytesIO +from typing import Any, Iterable, Optional import torch from torch.distributions.utils import lazy_property -from supar.utils.fn import debinarize from supar.utils.logging import get_logger, progress_bar logger = get_logger(__name__) @@ -212,6 +215,29 @@ def numericalize(self, fields): self.pad_index = fields[0].pad_index return self + def tobytes(self) -> bytes: + bufs, fields = [], {} + for name, value in self.fields.items(): + if isinstance(value, torch.Tensor): + fields[name] = value + buf, dtype = value.numpy().tobytes(), value.dtype + self.fields[name] = (len(buf), dtype) + bufs.append(buf) + buf, sentence = b''.join(bufs), pickle.dumps(self) + for name, value in fields.items(): + self.fields[name] = value + return buf + sentence + struct.pack('LL', len(buf), len(sentence)) + @classmethod - def from_cache(cls, fbin: str, pos: Tuple[int, int]) -> Sentence: - return debinarize(fbin, pos) + def frombuffer(cls, buf: bytes) -> Sentence: + mm = BytesIO(buf) + mm.seek(-len(struct.pack('LL', 0, 0)), os.SEEK_END) + offset, length = struct.unpack('LL', mm.read()) + mm.seek(offset) + sentence = pickle.loads(mm.read(length)) + mm.seek(0) + for name, value in sentence.fields.items(): + if isinstance(value, tuple) and isinstance(value[1], torch.dtype): + length, dtype = value + sentence.fields[name] = torch.frombuffer(bytearray(mm.read(length)), dtype=dtype) + return sentence