Skip to content

Commit

Permalink
reaching 99.9% parity with sentence piece xlnet model
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei Alonichau committed Apr 24, 2020
1 parent afc8fdc commit 3eb6ff8
Show file tree
Hide file tree
Showing 21 changed files with 37,084 additions and 21 deletions.
12 changes: 6 additions & 6 deletions blingfireclient.library/inc/FATokenSegmentationTools_1best_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class FATokenSegmentationTools_1best_t {
// to keep track of arc data
struct _TArc {

int _Begin; // the begging position of the ssegment
int _Id; // ID of a segment from the vocab
float _Score; // cumulative score
int _Begin; // the begging position of the ssegment
int _Id; // ID of a segment from the vocab
double _Score; // cumulative score

public:
_TArc ():
Expand All @@ -90,7 +90,7 @@ FATokenSegmentationTools_1best_t < Ty >::
m_pMealy (NULL),
m_pK2I (NULL),
m_pI2Info (NULL),
m_UnkScore (-100000.0f) // this is guaranteed lower than any of the segment scores
m_UnkScore (-100000.0) // this is guaranteed lower than any of the segment scores
{}


Expand Down Expand Up @@ -122,7 +122,7 @@ inline void FATokenSegmentationTools_1best_t < Ty >::
const float Score = *((const float*) &(pValues [1]));

// compute previous score given the start
const float prevScore = 0 < start ? pArcs [start - 1]._Score : 0.0f;
const double prevScore = 0 < start ? pArcs [start - 1]._Score : 0;

// get a pointer to the arc object
_TArc * pA = pArcs + end;
Expand All @@ -149,7 +149,7 @@ inline void FATokenSegmentationTools_1best_t < Ty >::
_TArc * pPrevA = pA - 1;

// compute previous score given the start
const float prevScore = 0 < start ? pPrevA->_Score : 0.0f;
const double prevScore = 0 < start ? pPrevA->_Score : 0;

// set the arc, if it was never set then it has smallest negative float number
// so the condition is always true
Expand Down
4 changes: 2 additions & 2 deletions blingfiretools/any_test/any_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ int __cdecl main (int argc, char ** argv)

// tests

void* hModel = (*g_LoadModelPtr)("bert_base_tok.bin");
void* hModel = (*g_LoadModelPtr)("xlnet.bin");

const int MaxIdCount = 128;
int Ids [MaxIdCount];
int Starts [MaxIdCount];
int Ends [MaxIdCount];

std::string in1 ("Sergei Alonichau I saw a girl with a telescope.");
std::string in1 ("1 ½ fruit per day .");

int IdCount = (*g_TextToIdsPtr)(hModel, in1.c_str(), in1.length(), Ids, MaxIdCount, 100);
for(int i = 0; i < IdCount; ++i) {
Expand Down
26 changes: 16 additions & 10 deletions blingfiretools/blingfiretokdll/blingfiretokdll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void InitializeWbdSbd()

// WHITESPACE [\x0004-\x0020\x007F-\x009F\x00A0\x2000-\x200B\x200E\x200F\x202F\x205F\x2060\x2420\x2424\x3000\xFEFF]
#define __FAIsWhiteSpace__(C) ( \
(C <= 0x20 || (C >= 0x7f && C <= 0x9f) || C == 0xa0 || (C >= 0x2000 && C <= 0x200b) || \
(C <= 0x20 || (C >= 0x7f && C <= 0xa0) || (C >= 0x2000 && C <= 0x200c) || \
C == 0x200e || C == 0x200f || C == 0x202f || C == 0x205f || C == 0x2060 || C == 0x2420 || \
C == 0x2424 || C == 0x3000 || C == 0xfeff) \
)
Expand Down Expand Up @@ -1072,30 +1072,36 @@ const int TextToIdsWithOffsets_sp(
std::vector< int > utf32norm_offsets;
int * pNormOffsets = NULL;

// do normalization if needed
// do normalization, if needed
if (NULL != pCharMap) {

utf32input_norm.resize(InUtf8StrByteCount + 1);
const int MaxNormBuffSize = (InUtf8StrByteCount + 1) * 2;
utf32input_norm.resize(MaxNormBuffSize);
pNormBuff = utf32input_norm.data();
if (NULL == pNormBuff) {
return 0;
}
if (fNeedOffsets) {
utf32norm_offsets.resize(InUtf8StrByteCount + 1);
utf32norm_offsets.resize(MaxNormBuffSize);
pNormOffsets = utf32norm_offsets.data();
if (NULL == pNormOffsets) {
return 0;
}
}

// do the normalization for the entire input
BuffSize = fNeedOffsets ?
::FANormalize(pBuff, BuffSize, pNormBuff, pNormOffsets, InUtf8StrByteCount + 1, pCharMap) :
::FANormalize(pBuff, BuffSize, pNormBuff, InUtf8StrByteCount + 1, pCharMap);
if (BuffSize <= 0 || BuffSize > InUtf8StrByteCount + 1) {
const int ActualNormBuffSize = fNeedOffsets ?
::FANormalize(pBuff, BuffSize, pNormBuff, pNormOffsets, MaxNormBuffSize, pCharMap) :
::FANormalize(pBuff, BuffSize, pNormBuff, MaxNormBuffSize, pCharMap);

if (ActualNormBuffSize <= 0 || ActualNormBuffSize > MaxNormBuffSize) {
pCharMap = NULL;
// don't proceed without normalization, TODO: 99% times it does not change anything... so it is ok to proceed
return 0;
} else {
BuffSize = ActualNormBuffSize;
pBuff = pNormBuff;
}
pBuff = pNormBuff;
}

// Replace every space sequence with U+2581 in-place
Expand Down Expand Up @@ -1140,7 +1146,7 @@ const int TextToIdsWithOffsets_sp(
BuffSize = j;

// do the segmentation
const int WbdResMaxSize = InUtf8StrByteCount * 3;
const int WbdResMaxSize = BuffSize * 3;
std::vector< int > WbdResults(WbdResMaxSize);
int * pWbdResults = WbdResults.data ();
const int WbdOutSize = pModelData->m_SegEngine.Process (pBuff, BuffSize, pWbdResults, WbdResMaxSize, 0);
Expand Down
Binary file modified dist-pypi/blingfire/libblingfiretokdll.so
Binary file not shown.
Binary file added dist-pypi/blingfire/xlnet.bin
Binary file not shown.
Binary file added dist-pypi/blingfire/xlnet_nonorm.bin
Binary file not shown.
Binary file modified ldbsrc/ldb/xlnet.bin
Binary file not shown.
Binary file added ldbsrc/ldb/xlnet_nonorm.bin
Binary file not shown.
32 changes: 31 additions & 1 deletion ldbsrc/xlnet/README.TXT
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# build everything and call set_env from the BlingFire directory to set the PATH (see wiki)

# change directory to the directory with source files
cd <BlingFire>/ldbsrc/xlnet

# export the model:
spm_export_vocab --model spiece.model --output spiece.model.exportvocab.txt --output_format txt
Expand All @@ -8,5 +12,31 @@ cat spiece.model.exportvocab.txt | awk 'BEGIN {FS="\t"} NF == 2 { if (NR > 1) {
# zip it:
zip pos.dict.utf8.zip pos.dict.utf8

# build as usual
# optional step: create a charmap for NFC --> NFKC normalization or anything else
python ./generate_charmap.py > charmap.utf8

# create options.small and ldb.conf.small by example

# build all as usual
cd <BlingFire>/ldbsrc
make -f Makefile.gnu lang=xlnet all

# after the succuessful compilation there should be a new file xlnet.bin inside the ldb directory
ls -l ldb/xlnet.bin

# let's run the parity verification
cat input.utf8 | python ../scripts/test_bling_with_offsets.py -m ldb/xlnet.bin -p xlnet/spiece.model > output.utf8

# count number of mismatches
> cat output.utf8 | awk '/ERROR:/' | wc -l
2133

# count total number of input texts
> wc -l input.utf8
2052515 input.utf8

Input level parity is close to 99.9%, token level parity will be much higher than this. Upon inspection,
most of the disparities are due to sentence piece treating some Unicode control characters as a letter and some as a space,
Bling Fire treats them all as a space.


Loading

0 comments on commit 3eb6ff8

Please sign in to comment.