Skip to content

Commit

Permalink
Support back and forth conversion of sents/bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed May 17, 2023
1 parent e233c20 commit 6a9eef2
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions supar/utils/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

from __future__ import annotations

from typing import Any, Iterable, Optional, Tuple
import os
import pickle
import struct
from io import BytesIO
from typing import Any, Iterable, Optional

import torch
from torch.distributions.utils import lazy_property

from supar.utils.fn import debinarize
from supar.utils.logging import get_logger, progress_bar

logger = get_logger(__name__)
Expand Down Expand Up @@ -212,6 +215,29 @@ def numericalize(self, fields):
self.pad_index = fields[0].pad_index
return self

def tobytes(self) -> bytes:
bufs, fields = [], {}
for name, value in self.fields.items():
if isinstance(value, torch.Tensor):
fields[name] = value
buf, dtype = value.numpy().tobytes(), value.dtype
self.fields[name] = (len(buf), dtype)
bufs.append(buf)
buf, sentence = b''.join(bufs), pickle.dumps(self)
for name, value in fields.items():
self.fields[name] = value
return buf + sentence + struct.pack('LL', len(buf), len(sentence))

@classmethod
def from_cache(cls, fbin: str, pos: Tuple[int, int]) -> Sentence:
return debinarize(fbin, pos)
def frombuffer(cls, buf: bytes) -> Sentence:
mm = BytesIO(buf)
mm.seek(-len(struct.pack('LL', 0, 0)), os.SEEK_END)
offset, length = struct.unpack('LL', mm.read())
mm.seek(offset)
sentence = pickle.loads(mm.read(length))
mm.seek(0)
for name, value in sentence.fields.items():
if isinstance(value, tuple) and isinstance(value[1], torch.dtype):
length, dtype = value
sentence.fields[name] = torch.frombuffer(bytearray(mm.read(length)), dtype=dtype)
return sentence

0 comments on commit 6a9eef2

Please sign in to comment.