|
9 | 9 | namespace NKikimr::NArrow {
|
10 | 10 |
|
11 | 11 | namespace {
|
12 |
| -template <class TDataContainer, class TStringImpl> |
| 12 | + |
| 13 | +template <class T> |
| 14 | +class TColumnNameAccessor { |
| 15 | +public: |
| 16 | + static const std::string& GetFieldName(const T& val) { |
| 17 | + return val; |
| 18 | + } |
| 19 | + static TString DebugString(const std::vector<T>& items) { |
| 20 | + return JoinSeq(",", items); |
| 21 | + } |
| 22 | +}; |
| 23 | + |
| 24 | +template <> |
| 25 | +class TColumnNameAccessor<std::shared_ptr<arrow::Field>> { |
| 26 | +public: |
| 27 | + static const std::string& GetFieldName(const std::shared_ptr<arrow::Field>& val) { |
| 28 | + return val->name(); |
| 29 | + } |
| 30 | + static TString DebugString(const std::vector<std::shared_ptr<arrow::Field>>& items) { |
| 31 | + TStringBuilder sb; |
| 32 | + for (auto&& i : items) { |
| 33 | + sb << i->name() << ","; |
| 34 | + } |
| 35 | + return sb; |
| 36 | + } |
| 37 | +}; |
| 38 | + |
| 39 | +template <class TDataContainer, class TStringContainer> |
13 | 40 | std::shared_ptr<TDataContainer> ExtractColumnsValidateImpl(
|
14 |
| - const std::shared_ptr<TDataContainer>& srcBatch, const std::vector<TStringImpl>& columnNames) { |
| 41 | + const std::shared_ptr<TDataContainer>& srcBatch, const std::vector<TStringContainer>& columnNames) { |
15 | 42 | std::vector<std::shared_ptr<arrow::Field>> fields;
|
16 | 43 | fields.reserve(columnNames.size());
|
17 | 44 | std::vector<std::shared_ptr<typename NAdapter::TDataBuilderPolicy<TDataContainer>::TColumn>> columns;
|
18 | 45 | columns.reserve(columnNames.size());
|
19 | 46 |
|
20 | 47 | auto srcSchema = srcBatch->schema();
|
21 | 48 | for (auto& name : columnNames) {
|
22 |
| - const int pos = srcSchema->GetFieldIndex(name); |
| 49 | + const int pos = srcSchema->GetFieldIndex(TColumnNameAccessor<TStringContainer>::GetFieldName(name)); |
23 | 50 | if (Y_LIKELY(pos > -1)) {
|
24 | 51 | fields.push_back(srcSchema->field(pos));
|
25 | 52 | columns.push_back(srcBatch->column(pos));
|
@@ -70,16 +97,16 @@ TConclusion<std::shared_ptr<TDataContainer>> AdaptColumnsImpl(
|
70 | 97 | return NAdapter::TDataBuilderPolicy<TDataContainer>::Build(std::make_shared<arrow::Schema>(fields), std::move(columns), srcBatch->num_rows());
|
71 | 98 | }
|
72 | 99 |
|
73 |
| -template <class TDataContainer, class TStringType> |
| 100 | +template <class TDataContainer, class TStringContainer> |
74 | 101 | std::shared_ptr<TDataContainer> ExtractImpl(const TColumnOperator::EExtractProblemsPolicy& policy,
|
75 |
| - const std::shared_ptr<TDataContainer>& incoming, const std::vector<TStringType>& columnNames) { |
| 102 | + const std::shared_ptr<TDataContainer>& incoming, const std::vector<TStringContainer>& columnNames) { |
76 | 103 | AFL_VERIFY(incoming);
|
77 | 104 | AFL_VERIFY(columnNames.size());
|
78 | 105 | auto result = ExtractColumnsValidateImpl(incoming, columnNames);
|
79 | 106 | switch (policy) {
|
80 | 107 | case TColumnOperator::EExtractProblemsPolicy::Verify:
|
81 | 108 | AFL_VERIFY((ui32)result->num_columns() == columnNames.size())("schema", incoming->schema()->ToString())(
|
82 |
| - "required", JoinSeq(",", columnNames)); |
| 109 | + "required", TColumnNameAccessor<TStringContainer>::DebugString(columnNames)); |
83 | 110 | break;
|
84 | 111 | case TColumnOperator::EExtractProblemsPolicy::Null:
|
85 | 112 | if ((ui32)result->num_columns() != columnNames.size()) {
|
@@ -123,6 +150,16 @@ std::shared_ptr<arrow::Table> TColumnOperator::Extract(
|
123 | 150 | return ExtractImpl(AbsentColumnPolicy, incoming, columnNames);
|
124 | 151 | }
|
125 | 152 |
|
| 153 | +std::shared_ptr<arrow::Table> TColumnOperator::Extract( |
| 154 | + const std::shared_ptr<arrow::Table>& incoming, const std::vector<std::shared_ptr<arrow::Field>>& columns) { |
| 155 | + return ExtractImpl(AbsentColumnPolicy, incoming, columns); |
| 156 | +} |
| 157 | + |
| 158 | +std::shared_ptr<arrow::RecordBatch> TColumnOperator::Extract( |
| 159 | + const std::shared_ptr<arrow::RecordBatch>& incoming, const std::vector<std::shared_ptr<arrow::Field>>& columns) { |
| 160 | + return ExtractImpl(AbsentColumnPolicy, incoming, columns); |
| 161 | +} |
| 162 | + |
126 | 163 | std::shared_ptr<arrow::RecordBatch> TColumnOperator::Extract(
|
127 | 164 | const std::shared_ptr<arrow::RecordBatch>& incoming, const std::vector<TString>& columnNames) {
|
128 | 165 | return ExtractImpl(AbsentColumnPolicy, incoming, columnNames);
|
@@ -171,5 +208,47 @@ NKikimr::TConclusion<std::shared_ptr<arrow::Table>> TColumnOperator::Reorder(
|
171 | 208 | const std::shared_ptr<arrow::Table>& incoming, const std::vector<TString>& columnNames) {
|
172 | 209 | return ReorderImpl(incoming, columnNames);
|
173 | 210 | }
|
| 211 | +namespace { |
| 212 | +template <class TDataContainer, class TSchemaImpl> |
| 213 | +TConclusion<TSchemaSubset> BuildSequentialSubsetImpl( |
| 214 | + const std::shared_ptr<TDataContainer>& srcBatch, const std::shared_ptr<TSchemaImpl>& dstSchema) { |
| 215 | + AFL_VERIFY(srcBatch); |
| 216 | + AFL_VERIFY(dstSchema); |
| 217 | + if (dstSchema->num_fields() < srcBatch->schema()->num_fields()) { |
| 218 | + AFL_ERROR(NKikimrServices::ARROW_HELPER)("event", "incorrect columns set: destination must been wider than source")( |
| 219 | + "source", srcBatch->schema()->ToString())("destination", dstSchema->ToString()); |
| 220 | + return TConclusionStatus::Fail("incorrect columns set: destination must been wider than source"); |
| 221 | + } |
| 222 | + std::set<ui32> fieldIdx; |
| 223 | + auto itSrc = srcBatch->schema()->fields().begin(); |
| 224 | + auto itDst = dstSchema->fields().begin(); |
| 225 | + while (itSrc != srcBatch->schema()->fields().end() && itDst != dstSchema->fields().end()) { |
| 226 | + if ((*itSrc)->name() != (*itDst)->name()) { |
| 227 | + ++itDst; |
| 228 | + } else { |
| 229 | + fieldIdx.emplace(itDst - dstSchema->fields().begin()); |
| 230 | + if (!(*itDst)->Equals(*itSrc)) { |
| 231 | + AFL_ERROR(NKikimrServices::ARROW_HELPER)("event", "cannot_use_incoming_batch")("reason", "invalid_column_type")( |
| 232 | + "column_type", (*itDst)->ToString(true))("incoming_type", (*itSrc)->ToString(true)); |
| 233 | + return TConclusionStatus::Fail("incompatible column types"); |
| 234 | + } |
| 235 | + |
| 236 | + ++itDst; |
| 237 | + ++itSrc; |
| 238 | + } |
| 239 | + } |
| 240 | + if (itDst == dstSchema->fields().end() && itSrc != srcBatch->schema()->fields().end()) { |
| 241 | + AFL_ERROR(NKikimrServices::ARROW_HELPER)("event", "incorrect columns order in source set")("source", srcBatch->schema()->ToString())( |
| 242 | + "destination", dstSchema->ToString()); |
| 243 | + return TConclusionStatus::Fail("incorrect columns order in source set"); |
| 244 | + } |
| 245 | + return TSchemaSubset(fieldIdx, dstSchema->num_fields()); |
| 246 | +} |
| 247 | +} // namespace |
| 248 | + |
| 249 | +TConclusion<TSchemaSubset> TColumnOperator::BuildSequentialSubset( |
| 250 | + const std::shared_ptr<arrow::RecordBatch>& incoming, const std::shared_ptr<NArrow::TSchemaLite>& dstSchema) { |
| 251 | + return BuildSequentialSubsetImpl(incoming, dstSchema); |
| 252 | +} |
174 | 253 |
|
175 | 254 | } // namespace NKikimr::NArrow
|
0 commit comments