Skip to content

Commit

Permalink
fix: better detecting co-group keys.
Browse files Browse the repository at this point in the history
  • Loading branch information
ashigeru committed Aug 28, 2024
1 parent 0634ca3 commit e1d8e9a
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 15 deletions.
17 changes: 9 additions & 8 deletions src/yugawara/analyzer/details/collect_join_keys.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include <takatori/relation/intermediate/join.h>

#include <takatori/util/assertion.h>
#include <takatori/util/downcast.h>
#include <takatori/util/sequence_view.h>

Expand Down Expand Up @@ -184,13 +183,15 @@ class engine {

// key pairs
for (auto&& key : build_key) {
if (auto term = builder.find(key)) {
if (auto opposite = term->equivalent_key();
opposite
&& std::find(probe_key.begin(), probe_key.end(), *opposite) != probe_key.end()) {
results.emplace_back(std::addressof(key), term.get());
} else if (term->equivalent()) {
temp_term_buf_.emplace_back(std::addressof(key), term.get());
auto [begin, end] = builder.search(key);
for (auto iter = begin; iter != end; ++iter) {
auto&& term = iter->second;
if (auto opposite = term.equivalent_key();
opposite &&
std::find(probe_key.begin(), probe_key.end(), *opposite) != probe_key.end()) {
results.emplace_back(std::addressof(key), std::addressof(term));
} else if (term.equivalent()) {
temp_term_buf_.emplace_back(std::addressof(key), std::addressof(term));
}
}
}
Expand Down
19 changes: 14 additions & 5 deletions src/yugawara/analyzer/details/search_key_term_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace {

using variable_set = search_key_term_builder::variable_set;
using factor_info = search_key_term_builder::factor_info;
using term_map = search_key_term_builder::term_map;
using factor_info_map = search_key_term_builder::factor_info_map;

class key_variable_detector {
Expand Down Expand Up @@ -68,9 +69,9 @@ class factor_collector {
public:
explicit factor_collector(
variable_set const& keys,
factor_info_map& factors) noexcept
: keys_(keys)
, factor_info_map_(factors)
factor_info_map& factors) noexcept :
keys_ { keys },
factor_info_map_ { factors }
{}

void operator()(expression_ref&& condition) {
Expand Down Expand Up @@ -158,11 +159,19 @@ optional_ptr<search_key_term> search_key_term_builder::find(descriptor::variable
build_terms();
}
if (auto it = terms_.find(key); it != terms_.end()) {
return it.value();
return it->second;
}
return {};
}

std::pair<term_map::iterator, term_map::iterator> search_key_term_builder::search(descriptor::variable const& key) {
if (!factors_.empty()) {
build_terms();
}
auto [begin, end] = terms_.equal_range(key);
return std::make_pair(begin, end);
}

void search_key_term_builder::clear() {
keys_.clear();
factors_.clear();
Expand Down Expand Up @@ -204,7 +213,7 @@ void search_key_term_builder::build_term(factor_info_map::iterator begin, factor
std::move(info.term),
std::move(info.factor),
});
return;
break;

case kind::less:
case kind::less_equal:
Expand Down
9 changes: 8 additions & 1 deletion src/yugawara/analyzer/details/search_key_term_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class search_key_term_builder {
::takatori::util::ownership_reference<::takatori::scalar::expression> factor;
};

using term_map = ::tsl::hopscotch_map<
using term_map = std::unordered_multimap<
::takatori::descriptor::variable,
search_key_term,
std::hash<::takatori::descriptor::variable>,
Expand All @@ -73,6 +73,13 @@ class search_key_term_builder {
std::hash<::takatori::descriptor::variable>,
std::equal_to<>>;

/**
* @brief returns the search term for the given key variable.
* @param key the target key variable
* @return the corresponded search term
*/
[[nodiscard]] std::pair<term_map::iterator, term_map::iterator> search(::takatori::descriptor::variable const& key);

private:
variable_set keys_;
factor_info_map factors_;
Expand Down
5 changes: 4 additions & 1 deletion src/yugawara/analyzer/intermediate_plan_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ void intermediate_plan_optimizer::operator()(::takatori::relation::graph_type& g
flow_volume,
options_.runtime_features().contains(runtime_feature::index_join_scan));
}
details::collect_join_keys(graph, flow_volume, compute_join_keys_features(options_.runtime_features()));
details::collect_join_keys(
graph,
flow_volume,
compute_join_keys_features(options_.runtime_features()));
details::rewrite_scan(graph, options_.index_estimator());
details::remove_redundant_conditions(graph);
}
Expand Down
45 changes: 45 additions & 0 deletions test/yugawara/analyzer/details/collect_join_keys_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,4 +468,49 @@ TEST_F(collect_join_keys_test, left_outer) {
EXPECT_EQ(join.condition(), compare(constant(0), varref(cl1), cmp::less_equal));
}

TEST_F(collect_join_keys_test, cogroup_suppress_broadcast) {
relation::graph_type r;
auto cl0 = bindings.stream_variable("cl0");
auto&& inl = r.insert(relation::scan {
bindings(*i0),
{
{ bindings(t0c0), cl0 },
},
});
auto cr0 = bindings.stream_variable("cr0");
auto&& inr = r.insert(relation::scan {
bindings(*i1),
{
{ bindings(t1c0), cr0 },
},
});
auto&& join = r.insert(relation::intermediate::join {
relation::join_kind::left_outer,
land(
compare(cl0, cr0),
compare(varref(cr0), constant(0), cmp::equal)),
});

auto&& out = r.insert(relation::emit { cl0, cr0 });
inl.output() >> join.left();
inr.output() >> join.right();
join.output() >> out.input();

apply(r, { collect_join_keys_feature::cogroup });
EXPECT_GT(inl.output(), join.left());
EXPECT_GT(inr.output(), join.right());

EXPECT_EQ(join.lower().kind(), endpoint_kind::prefixed_inclusive);
ASSERT_EQ(join.lower().keys().size(), 1);
EXPECT_EQ(join.lower().keys()[0].variable(), cr0);
EXPECT_EQ(join.lower().keys()[0].value(), varref(cl0));

EXPECT_EQ(join.upper().kind(), endpoint_kind::prefixed_inclusive);
ASSERT_EQ(join.upper().keys().size(), 1);
EXPECT_EQ(join.upper().keys()[0].variable(), cr0);
EXPECT_EQ(join.upper().keys()[0].value(), varref(cl0));

EXPECT_EQ(join.condition(), compare(varref(cr0), constant(0), cmp::equal));
}

} // namespace yugawara::analyzer::details

0 comments on commit e1d8e9a

Please sign in to comment.