This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 002ca5d  Lead/lag window function with offset and default value 
arguments (#687)
002ca5d is described below

commit 002ca5d1e81dd45247b90f0f6e340ff5fec3a747
Author: Javier Goday <[email protected]>
AuthorDate: Wed Jul 14 22:22:17 2021 +0200

    Lead/lag window function with offset and default value arguments (#687)
---
 .../src/physical_plan/expressions/lead_lag.rs      | 94 +++++++++++++++++++++-
 datafusion/src/physical_plan/type_coercion.rs      | 35 ++++++--
 datafusion/src/physical_plan/window_functions.rs   | 14 +++-
 datafusion/src/physical_plan/windows.rs            | 67 ++++++++++++++-
 .../sqls/simple_window_lead_built_in_functions.sql | 27 +++++++
 integration-tests/test_psql_parity.py              |  2 +-
 6 files changed, 221 insertions(+), 18 deletions(-)

diff --git a/datafusion/src/physical_plan/expressions/lead_lag.rs 
b/datafusion/src/physical_plan/expressions/lead_lag.rs
index 352d97c..d1f6c19 100644
--- a/datafusion/src/physical_plan/expressions/lead_lag.rs
+++ b/datafusion/src/physical_plan/expressions/lead_lag.rs
@@ -21,11 +21,13 @@
 use crate::error::{DataFusionError, Result};
 use crate::physical_plan::window_functions::PartitionEvaluator;
 use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, 
PhysicalExpr};
+use crate::scalar::ScalarValue;
 use arrow::array::ArrayRef;
-use arrow::compute::kernels::window::shift;
+use arrow::compute::cast;
 use arrow::datatypes::{DataType, Field};
 use arrow::record_batch::RecordBatch;
 use std::any::Any;
+use std::ops::Neg;
 use std::ops::Range;
 use std::sync::Arc;
 
@@ -36,6 +38,7 @@ pub struct WindowShift {
     data_type: DataType,
     shift_offset: i64,
     expr: Arc<dyn PhysicalExpr>,
+    default_value: Option<ScalarValue>,
 }
 
 /// lead() window function
@@ -43,12 +46,15 @@ pub fn lead(
     name: String,
     data_type: DataType,
     expr: Arc<dyn PhysicalExpr>,
+    shift_offset: Option<i64>,
+    default_value: Option<ScalarValue>,
 ) -> WindowShift {
     WindowShift {
         name,
         data_type,
-        shift_offset: -1,
+        shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1),
         expr,
+        default_value,
     }
 }
 
@@ -57,12 +63,15 @@ pub fn lag(
     name: String,
     data_type: DataType,
     expr: Arc<dyn PhysicalExpr>,
+    shift_offset: Option<i64>,
+    default_value: Option<ScalarValue>,
 ) -> WindowShift {
     WindowShift {
         name,
         data_type,
-        shift_offset: 1,
+        shift_offset: shift_offset.unwrap_or(1),
         expr,
+        default_value,
     }
 }
 
@@ -98,6 +107,7 @@ impl BuiltInWindowFunctionExpr for WindowShift {
         Ok(Box::new(WindowShiftEvaluator {
             shift_offset: self.shift_offset,
             values,
+            default_value: self.default_value.clone(),
         }))
     }
 }
@@ -105,13 +115,63 @@ impl BuiltInWindowFunctionExpr for WindowShift {
 pub(crate) struct WindowShiftEvaluator {
     shift_offset: i64,
     values: Vec<ArrayRef>,
+    default_value: Option<ScalarValue>,
+}
+
+fn create_empty_array(
+    value: &Option<ScalarValue>,
+    data_type: &DataType,
+    size: usize,
+) -> Result<ArrayRef> {
+    use arrow::array::new_null_array;
+    let array = value
+        .as_ref()
+        .map(|scalar| scalar.to_array_of_size(size))
+        .unwrap_or_else(|| new_null_array(data_type, size));
+    if array.data_type() != data_type {
+        cast(&array, data_type).map_err(DataFusionError::ArrowError)
+    } else {
+        Ok(array)
+    }
+}
+
+// TODO: change the original arrow::compute::kernels::window::shift impl to 
support an optional default value
+fn shift_with_default_value(
+    array: &ArrayRef,
+    offset: i64,
+    value: &Option<ScalarValue>,
+) -> Result<ArrayRef> {
+    use arrow::compute::concat;
+
+    let value_len = array.len() as i64;
+    if offset == 0 {
+        Ok(arrow::array::make_array(array.data_ref().clone()))
+    } else if offset == i64::MIN || offset.abs() >= value_len {
+        create_empty_array(value, array.data_type(), array.len())
+    } else {
+        let slice_offset = (-offset).clamp(0, value_len) as usize;
+        let length = array.len() - offset.abs() as usize;
+        let slice = array.slice(slice_offset, length);
+
+        // Generate array with remaining `null` items
+        let nulls = offset.abs() as usize;
+        let default_values = create_empty_array(value, slice.data_type(), 
nulls)?;
+        // Concatenate both arrays, add nulls after if shift > 0 else before
+        if offset > 0 {
+            concat(&[default_values.as_ref(), slice.as_ref()])
+                .map_err(DataFusionError::ArrowError)
+        } else {
+            concat(&[slice.as_ref(), default_values.as_ref()])
+                .map_err(DataFusionError::ArrowError)
+        }
+    }
 }
 
 impl PartitionEvaluator for WindowShiftEvaluator {
     fn evaluate_partition(&self, partition: Range<usize>) -> Result<ArrayRef> {
         let value = &self.values[0];
         let value = value.slice(partition.start, partition.end - 
partition.start);
-        shift(value.as_ref(), 
self.shift_offset).map_err(DataFusionError::ArrowError)
+        shift_with_default_value(&value, self.shift_offset, 
&self.default_value)
     }
 }
 
@@ -142,6 +202,8 @@ mod tests {
                 "lead".to_owned(),
                 DataType::Float32,
                 Arc::new(Column::new("c3", 0)),
+                None,
+                None,
             ),
             vec![
                 Some(-2),
@@ -162,6 +224,8 @@ mod tests {
                 "lead".to_owned(),
                 DataType::Float32,
                 Arc::new(Column::new("c3", 0)),
+                None,
+                None,
             ),
             vec![
                 None,
@@ -176,6 +240,28 @@ mod tests {
             .iter()
             .collect::<Int32Array>(),
         )?;
+
+        test_i32_result(
+            lag(
+                "lead".to_owned(),
+                DataType::Int32,
+                Arc::new(Column::new("c3", 0)),
+                None,
+                Some(ScalarValue::Int32(Some(100))),
+            ),
+            vec![
+                Some(100),
+                Some(1),
+                Some(-2),
+                Some(3),
+                Some(-4),
+                Some(5),
+                Some(-6),
+                Some(7),
+            ]
+            .iter()
+            .collect::<Int32Array>(),
+        )?;
         Ok(())
     }
 }
diff --git a/datafusion/src/physical_plan/type_coercion.rs 
b/datafusion/src/physical_plan/type_coercion.rs
index ffd8f20..c8387bb 100644
--- a/datafusion/src/physical_plan/type_coercion.rs
+++ b/datafusion/src/physical_plan/type_coercion.rs
@@ -128,13 +128,11 @@ fn get_valid_types(
             }
             vec![(0..*number).map(|i| current_types[i].clone()).collect()]
         }
-        Signature::OneOf(types) => {
-            let mut r = vec![];
-            for s in types {
-                r.extend(get_valid_types(s, current_types)?);
-            }
-            r
-        }
+        Signature::OneOf(types) => types
+            .iter()
+            .filter_map(|t| get_valid_types(t, current_types).ok())
+            .flatten()
+            .collect::<Vec<_>>(),
     };
 
     Ok(valid_types)
@@ -367,4 +365,27 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn test_get_valid_types_one_of() -> Result<()> {
+        let signature = Signature::OneOf(vec![Signature::Any(1), 
Signature::Any(2)]);
+
+        let invalid_types = get_valid_types(
+            &signature,
+            &[DataType::Int32, DataType::Int32, DataType::Int32],
+        )?;
+        assert_eq!(invalid_types.len(), 0);
+
+        let args = vec![DataType::Int32, DataType::Int32];
+        let valid_types = get_valid_types(&signature, &args)?;
+        assert_eq!(valid_types.len(), 1);
+        assert_eq!(valid_types[0], args);
+
+        let args = vec![DataType::Int32];
+        let valid_types = get_valid_types(&signature, &args)?;
+        assert_eq!(valid_types.len(), 1);
+        assert_eq!(valid_types[0], args);
+
+        Ok(())
+    }
 }
diff --git a/datafusion/src/physical_plan/window_functions.rs 
b/datafusion/src/physical_plan/window_functions.rs
index 99805b6..e2b4606 100644
--- a/datafusion/src/physical_plan/window_functions.rs
+++ b/datafusion/src/physical_plan/window_functions.rs
@@ -201,10 +201,16 @@ pub(super) fn signature_for_built_in(fun: 
&BuiltInWindowFunction) -> Signature {
         | BuiltInWindowFunction::DenseRank
         | BuiltInWindowFunction::PercentRank
         | BuiltInWindowFunction::CumeDist => Signature::Any(0),
-        BuiltInWindowFunction::Lag
-        | BuiltInWindowFunction::Lead
-        | BuiltInWindowFunction::FirstValue
-        | BuiltInWindowFunction::LastValue => Signature::Any(1),
+        BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => {
+            Signature::OneOf(vec![
+                Signature::Any(1),
+                Signature::Any(2),
+                Signature::Any(3),
+            ])
+        }
+        BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue 
=> {
+            Signature::Any(1)
+        }
         BuiltInWindowFunction::Ntile => 
Signature::Exact(vec![DataType::UInt64]),
         BuiltInWindowFunction::NthValue => Signature::Any(2),
     }
diff --git a/datafusion/src/physical_plan/windows.rs 
b/datafusion/src/physical_plan/windows.rs
index 1b78378..a1f4b7a 100644
--- a/datafusion/src/physical_plan/windows.rs
+++ b/datafusion/src/physical_plan/windows.rs
@@ -32,6 +32,7 @@ use crate::physical_plan::{
     Accumulator, AggregateExpr, Distribution, ExecutionPlan, Partitioning, 
PhysicalExpr,
     RecordBatchStream, SendableRecordBatchStream, WindowExpr,
 };
+use crate::scalar::ScalarValue;
 use arrow::compute::concat;
 use arrow::{
     array::ArrayRef,
@@ -96,6 +97,19 @@ pub fn create_window_expr(
     })
 }
 
+fn get_scalar_value_from_args(
+    args: &[Arc<dyn PhysicalExpr>],
+    index: usize,
+) -> Option<ScalarValue> {
+    args.get(index).map(|v| {
+        v.as_any()
+            .downcast_ref::<Literal>()
+            .unwrap()
+            .value()
+            .clone()
+    })
+}
+
 fn create_built_in_window_expr(
     fun: &BuiltInWindowFunction,
     args: &[Arc<dyn PhysicalExpr>],
@@ -110,13 +124,21 @@ fn create_built_in_window_expr(
             let coerced_args = coerce(args, input_schema, 
&signature_for_built_in(fun))?;
             let arg = coerced_args[0].clone();
             let data_type = args[0].data_type(input_schema)?;
-            Arc::new(lag(name, data_type, arg))
+            let shift_offset = get_scalar_value_from_args(&coerced_args, 1)
+                .map(|v| v.try_into())
+                .and_then(|v| v.ok());
+            let default_value = get_scalar_value_from_args(&coerced_args, 2);
+            Arc::new(lag(name, data_type, arg, shift_offset, default_value))
         }
         BuiltInWindowFunction::Lead => {
             let coerced_args = coerce(args, input_schema, 
&signature_for_built_in(fun))?;
             let arg = coerced_args[0].clone();
             let data_type = args[0].data_type(input_schema)?;
-            Arc::new(lead(name, data_type, arg))
+            let shift_offset = get_scalar_value_from_args(&coerced_args, 1)
+                .map(|v| v.try_into())
+                .and_then(|v| v.ok());
+            let default_value = get_scalar_value_from_args(&coerced_args, 2);
+            Arc::new(lead(name, data_type, arg, shift_offset, default_value))
         }
         BuiltInWindowFunction::NthValue => {
             let coerced_args = coerce(args, input_schema, 
&signature_for_built_in(fun))?;
@@ -592,6 +614,47 @@ mod tests {
         Ok((input, schema))
     }
 
+    #[test]
+    fn test_create_window_exp_lead_no_args() -> Result<()> {
+        let (_, schema) = create_test_schema(1)?;
+
+        let expr = create_window_expr(
+            
&WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead),
+            "prev".to_owned(),
+            &[col("c2", &schema)?],
+            &[],
+            &[],
+            Some(WindowFrame::default()),
+            schema.as_ref(),
+        )?;
+
+        assert_eq!(expr.name(), "prev");
+
+        Ok(())
+    }
+
+    #[test]
+    fn test_create_window_exp_lead_with_args() -> Result<()> {
+        let (_, schema) = create_test_schema(1)?;
+
+        let expr = create_window_expr(
+            
&WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead),
+            "prev".to_owned(),
+            &[
+                col("c2", &schema)?,
+                Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
+            ],
+            &[],
+            &[],
+            Some(WindowFrame::default()),
+            schema.as_ref(),
+        )?;
+
+        assert_eq!(expr.name(), "prev");
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn window_function() -> Result<()> {
         let (input, schema) = create_test_schema(1)?;
diff --git a/integration-tests/sqls/simple_window_lead_built_in_functions.sql 
b/integration-tests/sqls/simple_window_lead_built_in_functions.sql
new file mode 100644
index 0000000..67df05b
--- /dev/null
+++ b/integration-tests/sqls/simple_window_lead_built_in_functions.sql
@@ -0,0 +1,27 @@
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements.  See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership.  The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License.  You may obtain a copy of the License at
+
+-- http://www.apache.org/licenses/LICENSE-2.0
+
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+
+SELECT
+  c8,
+  LEAD(c8) OVER () next_c8,
+  LEAD(c8, 10, 10) OVER() next_10_c8,
+  LEAD(c8, 100, 10) OVER() next_out_of_bounds_c8,
+  LAG(c8) OVER() prev_c8,
+  LAG(c8, -2, 0) OVER() AS prev_2_c8,
+  LAG(c8, -200, 10) OVER() AS prev_out_of_bounds_c8
+
+FROM test
+ORDER BY c8;
diff --git a/integration-tests/test_psql_parity.py 
b/integration-tests/test_psql_parity.py
index a160d3e..a85a2c2 100644
--- a/integration-tests/test_psql_parity.py
+++ b/integration-tests/test_psql_parity.py
@@ -77,7 +77,7 @@ test_files = set(root.glob("*.sql"))
 
 class TestPsqlParity:
     def test_tests_count(self):
-        assert len(test_files) == 14, "tests are missed"
+        assert len(test_files) == 15, "tests are missed"
 
     @pytest.mark.parametrize("fname", test_files)
     def test_sql_file(self, fname):

Reply via email to