This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 002ca5d Lead/lag window function with offset and default value
arguments (#687)
002ca5d is described below
commit 002ca5d1e81dd45247b90f0f6e340ff5fec3a747
Author: Javier Goday <[email protected]>
AuthorDate: Wed Jul 14 22:22:17 2021 +0200
Lead/lag window function with offset and default value arguments (#687)
---
.../src/physical_plan/expressions/lead_lag.rs | 94 +++++++++++++++++++++-
datafusion/src/physical_plan/type_coercion.rs | 35 ++++++--
datafusion/src/physical_plan/window_functions.rs | 14 +++-
datafusion/src/physical_plan/windows.rs | 67 ++++++++++++++-
.../sqls/simple_window_lead_built_in_functions.sql | 27 +++++++
integration-tests/test_psql_parity.py | 2 +-
6 files changed, 221 insertions(+), 18 deletions(-)
diff --git a/datafusion/src/physical_plan/expressions/lead_lag.rs
b/datafusion/src/physical_plan/expressions/lead_lag.rs
index 352d97c..d1f6c19 100644
--- a/datafusion/src/physical_plan/expressions/lead_lag.rs
+++ b/datafusion/src/physical_plan/expressions/lead_lag.rs
@@ -21,11 +21,13 @@
use crate::error::{DataFusionError, Result};
use crate::physical_plan::window_functions::PartitionEvaluator;
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr,
PhysicalExpr};
+use crate::scalar::ScalarValue;
use arrow::array::ArrayRef;
-use arrow::compute::kernels::window::shift;
+use arrow::compute::cast;
use arrow::datatypes::{DataType, Field};
use arrow::record_batch::RecordBatch;
use std::any::Any;
+use std::ops::Neg;
use std::ops::Range;
use std::sync::Arc;
@@ -36,6 +38,7 @@ pub struct WindowShift {
data_type: DataType,
shift_offset: i64,
expr: Arc<dyn PhysicalExpr>,
+ default_value: Option<ScalarValue>,
}
/// lead() window function
@@ -43,12 +46,15 @@ pub fn lead(
name: String,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
+ shift_offset: Option<i64>,
+ default_value: Option<ScalarValue>,
) -> WindowShift {
WindowShift {
name,
data_type,
- shift_offset: -1,
+ shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1),
expr,
+ default_value,
}
}
@@ -57,12 +63,15 @@ pub fn lag(
name: String,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
+ shift_offset: Option<i64>,
+ default_value: Option<ScalarValue>,
) -> WindowShift {
WindowShift {
name,
data_type,
- shift_offset: 1,
+ shift_offset: shift_offset.unwrap_or(1),
expr,
+ default_value,
}
}
@@ -98,6 +107,7 @@ impl BuiltInWindowFunctionExpr for WindowShift {
Ok(Box::new(WindowShiftEvaluator {
shift_offset: self.shift_offset,
values,
+ default_value: self.default_value.clone(),
}))
}
}
@@ -105,13 +115,63 @@ impl BuiltInWindowFunctionExpr for WindowShift {
pub(crate) struct WindowShiftEvaluator {
shift_offset: i64,
values: Vec<ArrayRef>,
+ default_value: Option<ScalarValue>,
+}
+
+fn create_empty_array(
+ value: &Option<ScalarValue>,
+ data_type: &DataType,
+ size: usize,
+) -> Result<ArrayRef> {
+ use arrow::array::new_null_array;
+ let array = value
+ .as_ref()
+ .map(|scalar| scalar.to_array_of_size(size))
+ .unwrap_or_else(|| new_null_array(data_type, size));
+ if array.data_type() != data_type {
+ cast(&array, data_type).map_err(DataFusionError::ArrowError)
+ } else {
+ Ok(array)
+ }
+}
+
+// TODO: change the original arrow::compute::kernels::window::shift impl to
support an optional default value
+fn shift_with_default_value(
+ array: &ArrayRef,
+ offset: i64,
+ value: &Option<ScalarValue>,
+) -> Result<ArrayRef> {
+ use arrow::compute::concat;
+
+ let value_len = array.len() as i64;
+ if offset == 0 {
+ Ok(arrow::array::make_array(array.data_ref().clone()))
+ } else if offset == i64::MIN || offset.abs() >= value_len {
+ create_empty_array(value, array.data_type(), array.len())
+ } else {
+ let slice_offset = (-offset).clamp(0, value_len) as usize;
+ let length = array.len() - offset.abs() as usize;
+ let slice = array.slice(slice_offset, length);
+
+ // Generate array with remaining `null` items
+ let nulls = offset.abs() as usize;
+ let default_values = create_empty_array(value, slice.data_type(),
nulls)?;
+ // Concatenate both arrays, add nulls after if shift > 0 else before
+ if offset > 0 {
+ concat(&[default_values.as_ref(), slice.as_ref()])
+ .map_err(DataFusionError::ArrowError)
+ } else {
+ concat(&[slice.as_ref(), default_values.as_ref()])
+ .map_err(DataFusionError::ArrowError)
+ }
+ }
}
impl PartitionEvaluator for WindowShiftEvaluator {
fn evaluate_partition(&self, partition: Range<usize>) -> Result<ArrayRef> {
let value = &self.values[0];
let value = value.slice(partition.start, partition.end -
partition.start);
- shift(value.as_ref(),
self.shift_offset).map_err(DataFusionError::ArrowError)
+ shift_with_default_value(&value, self.shift_offset,
&self.default_value)
}
}
@@ -142,6 +202,8 @@ mod tests {
"lead".to_owned(),
DataType::Float32,
Arc::new(Column::new("c3", 0)),
+ None,
+ None,
),
vec![
Some(-2),
@@ -162,6 +224,8 @@ mod tests {
"lead".to_owned(),
DataType::Float32,
Arc::new(Column::new("c3", 0)),
+ None,
+ None,
),
vec![
None,
@@ -176,6 +240,28 @@ mod tests {
.iter()
.collect::<Int32Array>(),
)?;
+
+ test_i32_result(
+ lag(
+ "lead".to_owned(),
+ DataType::Int32,
+ Arc::new(Column::new("c3", 0)),
+ None,
+ Some(ScalarValue::Int32(Some(100))),
+ ),
+ vec![
+ Some(100),
+ Some(1),
+ Some(-2),
+ Some(3),
+ Some(-4),
+ Some(5),
+ Some(-6),
+ Some(7),
+ ]
+ .iter()
+ .collect::<Int32Array>(),
+ )?;
Ok(())
}
}
diff --git a/datafusion/src/physical_plan/type_coercion.rs
b/datafusion/src/physical_plan/type_coercion.rs
index ffd8f20..c8387bb 100644
--- a/datafusion/src/physical_plan/type_coercion.rs
+++ b/datafusion/src/physical_plan/type_coercion.rs
@@ -128,13 +128,11 @@ fn get_valid_types(
}
vec![(0..*number).map(|i| current_types[i].clone()).collect()]
}
- Signature::OneOf(types) => {
- let mut r = vec![];
- for s in types {
- r.extend(get_valid_types(s, current_types)?);
- }
- r
- }
+ Signature::OneOf(types) => types
+ .iter()
+ .filter_map(|t| get_valid_types(t, current_types).ok())
+ .flatten()
+ .collect::<Vec<_>>(),
};
Ok(valid_types)
@@ -367,4 +365,27 @@ mod tests {
Ok(())
}
+
+ #[test]
+ fn test_get_valid_types_one_of() -> Result<()> {
+ let signature = Signature::OneOf(vec![Signature::Any(1),
Signature::Any(2)]);
+
+ let invalid_types = get_valid_types(
+ &signature,
+ &[DataType::Int32, DataType::Int32, DataType::Int32],
+ )?;
+ assert_eq!(invalid_types.len(), 0);
+
+ let args = vec![DataType::Int32, DataType::Int32];
+ let valid_types = get_valid_types(&signature, &args)?;
+ assert_eq!(valid_types.len(), 1);
+ assert_eq!(valid_types[0], args);
+
+ let args = vec![DataType::Int32];
+ let valid_types = get_valid_types(&signature, &args)?;
+ assert_eq!(valid_types.len(), 1);
+ assert_eq!(valid_types[0], args);
+
+ Ok(())
+ }
}
diff --git a/datafusion/src/physical_plan/window_functions.rs
b/datafusion/src/physical_plan/window_functions.rs
index 99805b6..e2b4606 100644
--- a/datafusion/src/physical_plan/window_functions.rs
+++ b/datafusion/src/physical_plan/window_functions.rs
@@ -201,10 +201,16 @@ pub(super) fn signature_for_built_in(fun:
&BuiltInWindowFunction) -> Signature {
| BuiltInWindowFunction::DenseRank
| BuiltInWindowFunction::PercentRank
| BuiltInWindowFunction::CumeDist => Signature::Any(0),
- BuiltInWindowFunction::Lag
- | BuiltInWindowFunction::Lead
- | BuiltInWindowFunction::FirstValue
- | BuiltInWindowFunction::LastValue => Signature::Any(1),
+ BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => {
+ Signature::OneOf(vec![
+ Signature::Any(1),
+ Signature::Any(2),
+ Signature::Any(3),
+ ])
+ }
+ BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue
=> {
+ Signature::Any(1)
+ }
BuiltInWindowFunction::Ntile =>
Signature::Exact(vec![DataType::UInt64]),
BuiltInWindowFunction::NthValue => Signature::Any(2),
}
diff --git a/datafusion/src/physical_plan/windows.rs
b/datafusion/src/physical_plan/windows.rs
index 1b78378..a1f4b7a 100644
--- a/datafusion/src/physical_plan/windows.rs
+++ b/datafusion/src/physical_plan/windows.rs
@@ -32,6 +32,7 @@ use crate::physical_plan::{
Accumulator, AggregateExpr, Distribution, ExecutionPlan, Partitioning,
PhysicalExpr,
RecordBatchStream, SendableRecordBatchStream, WindowExpr,
};
+use crate::scalar::ScalarValue;
use arrow::compute::concat;
use arrow::{
array::ArrayRef,
@@ -96,6 +97,19 @@ pub fn create_window_expr(
})
}
+fn get_scalar_value_from_args(
+ args: &[Arc<dyn PhysicalExpr>],
+ index: usize,
+) -> Option<ScalarValue> {
+ args.get(index).map(|v| {
+ v.as_any()
+ .downcast_ref::<Literal>()
+ .unwrap()
+ .value()
+ .clone()
+ })
+}
+
fn create_built_in_window_expr(
fun: &BuiltInWindowFunction,
args: &[Arc<dyn PhysicalExpr>],
@@ -110,13 +124,21 @@ fn create_built_in_window_expr(
let coerced_args = coerce(args, input_schema,
&signature_for_built_in(fun))?;
let arg = coerced_args[0].clone();
let data_type = args[0].data_type(input_schema)?;
- Arc::new(lag(name, data_type, arg))
+ let shift_offset = get_scalar_value_from_args(&coerced_args, 1)
+ .map(|v| v.try_into())
+ .and_then(|v| v.ok());
+ let default_value = get_scalar_value_from_args(&coerced_args, 2);
+ Arc::new(lag(name, data_type, arg, shift_offset, default_value))
}
BuiltInWindowFunction::Lead => {
let coerced_args = coerce(args, input_schema,
&signature_for_built_in(fun))?;
let arg = coerced_args[0].clone();
let data_type = args[0].data_type(input_schema)?;
- Arc::new(lead(name, data_type, arg))
+ let shift_offset = get_scalar_value_from_args(&coerced_args, 1)
+ .map(|v| v.try_into())
+ .and_then(|v| v.ok());
+ let default_value = get_scalar_value_from_args(&coerced_args, 2);
+ Arc::new(lead(name, data_type, arg, shift_offset, default_value))
}
BuiltInWindowFunction::NthValue => {
let coerced_args = coerce(args, input_schema,
&signature_for_built_in(fun))?;
@@ -592,6 +614,47 @@ mod tests {
Ok((input, schema))
}
+ #[test]
+ fn test_create_window_exp_lead_no_args() -> Result<()> {
+ let (_, schema) = create_test_schema(1)?;
+
+ let expr = create_window_expr(
+
&WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead),
+ "prev".to_owned(),
+ &[col("c2", &schema)?],
+ &[],
+ &[],
+ Some(WindowFrame::default()),
+ schema.as_ref(),
+ )?;
+
+ assert_eq!(expr.name(), "prev");
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_create_window_exp_lead_with_args() -> Result<()> {
+ let (_, schema) = create_test_schema(1)?;
+
+ let expr = create_window_expr(
+
&WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead),
+ "prev".to_owned(),
+ &[
+ col("c2", &schema)?,
+ Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
+ ],
+ &[],
+ &[],
+ Some(WindowFrame::default()),
+ schema.as_ref(),
+ )?;
+
+ assert_eq!(expr.name(), "prev");
+
+ Ok(())
+ }
+
#[tokio::test]
async fn window_function() -> Result<()> {
let (input, schema) = create_test_schema(1)?;
diff --git a/integration-tests/sqls/simple_window_lead_built_in_functions.sql
b/integration-tests/sqls/simple_window_lead_built_in_functions.sql
new file mode 100644
index 0000000..67df05b
--- /dev/null
+++ b/integration-tests/sqls/simple_window_lead_built_in_functions.sql
@@ -0,0 +1,27 @@
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements. See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership. The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License. You may obtain a copy of the License at
+
+-- http://www.apache.org/licenses/LICENSE-2.0
+
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+
+SELECT
+ c8,
+ LEAD(c8) OVER () next_c8,
+ LEAD(c8, 10, 10) OVER() next_10_c8,
+ LEAD(c8, 100, 10) OVER() next_out_of_bounds_c8,
+ LAG(c8) OVER() prev_c8,
+ LAG(c8, -2, 0) OVER() AS prev_2_c8,
+ LAG(c8, -200, 10) OVER() AS prev_out_of_bounds_c8
+
+FROM test
+ORDER BY c8;
diff --git a/integration-tests/test_psql_parity.py
b/integration-tests/test_psql_parity.py
index a160d3e..a85a2c2 100644
--- a/integration-tests/test_psql_parity.py
+++ b/integration-tests/test_psql_parity.py
@@ -77,7 +77,7 @@ test_files = set(root.glob("*.sql"))
class TestPsqlParity:
def test_tests_count(self):
- assert len(test_files) == 14, "tests are missed"
+ assert len(test_files) == 15, "tests are missed"
@pytest.mark.parametrize("fname", test_files)
def test_sql_file(self, fname):