11#pragma once
22
3+ #include " knn-defines.h"
34#include " knn-enumerator.h"
45
56#include < ydb/library/yql/public/udf/udf_helpers.h>
1112using namespace NYql ;
1213using namespace NYql ::NUdf;
1314
14- enum EFormat : ui8 {
15- FloatVector = 1
16- };
17-
18- static constexpr size_t HeaderLen = sizeof (ui8);
19-
20- class TFloatVectorSerializer {
15+ template <typename T, EFormat Format>
16+ class TKnnVectorSerializer {
2117public:
2218 static TUnboxedValue Serialize (const IValueBuilder* valueBuilder, const TUnboxedValue x) {
2319 auto serialize = [&x] (IOutputStream& outStream) {
24- EnumerateVector (x, [&outStream] (float element) { outStream.Write (&element, sizeof (float )); });
25- const EFormat format = EFormat::FloatVector;
20+ EnumerateVector (x, [&outStream] (float floatElement) {
21+ T element = static_cast <T>(floatElement);
22+ outStream.Write (&element, sizeof (T));
23+ });
24+ const EFormat format = Format;
2625 outStream.Write (&format, HeaderLen);
2726 };
2827
2928 if (x.HasFastListLength ()) {
30- auto str = valueBuilder->NewStringNotFilled (HeaderLen + x.GetListLength () * sizeof (float ));
29+ auto str = valueBuilder->NewStringNotFilled (HeaderLen + x.GetListLength () * sizeof (T ));
3130 auto strRef = str.AsStringRef ();
3231 TMemoryOutput memoryOutput (strRef.Data (), strRef.Size ());
3332
@@ -46,45 +45,108 @@ class TFloatVectorSerializer {
4645 const char * buf = str.Data ();
4746 const size_t len = str.Size () - HeaderLen;
4847
49- if (len % sizeof (float ) != 0 )
48+ if (Y_UNLIKELY ( len % sizeof (T ) != 0 ))
5049 return {};
5150
52- const ui32 count = len / sizeof (float );
51+ const ui32 count = len / sizeof (T );
5352
5453 TUnboxedValue* items = nullptr ;
5554 auto res = valueBuilder->NewArray (count, items);
5655
5756 TMemoryInput inStr (buf, len);
5857 for (ui32 i = 0 ; i < count; ++i) {
59- float element;
60- if (inStr.Read (&element, sizeof (float )) != sizeof (float ))
58+ T element;
59+ if (Y_UNLIKELY ( inStr.Read (&element, sizeof (T )) != sizeof (T) ))
6160 return {};
62- *items++ = TUnboxedValuePod{element};
61+ *items++ = TUnboxedValuePod{static_cast < float >( element) };
6362 }
6463
6564 return res.Release ();
6665 }
6766
68- static const TArrayRef<const float > GetArray (const TStringRef& str) {
67+ static const TArrayRef<const T > GetArray (const TStringRef& str) {
6968 const char * buf = str.Data ();
7069 const size_t len = str.Size () - HeaderLen;
7170
72- if (len % sizeof (float ) != 0 )
71+ if (Y_UNLIKELY ( len % sizeof (T ) != 0 ))
7372 return {};
7473
75- const ui32 count = len / sizeof (float );
74+ const ui32 count = len / sizeof (T );
7675
77- return MakeArrayRef (reinterpret_cast <const float *>(buf), count);
76+ return MakeArrayRef (reinterpret_cast <const T *>(buf), count);
7877 }
7978};
8079
80+ // Encode all positive floats as bit 1, negative floats as bit 0.
81+ // So 1024 float vector is serialized in 1024/8=128 bytes.
82+ // Place all bits in ui64. So, only vector sizes divisible by 64 are supported.
83+ class TKnnBitVectorSerializer {
84+ public:
85+ static TUnboxedValue Serialize (const IValueBuilder* valueBuilder, const TUnboxedValue x) {
86+ auto serialize = [&x] (IOutputStream& outStream) {
87+ ui64 accumulator = 0 ;
88+ ui8 filledBits = 0 ;
89+
90+ EnumerateVector (x, [&] (float element) {
91+ if (element > 0 )
92+ accumulator |= 1ll << filledBits;
93+
94+ ++filledBits;
95+ if (filledBits == 64 ) {
96+ outStream.Write (&accumulator, sizeof (ui64));
97+ accumulator = 0 ;
98+ filledBits = 0 ;
99+ }
100+ });
101+
102+ // only vector sizes divisible by 64 are supported
103+ if (Y_UNLIKELY (filledBits))
104+ return false ;
105+
106+ const EFormat format = EFormat::BitVector;
107+ outStream.Write (&format, HeaderLen);
108+
109+ return true ;
110+ };
111+
112+ if (x.HasFastListLength ()) {
113+ auto str = valueBuilder->NewStringNotFilled (HeaderLen + x.GetListLength () / 8 );
114+ auto strRef = str.AsStringRef ();
115+ TMemoryOutput memoryOutput (strRef.Data (), strRef.Size ());
81116
82- class TSerializerFacade {
117+ if (Y_UNLIKELY (!serialize (memoryOutput)))
118+ return {};
119+
120+ return str;
121+ } else {
122+ TString str;
123+ TStringOutput stringOutput (str);
124+
125+ if (Y_UNLIKELY (!serialize (stringOutput)))
126+ return {};
127+
128+ return valueBuilder->NewString (str);
129+ }
130+ }
131+
132+ static const TArrayRef<const ui64> GetArray64 (const TStringRef& str) {
133+ const char * buf = str.Data ();
134+ const size_t len = (str.Size () - HeaderLen) / sizeof (ui64);
135+
136+ return MakeArrayRef (reinterpret_cast <const ui64*>(buf), len);
137+ }
138+ };
139+
140+ class TKnnSerializerFacade {
83141public:
84142 static TUnboxedValue Serialize (EFormat format, const IValueBuilder* valueBuilder, const TUnboxedValue x) {
85143 switch (format) {
86144 case EFormat::FloatVector:
87- return TFloatVectorSerializer::Serialize (valueBuilder, x);
145+ return TKnnVectorSerializer<float , EFormat::FloatVector>::Serialize (valueBuilder, x);
146+ case EFormat::ByteVector:
147+ return TKnnVectorSerializer<ui8, EFormat::ByteVector>::Serialize (valueBuilder, x);
148+ case EFormat::BitVector:
149+ return TKnnBitVectorSerializer::Serialize (valueBuilder, x);
88150 default :
89151 return {};
90152 }
@@ -97,20 +159,29 @@ class TSerializerFacade {
97159 const ui8 format = str.Data ()[str.Size () - HeaderLen];
98160 switch (format) {
99161 case EFormat::FloatVector:
100- return TFloatVectorSerializer::Deserialize (valueBuilder, str);
162+ return TKnnVectorSerializer<float , EFormat::FloatVector>::Deserialize (valueBuilder, str);
163+ case EFormat::ByteVector:
164+ return TKnnVectorSerializer<ui8, EFormat::ByteVector>::Deserialize (valueBuilder, str);
165+ case EFormat::BitVector:
166+ return {};
101167 default :
102168 return {};
103169 }
104170 }
105171
106- static const TArrayRef<const float > GetArray (const TStringRef& str) {
107- if (str.Size () == 0 )
172+ template <typename T>
173+ static const TArrayRef<const T> GetArray (const TStringRef& str) {
174+ if (Y_UNLIKELY (str.Size () == 0 ))
108175 return {};
109176
110177 const ui8 format = str.Data ()[str.Size () - HeaderLen];
111178 switch (format) {
112179 case EFormat::FloatVector:
113- return TFloatVectorSerializer::GetArray (str);
180+ return TKnnVectorSerializer<T, EFormat::FloatVector>::GetArray (str);
181+ case EFormat::ByteVector:
182+ return TKnnVectorSerializer<T, EFormat::ByteVector>::GetArray (str);
183+ case EFormat::BitVector:
184+ return {};
114185 default :
115186 return {};
116187 }
0 commit comments