Jefffrey commented on code in PR #20039:
URL: https://github.com/apache/datafusion/pull/20039#discussion_r2744928835
##########
datafusion/physical-plan/src/async_func.rs:
##########
@@ -59,10 +61,14 @@ impl AsyncFuncExec {
async_exprs: Vec<Arc<AsyncFuncExpr>>,
input: Arc<dyn ExecutionPlan>,
) -> Result<Self> {
- let async_fields = async_exprs
- .iter()
- .map(|async_expr| async_expr.field(input.schema().as_ref()))
- .collect::<Result<Vec<_>>>()?;
+ let mut current_fields = input.schema().fields().to_vec();
+ let mut async_fields = Vec::with_capacity(async_exprs.len());
+ for async_expr in &async_exprs {
+ let current_schema = Schema::new(current_fields.clone());
+ let field = async_expr.field(¤t_schema)?;
+ current_fields.push(Arc::new(field.clone()));
+ async_fields.push(field);
+ }
// compute the output schema: input schema then async expressions
let fields: Fields = input
Review Comment:
I think we can remove `fields` in favour of `current_fields`
##########
datafusion/physical-plan/src/async_func.rs:
##########
@@ -315,58 +363,49 @@ impl AsyncMapper {
}
/// Finds any references to async functions in the expression and adds
them to the map
- pub fn find_references(
+ /// AND rewrites the expression to use the mapped columns.
+ pub fn find_and_map(
&mut self,
physical_expr: &Arc<dyn PhysicalExpr>,
schema: &Schema,
- ) -> Result<()> {
- // recursively look for references to async functions
- physical_expr.apply(|expr| {
+ ) -> Result<Arc<dyn PhysicalExpr>> {
+ let transformed = Arc::clone(physical_expr).transform_up(|expr| {
if let Some(scalar_func_expr) =
expr.as_any().downcast_ref::<ScalarFunctionExpr>()
&& scalar_func_expr.fun().as_async().is_some()
{
let next_name = self.next_column_name();
- self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new(
- next_name,
- Arc::clone(expr),
- schema,
- )?));
+
+ // Construct extended schema including previously mapped async
fields
+ let mut current_fields = schema.fields().to_vec();
+ current_fields.extend(
+ self.output_fields
+ .iter()
+ .map(|f: &Field| Arc::new(f.clone())),
Review Comment:
Since we have to `Arc` each it time here, we might be better off changing
`output_fields` to store `FieldRef`s (aka `Arc<Field>`s) to simplify things
##########
datafusion/core/src/physical_planner.rs:
##########
@@ -2686,47 +2686,41 @@ impl DefaultPhysicalPlanner {
schema: &Schema,
) -> Result<PlanAsyncExpr> {
let mut async_map = AsyncMapper::new(num_input_columns);
- match &physical_expr {
+ let new_physical_expr = match physical_expr {
PlannedExprResult::ExprWithName(exprs) => {
- exprs
+ let new_exprs = exprs
.iter()
- .try_for_each(|(expr, _)| async_map.find_references(expr,
schema))?;
+ .map(|(expr, name)| {
+ // find_and_map will:
+ // 1. Identify nested async UDFs bottom-up
+ // 2. Rewrite them to Columns referencing the async_map
+ // 3. Return the fully rewritten expression
+ let new_expr = async_map.find_and_map(expr, schema)?;
+ Ok((new_expr, name.clone()))
Review Comment:
I'd suggest using `into_iter` to avoid the clone, for example:
```rust
let new_exprs = exprs
.into_iter()
.map(|(expr, name)| {
// find_and_map will:
// 1. Identify nested async UDFs bottom-up
// 2. Rewrite them to Columns referencing the async_map
// 3. Return the fully rewritten expression
let new_expr = async_map.find_and_map(expr, schema)?;
Ok((new_expr, name))
})
.collect::<Result<Vec<_>>>()?;
```
##########
datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs:
##########
@@ -113,6 +114,43 @@ async fn test_async_udf_metrics() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn test_nested_async_udf() -> Result<()> {
Review Comment:
We can probably omit the changes to this file; SLT is sufficient for testing
purposes
##########
datafusion/physical-plan/src/async_func.rs:
##########
@@ -315,58 +363,49 @@ impl AsyncMapper {
}
/// Finds any references to async functions in the expression and adds
them to the map
- pub fn find_references(
+ /// AND rewrites the expression to use the mapped columns.
+ pub fn find_and_map(
&mut self,
physical_expr: &Arc<dyn PhysicalExpr>,
Review Comment:
We can probably change this to take by value
```rust
physical_expr: Arc<dyn PhysicalExpr>,
```
Since we clone it internally anyway
##########
datafusion/physical-plan/src/async_func.rs:
##########
@@ -74,10 +80,25 @@ impl AsyncFuncExec {
.collect();
let schema = Arc::new(Schema::new(fields));
- let tuples = async_exprs
- .iter()
- .map(|expr| (Arc::clone(&expr.func), expr.name().to_string()))
- .collect::<Vec<_>>();
+
+ // Only include expressions that map to input columns in the
ProjectionMapping
+ // Expressions referencing newly created async columns cannot be
verified against input schema
+ let input_len = input.schema().fields().len();
+ let mut tuples = Vec::new();
+ for expr in &async_exprs {
+ let mut refers_to_new_cols = false;
+ expr.func.apply(&mut |e: &Arc<dyn PhysicalExpr>| {
+ if let Some(col) = e.as_any().downcast_ref::<Column>() {
+ refers_to_new_cols |= col.index() >= input_len;
+ }
+ Ok(TreeNodeRecursion::Continue)
+ })?;
Review Comment:
```suggestion
let refers_to_new_cols = expr.func.exists(|e| {
if let Some(col) = e.as_any().downcast_ref::<Column>()
&& col.index() >= input_len
{
Ok(true)
} else {
Ok(false)
}
})?;
```
Simpler to use `exists()`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]