This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new e980ef8 ARROW-10817: [Rust] [DataFusion] Implement TypedString and
DATE coercion
e980ef8 is described below
commit e980ef843922d8a2a07f0150b4a4ca54b23f280a
Author: Mike Seddon <[email protected]>
AuthorDate: Sun Dec 13 07:34:46 2020 -0500
ARROW-10817: [Rust] [DataFusion] Implement TypedString and DATE coercion
This PR adds support for what the `sqlparser` crate calls `TypedString`
which is basically syntactic sugar for an inline-cast. As this was an effort to
get the `TPC-H` queries behaving correctly I then went a step further and added
support for `Date` (temporal) coercion. I can split this PR if needed.
```sql
where
l_shipdate <= date '1998-09-02'
```
is equivalent to
```sql
where
l_shipdate <= CAST('1998-09-02' AS DATE)
```
FYI I am planning to tackle `INTERVAL` next.
Closes #8892 from seddonm1/typed_string
Authored-by: Mike Seddon <[email protected]>
Signed-off-by: Andrew Lamb <[email protected]>
---
rust/arrow/src/compute/kernels/cast.rs | 46 +++++++++++++++
rust/benchmarks/src/bin/tpch.rs | 75 +++++++++++++++++-------
rust/datafusion/src/physical_plan/expressions.rs | 72 ++++++++++++++++++++++-
rust/datafusion/src/sql/planner.rs | 16 +++++
4 files changed, 185 insertions(+), 24 deletions(-)
diff --git a/rust/arrow/src/compute/kernels/cast.rs
b/rust/arrow/src/compute/kernels/cast.rs
index 7b0c6bc..70acf5a 100644
--- a/rust/arrow/src/compute/kernels/cast.rs
+++ b/rust/arrow/src/compute/kernels/cast.rs
@@ -72,6 +72,7 @@ pub fn can_cast_types(from_type: &DataType, to_type:
&DataType) -> bool {
(Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8,
(Utf8, Date32(DateUnit::Day)) => true,
+ (Utf8, Date64(DateUnit::Millisecond)) => true,
(Utf8, _) => DataType::is_numeric(to_type),
(_, Utf8) => DataType::is_numeric(from_type) || from_type == &Binary,
@@ -399,6 +400,26 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) ->
Result<ArrayRef> {
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
+ Date64(DateUnit::Millisecond) => {
+ use chrono::{NaiveDate, NaiveTime};
+ let zero_time = NaiveTime::from_hms(0, 0, 0);
+ let string_array =
array.as_any().downcast_ref::<StringArray>().unwrap();
+ let mut builder =
PrimitiveBuilder::<Date64Type>::new(string_array.len());
+ for i in 0..string_array.len() {
+ if string_array.is_null(i) {
+ builder.append_null()?;
+ } else {
+ match NaiveDate::parse_from_str(string_array.value(i),
"%Y-%m-%d")
+ {
+ Ok(date) => builder.append_value(
+ date.and_time(zero_time).timestamp_millis() as
i64,
+ )?,
+ Err(_) => builder.append_null()?, // not a valid
date
+ };
+ }
+ }
+ Ok(Arc::new(builder.finish()) as ArrayRef)
+ }
_ => Err(ArrowError::ComputeError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
@@ -2781,6 +2802,31 @@ mod tests {
}
#[test]
+ fn test_cast_utf8_to_date64() {
+ let a = StringArray::from(vec![
+ "2000-01-01", // valid date with leading 0s
+ "2000-2-2", // valid date without leading 0s
+ "2000-00-00", // invalid month and day
+ "2000-01-01T12:00:00", // date + time is invalid
+ "2000", // just a year is invalid
+ ]);
+ let array = Arc::new(a) as ArrayRef;
+ let b = cast(&array,
&DataType::Date64(DateUnit::Millisecond)).unwrap();
+ let c = b.as_any().downcast_ref::<Date64Array>().unwrap();
+
+ // test valid inputs
+ assert_eq!(true, c.is_valid(0)); // "2000-01-01"
+ assert_eq!(946684800000, c.value(0));
+ assert_eq!(true, c.is_valid(1)); // "2000-2-2"
+ assert_eq!(949449600000, c.value(1));
+
+ // test invalid inputs
+ assert_eq!(false, c.is_valid(2)); // "2000-00-00"
+ assert_eq!(false, c.is_valid(3)); // "2000-01-01T12:00:00"
+ assert_eq!(false, c.is_valid(4)); // "2000"
+ }
+
+ #[test]
fn test_can_cast_types() {
// this function attempts to ensure that can_cast_types stays
// in sync with cast. It simply tries all combinations of
diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs
index 2ed9ab0..cd3d9d8 100644
--- a/rust/benchmarks/src/bin/tpch.rs
+++ b/rust/benchmarks/src/bin/tpch.rs
@@ -21,7 +21,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
-use arrow::datatypes::{DataType, Field, Schema};
+use arrow::datatypes::{DataType, DateUnit, Field, Schema};
use arrow::util::pretty;
use datafusion::datasource::parquet::ParquetTable;
use datafusion::datasource::{CsvFile, MemTable, TableProvider};
@@ -187,7 +187,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query:
usize) -> Result<Logic
from
lineitem
where
- l_shipdate <= '1998-09-02'
+ l_shipdate <= date '1998-09-02'
group by
l_returnflag,
l_linestatus
@@ -256,8 +256,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query:
usize) -> Result<Logic
c_mktsegment = 'BUILDING'
and c_custkey = o_custkey
and l_orderkey = o_orderkey
- and o_orderdate < '1995-03-15'
- and l_shipdate > '1995-03-15'
+ and o_orderdate < date '1995-03-15'
+ and l_shipdate > date '1995-03-15'
group by
l_orderkey,
o_orderdate,
@@ -337,8 +337,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query:
usize) -> Result<Logic
and s_nationkey = n_nationkey
and n_regionkey = r_regionkey
and r_name = 'ASIA'
- and o_orderdate >= '1994-01-01'
- and o_orderdate < '1995-01-01'
+ and o_orderdate >= date '1994-01-01'
+ and o_orderdate < date '1995-01-01'
group by
n_name
order by
@@ -363,9 +363,9 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query:
usize) -> Result<Logic
from
lineitem
where
- l_shipdate >= '1994-01-01'
- and l_shipdate < '1995-01-01'
- and l_discount between 0.06 - 0.01 and 0.06 + 0.01
+ l_shipdate >= date '1994-01-01'
+ and l_shipdate < date '1995-01-01'
+ and l_discount > 0.06 - 0.01 and l_discount < 0.06 + 0.01
and l_quantity < 24;"
),
@@ -399,7 +399,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query:
usize) -> Result<Logic
(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY')
or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')
)
- and l_shipdate > '1995-01-01' and l_shipdate <
'1996-12-31'
+ and l_shipdate > date '1995-01-01' and l_shipdate <
date '1996-12-31'
) as shipping
group by
supp_nation,
@@ -442,7 +442,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query:
usize) -> Result<Logic
and n1.n_regionkey = r_regionkey
and r_name = 'AMERICA'
and s_nationkey = n2.n_nationkey
- and o_orderdate between '1995-01-01' and '1996-12-31'
+ and o_orderdate between date '1995-01-01' and date
'1996-12-31'
and p_type = 'ECONOMY ANODIZED STEEL'
) as all_nations
group by
@@ -486,6 +486,39 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query:
usize) -> Result<Logic
o_year desc;"
),
+ // 10 => ctx.create_logical_plan(
+ // "select
+ // c_custkey,
+ // c_name,
+ // sum(l_extendedprice * (1 - l_discount)) as revenue,
+ // c_acctbal,
+ // n_name,
+ // c_address,
+ // c_phone,
+ // c_comment
+ // from
+ // customer,
+ // orders,
+ // lineitem,
+ // nation
+ // where
+ // c_custkey = o_custkey
+ // and l_orderkey = o_orderkey
+ // and o_orderdate >= date '1993-10-01'
+ // and o_orderdate < date '1993-10-01' + interval '3' month
+ // and l_returnflag = 'R'
+ // and c_nationkey = n_nationkey
+ // group by
+ // c_custkey,
+ // c_name,
+ // c_acctbal,
+ // c_phone,
+ // n_name,
+ // c_address,
+ // c_comment
+ // order by
+ // revenue desc;"
+ // ),
10 => ctx.create_logical_plan(
"select
c_custkey,
@@ -504,8 +537,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query:
usize) -> Result<Logic
where
c_custkey = o_custkey
and l_orderkey = o_orderkey
- and o_orderdate >= '1993-10-01'
- and o_orderdate < '1994-01-01'
+ and o_orderdate >= date '1993-10-01'
+ and o_orderdate < date '1994-01-01'
and l_returnflag = 'R'
and c_nationkey = n_nationkey
group by
@@ -606,8 +639,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query:
usize) -> Result<Logic
(l_shipmode = 'MAIL' or l_shipmode = 'SHIP')
and l_commitdate < l_receiptdate
and l_shipdate < l_commitdate
- and l_receiptdate >= '1994-01-01'
- and l_receiptdate < '1995-01-01'
+ and l_receiptdate >= date '1994-01-01'
+ and l_receiptdate < date '1995-01-01'
group by
l_shipmode
order by
@@ -649,8 +682,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query:
usize) -> Result<Logic
part
where
l_partkey = p_partkey
- and l_shipdate >= '1995-09-01'
- and l_shipdate < '1995-10-01';"
+ and l_shipdate >= date '1995-09-01'
+ and l_shipdate < date '1995-10-01';"
),
15 => ctx.create_logical_plan(
@@ -1072,7 +1105,7 @@ fn get_schema(table: &str) -> Schema {
Field::new("o_custkey", DataType::UInt32, false),
Field::new("o_orderstatus", DataType::Utf8, false),
Field::new("o_totalprice", DataType::Float64, false), // decimal
- Field::new("o_orderdate", DataType::Utf8, false),
+ Field::new("o_orderdate", DataType::Date32(DateUnit::Day), false),
Field::new("o_orderpriority", DataType::Utf8, false),
Field::new("o_clerk", DataType::Utf8, false),
Field::new("o_shippriority", DataType::UInt32, false),
@@ -1090,9 +1123,9 @@ fn get_schema(table: &str) -> Schema {
Field::new("l_tax", DataType::Float64, false), // decimal
Field::new("l_returnflag", DataType::Utf8, false),
Field::new("l_linestatus", DataType::Utf8, false),
- Field::new("l_shipdate", DataType::Utf8, false),
- Field::new("l_commitdate", DataType::Utf8, false),
- Field::new("l_receiptdate", DataType::Utf8, false),
+ Field::new("l_shipdate", DataType::Date32(DateUnit::Day), false),
+ Field::new("l_commitdate", DataType::Date32(DateUnit::Day), false),
+ Field::new("l_receiptdate", DataType::Date32(DateUnit::Day),
false),
Field::new("l_shipinstruct", DataType::Utf8, false),
Field::new("l_shipmode", DataType::Utf8, false),
Field::new("l_comment", DataType::Utf8, false),
diff --git a/rust/datafusion/src/physical_plan/expressions.rs
b/rust/datafusion/src/physical_plan/expressions.rs
index 79045d9..ffac95b 100644
--- a/rust/datafusion/src/physical_plan/expressions.rs
+++ b/rust/datafusion/src/physical_plan/expressions.rs
@@ -48,9 +48,9 @@ use arrow::datatypes::{DataType, DateUnit, Schema, TimeUnit};
use arrow::record_batch::RecordBatch;
use arrow::{
array::{
- ArrayRef, BooleanArray, Date32Array, Float32Array, Float64Array,
Int16Array,
- Int32Array, Int64Array, Int8Array, StringArray,
TimestampNanosecondArray,
- UInt16Array, UInt32Array, UInt64Array, UInt8Array,
+ ArrayRef, BooleanArray, Date32Array, Date64Array, Float32Array,
Float64Array,
+ Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
+ TimestampNanosecondArray, UInt16Array, UInt32Array, UInt64Array,
UInt8Array,
},
datatypes::Field,
};
@@ -1135,6 +1135,9 @@ macro_rules! binary_array_op {
DataType::Date32(DateUnit::Day) => {
compute_op!($LEFT, $RIGHT, $OP, Date32Array)
}
+ DataType::Date64(DateUnit::Millisecond) => {
+ compute_op!($LEFT, $RIGHT, $OP, Date64Array)
+ }
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?}",
other
@@ -1227,6 +1230,19 @@ fn string_coercion(lhs_type: &DataType, rhs_type:
&DataType) -> Option<DataType>
}
}
+/// Coercion rules for Temporal columns: the type that both lhs and rhs can be
+/// casted to for the purpose of a date computation
+fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) ->
Option<DataType> {
+ use arrow::datatypes::DataType::*;
+ match (lhs_type, rhs_type) {
+ (Utf8, Date32(DateUnit::Day)) => Some(Date32(DateUnit::Day)),
+ (Date32(DateUnit::Day), Utf8) => Some(Date32(DateUnit::Day)),
+ (Utf8, Date64(DateUnit::Millisecond)) =>
Some(Date64(DateUnit::Millisecond)),
+ (Date64(DateUnit::Millisecond), Utf8) =>
Some(Date64(DateUnit::Millisecond)),
+ _ => None,
+ }
+}
+
/// Coercion rule for numerical types: The type that both lhs and rhs
/// can be casted to for numerical calculation, while maintaining
/// maximum precision
@@ -1288,6 +1304,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType)
-> Option<DataType> {
}
numerical_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
+ .or_else(|| temporal_coercion(lhs_type, rhs_type))
}
// coercion rules that assume an ordered set, such as "less than".
@@ -1301,6 +1318,7 @@ fn order_coercion(lhs_type: &DataType, rhs_type:
&DataType) -> Option<DataType>
numerical_coercion(lhs_type, rhs_type)
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
+ .or_else(|| temporal_coercion(lhs_type, rhs_type))
}
/// Coercion rules for all binary operators. Returns the output type
@@ -2638,6 +2656,54 @@ mod tests {
DataType::Boolean,
vec![true, false]
);
+ test_coercion!(
+ StringArray,
+ DataType::Utf8,
+ vec!["1994-12-13", "1995-01-26"],
+ Date32Array,
+ DataType::Date32(DateUnit::Day),
+ vec![9112, 9156],
+ Operator::Eq,
+ BooleanArray,
+ DataType::Boolean,
+ vec![true, true]
+ );
+ test_coercion!(
+ StringArray,
+ DataType::Utf8,
+ vec!["1994-12-13", "1995-01-26"],
+ Date32Array,
+ DataType::Date32(DateUnit::Day),
+ vec![9113, 9154],
+ Operator::Lt,
+ BooleanArray,
+ DataType::Boolean,
+ vec![true, false]
+ );
+ test_coercion!(
+ StringArray,
+ DataType::Utf8,
+ vec!["1994-12-13", "1995-01-26"],
+ Date64Array,
+ DataType::Date64(DateUnit::Millisecond),
+ vec![787276800000, 791078400000],
+ Operator::Eq,
+ BooleanArray,
+ DataType::Boolean,
+ vec![true, true]
+ );
+ test_coercion!(
+ StringArray,
+ DataType::Utf8,
+ vec!["1994-12-13", "1995-01-26"],
+ Date64Array,
+ DataType::Date64(DateUnit::Millisecond),
+ vec![787276800001, 791078399999],
+ Operator::Lt,
+ BooleanArray,
+ DataType::Boolean,
+ vec![true, false]
+ );
Ok(())
}
diff --git a/rust/datafusion/src/sql/planner.rs
b/rust/datafusion/src/sql/planner.rs
index 562e580..eb9a1a5 100644
--- a/rust/datafusion/src/sql/planner.rs
+++ b/rust/datafusion/src/sql/planner.rs
@@ -629,6 +629,14 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> {
data_type: convert_data_type(data_type)?,
}),
+ SQLExpr::TypedString {
+ ref data_type,
+ ref value,
+ } => Ok(Expr::Cast {
+ expr: Box::new(lit(&**value)),
+ data_type: convert_data_type(data_type)?,
+ }),
+
SQLExpr::IsNull(ref expr) => {
Ok(Expr::IsNull(Box::new(self.sql_to_rex(expr, schema)?)))
}
@@ -1311,6 +1319,14 @@ mod tests {
quick_test(sql, expected);
}
+ #[test]
+ fn select_typedstring() {
+ let sql = "SELECT date '2020-12-10' AS date FROM person";
+ let expected = "Projection: CAST(Utf8(\"2020-12-10\") AS Date32(Day))
AS date\
+ \n TableScan: person projection=None";
+ quick_test(sql, expected);
+ }
+
fn logical_plan(sql: &str) -> Result<LogicalPlan> {
let planner = SqlToRel::new(&MockSchemaProvider {});
let result = DFParser::parse_sql(&sql);