This is an automated email from the ASF dual-hosted git repository. kszucs pushed a commit to branch maint-1.0.x in repository https://gitbox.apache.org/repos/asf/arrow.git
commit f92558898ebd0e8dd4dbce6fbd64dcf6db8c340d Author: Benjamin Kietzman <[email protected]> AuthorDate: Mon Aug 10 10:41:37 2020 -0700 ARROW-9606: [C++][Dataset] Support `"a"_.In(<>).Assume(<compound>)` This enables predicate pushdown of `%in%` filters in the presence of compound partition information @mpjdem Closes #7911 from bkietz/9606-simplify-isin-query-nested-partitions Authored-by: Benjamin Kietzman <[email protected]> Signed-off-by: Neal Richardson <[email protected]> --- cpp/src/arrow/dataset/filter.cc | 73 ++++++++++++++++++++++++++++-------- cpp/src/arrow/dataset/filter.h | 10 +++++ cpp/src/arrow/dataset/filter_test.cc | 13 ++++++- cpp/src/arrow/dataset/partition.cc | 39 +++++++++---------- r/tests/testthat/test-dataset.R | 10 +++++ 5 files changed, 106 insertions(+), 39 deletions(-) diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index f9a36ee..c4120eb 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -260,22 +260,36 @@ std::shared_ptr<Expression> Invert(const Expression& expr) { } std::shared_ptr<Expression> Expression::Assume(const Expression& given) const { - if (given.type() == ExpressionType::COMPARISON) { + std::shared_ptr<Expression> out; + + DCHECK_OK(VisitConjunctionMembers(given, [&](const Expression& given) { + if (out != nullptr) { + return Status::OK(); + } + + if (given.type() != ExpressionType::COMPARISON) { + return Status::OK(); + } + const auto& given_cmp = checked_cast<const ComparisonExpression&>(given); - if (given_cmp.op() == CompareOperator::EQUAL) { - if (this->Equals(given_cmp.left_operand()) && - given_cmp.right_operand()->type() == ExpressionType::SCALAR) { - return given_cmp.right_operand(); - } + if (given_cmp.op() != CompareOperator::EQUAL) { + return Status::OK(); + } - if (this->Equals(given_cmp.right_operand()) && - given_cmp.left_operand()->type() == ExpressionType::SCALAR) { - return given_cmp.left_operand(); - } + if (this->Equals(given_cmp.left_operand())) { + out = given_cmp.right_operand(); + return Status::OK(); } - } - return Copy(); + if (this->Equals(given_cmp.right_operand())) { + out = given_cmp.left_operand(); + return Status::OK(); + } + + return Status::OK(); + })); + + return out ? out : Copy(); } std::shared_ptr<Expression> ComparisonExpression::Assume(const Expression& given) const { @@ -570,15 +584,30 @@ std::shared_ptr<Expression> InExpression::Assume(const Expression& given) const return scalar(set_->null_count() > 0); } - const auto& value = checked_cast<const ScalarExpression&>(*operand).value(); + Datum set, value; + if (set_->type_id() == Type::DICTIONARY) { + const auto& dict_set = checked_cast<const DictionaryArray&>(*set_); + auto maybe_decoded = compute::Take(dict_set.dictionary(), dict_set.indices()); + auto maybe_value = checked_cast<const DictionaryScalar&>( + *checked_cast<const ScalarExpression&>(*operand).value()) + .GetEncodedValue(); + if (!maybe_decoded.ok() || !maybe_value.ok()) { + return std::make_shared<InExpression>(std::move(operand), set_); + } + set = *maybe_decoded; + value = *maybe_value; + } else { + set = set_; + value = checked_cast<const ScalarExpression&>(*operand).value(); + } compute::CompareOptions eq(CompareOperator::EQUAL); - Result<Datum> out_result = compute::Compare(set_, value, eq); - if (!out_result.ok()) { + Result<Datum> maybe_out = compute::Compare(set, value, eq); + if (!maybe_out.ok()) { return std::make_shared<InExpression>(std::move(operand), set_); } - Datum out = out_result.ValueOrDie(); + Datum out = maybe_out.ValueOrDie(); DCHECK(out.is_array()); DCHECK_EQ(out.type()->id(), Type::BOOL); @@ -1045,6 +1074,18 @@ Result<std::shared_ptr<Expression>> InsertImplicitCasts(const Expression& expr, return VisitExpression(expr, InsertImplicitCastsImpl{schema}); } +Status VisitConjunctionMembers(const Expression& expr, + const std::function<Status(const Expression&)>& visitor) { + if (expr.type() == ExpressionType::AND) { + const auto& and_ = checked_cast<const AndExpression&>(expr); + RETURN_NOT_OK(VisitConjunctionMembers(*and_.left_operand(), visitor)); + RETURN_NOT_OK(VisitConjunctionMembers(*and_.right_operand(), visitor)); + return Status::OK(); + } + + return visitor(expr); +} + std::vector<std::string> FieldsInExpression(const Expression& expr) { struct { void operator()(const FieldExpression& expr) { fields.push_back(expr.name()); } diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h index b7d4655..70279e1 100644 --- a/cpp/src/arrow/dataset/filter.h +++ b/cpp/src/arrow/dataset/filter.h @@ -575,6 +575,16 @@ auto VisitExpression(const Expression& expr, Visitor&& visitor) return visitor(internal::checked_cast<const CustomExpression&>(expr)); } +/// \brief Visit each subexpression of an arbitrarily nested conjunction. +/// +/// | given | visit | +/// |--------------------------------|---------------------------------------------| +/// | a and b | visit(a), visit(b) | +/// | c | visit(c) | +/// | (a and b) and ((c or d) and e) | visit(a), visit(b), visit(c or d), visit(e) | +ARROW_DS_EXPORT Status VisitConjunctionMembers( + const Expression& expr, const std::function<Status(const Expression&)>& visitor); + /// \brief Insert CastExpressions where necessary to make a valid expression. ARROW_DS_EXPORT Result<std::shared_ptr<Expression>> InsertImplicitCasts( const Expression& expr, const Schema& schema); diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 7c3c1a2..f044c41 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -72,7 +72,8 @@ class ExpressionsTest : public ::testing::Test { std::shared_ptr<DataType> ns = timestamp(TimeUnit::NANO); std::shared_ptr<Schema> schema_ = schema({field("a", int32()), field("b", int32()), field("f", float64()), - field("s", utf8()), field("ts", ns)}); + field("s", utf8()), field("ts", ns), + field("dict_b", dictionary(int32(), int32()))}); std::shared_ptr<Expression> always = scalar(true); std::shared_ptr<Expression> never = scalar(false); }; @@ -131,6 +132,16 @@ TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) { AssertSimplifiesTo("b"_ > 5, "b"_ == 3 or "b"_ == 6, "b"_ > 5); AssertSimplifiesTo("b"_ > 7, "b"_ == 3 or "b"_ == 6, *never); AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ > 6 and "b"_ < 13, "b"_ < 10); + + auto set_123 = ArrayFromJSON(int32(), R"([1, 2, 3])"); + AssertSimplifiesTo("b"_.In(set_123), "a"_ == 3 and "b"_ == 3, *always); + AssertSimplifiesTo("b"_.In(set_123), "a"_ == 3 and "b"_ == 5, *never); + + auto dict_set_123 = + DictArrayFromJSON(dictionary(int32(), int32()), R"([1,2,0])", R"([1,2,3])"); + ASSERT_OK_AND_ASSIGN(auto b_dict, dict_set_123->GetScalar(0)); + AssertSimplifiesTo("b_dict"_.In(dict_set_123), "a"_ == 3 and "b_dict"_ == b_dict, + *always); } TEST_F(ExpressionsTest, SimplificationToNull) { diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 9f497a1..b4a38d2 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -96,32 +96,27 @@ Status KeyValuePartitioning::VisitKeys( const Expression& expr, const std::function<Status(const std::string& name, const std::shared_ptr<Scalar>& value)>& visitor) { - if (expr.type() == ExpressionType::AND) { - const auto& and_ = checked_cast<const AndExpression&>(expr); - RETURN_NOT_OK(VisitKeys(*and_.left_operand(), visitor)); - RETURN_NOT_OK(VisitKeys(*and_.right_operand(), visitor)); - return Status::OK(); - } - - if (expr.type() != ExpressionType::COMPARISON) { - return Status::OK(); - } + return VisitConjunctionMembers(expr, [visitor](const Expression& expr) { + if (expr.type() != ExpressionType::COMPARISON) { + return Status::OK(); + } - const auto& cmp = checked_cast<const ComparisonExpression&>(expr); - if (cmp.op() != compute::CompareOperator::EQUAL) { - return Status::OK(); - } + const auto& cmp = checked_cast<const ComparisonExpression&>(expr); + if (cmp.op() != compute::CompareOperator::EQUAL) { + return Status::OK(); + } - auto lhs = cmp.left_operand().get(); - auto rhs = cmp.right_operand().get(); - if (lhs->type() != ExpressionType::FIELD) std::swap(lhs, rhs); + auto lhs = cmp.left_operand().get(); + auto rhs = cmp.right_operand().get(); + if (lhs->type() != ExpressionType::FIELD) std::swap(lhs, rhs); - if (lhs->type() != ExpressionType::FIELD || rhs->type() != ExpressionType::SCALAR) { - return Status::OK(); - } + if (lhs->type() != ExpressionType::FIELD || rhs->type() != ExpressionType::SCALAR) { + return Status::OK(); + } - return visitor(checked_cast<const FieldExpression*>(lhs)->name(), - checked_cast<const ScalarExpression*>(rhs)->value()); + return visitor(checked_cast<const FieldExpression*>(lhs)->name(), + checked_cast<const ScalarExpression*>(rhs)->value()); + }); } Result<std::unordered_map<std::string, std::shared_ptr<Scalar>>> diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index e78dfd3..fe6da87 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -391,6 +391,16 @@ test_that("filter() with %in%", { collect(), tibble(int = df1$int[c(3, 4, 6)], part = 1) ) + +# ARROW-9606: bug in %in% filter on partition column with >1 partition columns + ds <- open_dataset(hive_dir) + expect_equivalent( + ds %>% + filter(group %in% 2) %>% + select(names(df2)) %>% + collect(), + df2 + ) }) test_that("filter() on timestamp columns", {
