This is an automated email from the ASF dual-hosted git repository. michaelsmith pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/impala.git
commit 20a9d2669c69f8e5b0a5c0b9487fa0212a00ad9c Author: pranav.lodha <[email protected]> AuthorDate: Mon Feb 27 02:39:20 2023 +0530 IMPALA-11957: Implement Regression functions: regr_slope(), regr_intercept() and regr_r2() The linear regression functions fit an ordinary-least-squares regression line to a set of number pairs. They can be used both as aggregate and analytic functions. regr_slope() takes two arguments of numeric type and returns the slope of the line. regr_intercept() takes two arguments of numeric type and returns the y-intercept of the regression line. regr_r2() takes two arguments of numeric type and returns the coefficient of determination (also called R-squared or goodness of fit) for the regression. Testing: The functions are extensively tested and cross-checked with Hive. The tests can be found in aggregation.test. Change-Id: Iab6bd84ae3e0c02ec924c30183308123b951caa3 Reviewed-on: http://gerrit.cloudera.org:8080/19569 Reviewed-by: Impala Public Jenkins <[email protected]> Tested-by: Impala Public Jenkins <[email protected]> --- be/src/exprs/aggregate-functions-ir.cc | 272 ++++++++- be/src/exprs/aggregate-functions.h | 22 + .../java/org/apache/impala/catalog/BuiltinsDb.java | 71 +++ .../queries/QueryTest/aggregation.test | 631 ++++++++++++++++++++- 4 files changed, 988 insertions(+), 8 deletions(-) diff --git a/be/src/exprs/aggregate-functions-ir.cc b/be/src/exprs/aggregate-functions-ir.cc index 0bb0b9cc6..bf41acb04 100644 --- a/be/src/exprs/aggregate-functions-ir.cc +++ b/be/src/exprs/aggregate-functions-ir.cc @@ -288,6 +288,202 @@ void AggregateFunctions::CountMerge(FunctionContext*, const BigIntVal& src, dst->val += src.val; } +// Implementation of regr_slope() and regr_intercept(): +// RegrSlopeState is used for implementing regr_slope() and regr_intercept(). +// regr_slope() and regr_intercept() take two arguments of numeric type and return the +// regression slope of the line and the y-intercept of the regression line respectively. +// The linear regression functions fit an ordinary-least-squares regression line to a set +// of number pairs. They can be used both as aggregate and analytic functions. +// Here's a link which contains description of all the regression functions: +// https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/REGR_-Linear-Regression-Functions.html#GUID-A675B68F-2A88-4843-BE2C-FCDE9C65F9A9 + +// regr_slope() formula used: +// regr_slope(y, x) = covar_pop(x, y) / var_pop(x) +// regr_intercept() formula used: +// regr_intercept(y,x) = avg(y) - regr_slope(y, x) * avg(x) +// where y and x are the dependent and independent variables respectively. +struct RegrSlopeState { + int64_t count; + double yavg; // average of y elements + double xavg; // average of x elements + double xvar; // count times the variance of x elements + double covar; // count times the covariance +}; + +void AggregateFunctions::RegrSlopeInit(FunctionContext* ctx, StringVal* dst) { + dst->is_null = false; + dst->len = sizeof(RegrSlopeState); + AllocBuffer(ctx, dst, dst->len); + if (UNLIKELY(dst->is_null)) { + DCHECK(!ctx->impl()->state()->GetQueryStatus().ok()); + return; + } + *(reinterpret_cast<RegrSlopeState*>(dst->ptr)) = {}; +} + +static inline void RegrSlopeUpdateState(double y, double x, RegrSlopeState* state) { + double deltaY = y - state->yavg; + double deltaX = x - state->xavg; + ++state->count; + // my_n = my_(n - 1) + [y_n - my_(n - 1)] / n + state->yavg += deltaY / state->count; + // mx_n = mx_(n - 1) + [x_n - mx_(n - 1)] / n + state->xavg += deltaX / state->count; + if (state->count > 1) { + // c_n = c_(n - 1) + (y_n - my_n) * (x_n - mx_(n - 1)) OR + // c_n = c_(n - 1) + (y_n - my_(n - 1)) * (x_n - mx_n) + // The apparent asymmetry in the equations is due to the fact that, + // y_n - my_n = (n - 1) * (y_n - my_(n - 1)) / n, so both update terms are equal to + // (n - 1) * (y_n - my_(n - 1)) * (x_n - mx_(n - 1)) / n + state->covar += deltaY * (x - state->xavg); + // vx_n = vx_(n - 1) + (x_n - mx_(n - 1)) * (x_n - mx_n) + state->xvar += deltaX * (x - state->xavg); + } +} + +static inline void RegrSlopeRemoveState(double y, double x, RegrSlopeState* state) { + if (state->count <= 1) { + *(reinterpret_cast<RegrSlopeState*>(sizeof(RegrSlopeState))) = {}; + } else { + double deltaY = y - state->yavg; + double deltaX = x - state->xavg; + --state->count; + // my_(n - 1) = my_n - (y_n - my_n) / (n - 1) + state->yavg -= deltaY / state->count; + // mx_(n - 1) = mx_n - (x_n - mx_n) / (n - 1) + state->xavg -= deltaX / state->count; + // c_(n - 1) = c_n - (y_n - mx_n) * (x_n - mx_(n -1)) + state->covar -= deltaY * (x - state->xavg); + // vx_(n - 1) = vx_n - (x_n - mx_n) * (x_n - mx_(n - 1)) + state->xvar -= deltaX * (x - state->xavg); + } +} + +void AggregateFunctions::RegrSlopeUpdate(FunctionContext* ctx, + const DoubleVal& src1, const DoubleVal& src2, StringVal* dst) { + if (src1.is_null || src2.is_null) return; + DCHECK(dst->ptr != nullptr); + DCHECK_EQ(sizeof(RegrSlopeState), dst->len); + RegrSlopeState* state = reinterpret_cast<RegrSlopeState*>(dst->ptr); + RegrSlopeUpdateState(src1.val, src2.val, state); +} + +void AggregateFunctions::RegrSlopeRemove(FunctionContext* ctx, + const DoubleVal& src1, const DoubleVal& src2, StringVal* dst) { + // Remove doesn't need to explicitly check the number of calls to Update() or Remove() + // because Finalize() returns NULL if count is 0. In other words, it's not needed to + // check if num_removes() >= num_updates() as it's accounted for in Finalize(). + if (src1.is_null || src2.is_null) return; + DCHECK(dst->ptr != nullptr); + DCHECK_EQ(sizeof(RegrSlopeState), dst->len); + RegrSlopeState* state = reinterpret_cast<RegrSlopeState*>(dst->ptr); + RegrSlopeRemoveState(src1.val, src2.val, state); +} + +void AggregateFunctions::TimestampRegrSlopeUpdate(FunctionContext* ctx, + const TimestampVal& src1, const TimestampVal& src2, StringVal* dst) { + if (src1.is_null || src2.is_null) return; + RegrSlopeState* state = reinterpret_cast<RegrSlopeState*>(dst->ptr); + const TimestampValue& tm_src1 = TimestampValue::FromTimestampVal(src1); + const TimestampValue& tm_src2 = TimestampValue::FromTimestampVal(src2); + double val1, val2; + if (tm_src1.ToSubsecondUnixTime(UTCPTR, &val1) && + tm_src2.ToSubsecondUnixTime(UTCPTR, &val2)) { + RegrSlopeUpdateState(val1, val2, state); + } +} + +void AggregateFunctions::TimestampRegrSlopeRemove(FunctionContext* ctx, + const TimestampVal& src1, const TimestampVal& src2, StringVal* dst) { + // Remove doesn't need to explicitly check the number of calls to Update() or Remove() + // because Finalize() returns NULL if count is 0. In other words, it's not needed to + // check if num_removes() >= num_updates() as it's accounted for in Finalize(). + if (src1.is_null || src2.is_null) return; + RegrSlopeState* state = reinterpret_cast<RegrSlopeState*>(dst->ptr); + const TimestampValue& tm_src1 = TimestampValue::FromTimestampVal(src1); + const TimestampValue& tm_src2 = TimestampValue::FromTimestampVal(src2); + double val1, val2; + if (tm_src1.ToSubsecondUnixTime(UTCPTR, &val1) && + tm_src2.ToSubsecondUnixTime(UTCPTR, &val2)) { + RegrSlopeRemoveState(val1, val2, state); + } +} + +void AggregateFunctions::RegrSlopeMerge(FunctionContext* ctx, + const StringVal& src, StringVal* dst) { + const RegrSlopeState* src_state = reinterpret_cast<RegrSlopeState*>(src.ptr); + DCHECK(dst->ptr != nullptr); + DCHECK_EQ(sizeof(RegrSlopeState), dst->len); + RegrSlopeState* dst_state = reinterpret_cast<RegrSlopeState*>(dst->ptr); + if (src.ptr != nullptr) { + int64_t nA = dst_state->count; + if (nA == 0) { + *dst_state = *src_state; + return; + } + double yavgA = dst_state->yavg; + double xavgA = dst_state->xavg; + + dst_state->count += src_state->count; + dst_state->yavg = (yavgA * nA + src_state->yavg * src_state->count) / + dst_state->count; + dst_state->xavg = (xavgA * nA + src_state->xavg * src_state->count) / + dst_state->count; + // vx_(A,B) = vx_A + vx_B + (mx_A - mx_B) * (mx_A - mx_B) * n_A * n_B / (n_A + n_B) + dst_state->xvar += + src_state->xvar + (xavgA - src_state->xavg) * (xavgA - src_state->xavg) * nA + * src_state->count / dst_state->count; + // c_(A,B) = c_A + c_B + (my_A - my_B) * (mx_A - mx_B) * n_A * n_B / (n_A + n_B) + dst_state->covar += src_state->covar + + (yavgA - src_state->yavg) * (xavgA - src_state->xavg) * ((double)(nA * + src_state->count)) / (dst_state->count); + } +} + +DoubleVal AggregateFunctions::RegrSlopeGetValue(FunctionContext* ctx, + const StringVal& src) { + const RegrSlopeState* state = reinterpret_cast<RegrSlopeState*>(src.ptr); + // Calculating Regression slope: + // xvar becomes negative in certain cases due to floating point rounding error. + // Since these values are very small, they can be ignored and rounded to 0. + DCHECK(state->xvar >= FLOATING_POINT_ERROR_THRESHOLD); + if (state->count < 2 || state->xvar <= 0.0) { + return DoubleVal::null(); + } + return DoubleVal(state->covar / state->xvar); +} + +DoubleVal AggregateFunctions::RegrSlopeFinalize(FunctionContext* ctx, + const StringVal& src) { + DoubleVal r = src.is_null ? DoubleVal::null() : + RegrSlopeGetValue(ctx, src); + ctx->Free(src.ptr); + return r; +} + +DoubleVal AggregateFunctions::RegrInterceptGetValue(FunctionContext* ctx, + const StringVal& src) { + RegrSlopeState* state = reinterpret_cast<RegrSlopeState*>(src.ptr); + // Calculating Regression Intercept + // xvar becomes negative in certain cases due to floating point rounding error. + // Since these values are very small, they can be ignored and rounded to 0. + DCHECK(state->xvar >= FLOATING_POINT_ERROR_THRESHOLD); + if (state->count < 2 || state->xvar <= 0.0) { + return DoubleVal::null(); + } + double regrSlope = state->covar / state->xvar; + double regrIntercept = state->yavg - (regrSlope * state->xavg); + return DoubleVal(regrIntercept); +} + +DoubleVal AggregateFunctions::RegrInterceptFinalize(FunctionContext* ctx, + const StringVal& src) { + DoubleVal r = src.is_null ? DoubleVal::null() : + RegrInterceptGetValue(ctx, src); + ctx->Free(src.ptr); + return r; +} + // Implementation of CORR() function which takes two arguments of numeric type // and returns the Pearson's correlation coefficient between them using the Welford's // online algorithm. This is calculated using a stable one-pass algorithm, based on @@ -311,8 +507,12 @@ struct CorrState { void AggregateFunctions::CorrInit(FunctionContext* ctx, StringVal* dst) { dst->is_null = false; dst->len = sizeof(CorrState); - dst->ptr = ctx->Allocate(dst->len); - memset(dst->ptr, 0, dst->len); + AllocBuffer(ctx, dst, dst->len); + if (UNLIKELY(dst->is_null)) { + DCHECK(!ctx->impl()->state()->GetQueryStatus().ok()); + return; + } + *(reinterpret_cast<CorrState*>(dst->ptr)) = {}; } static inline void CorrUpdateState(double x, double y, CorrState* state) { @@ -341,7 +541,7 @@ static inline void CorrRemoveState(double x, double y, CorrState* state) { double deltaX = x - state->xavg; double deltaY = y - state->yavg; if (state->count <= 1) { - memset(state, 0, sizeof(CorrState)); + *(reinterpret_cast<CorrState*>(sizeof(CorrState))) = {}; } else { --state->count; // mx_(n - 1) = mx_n - (x_n - mx_n) / (n - 1) @@ -450,8 +650,8 @@ DoubleVal AggregateFunctions::CorrGetValue(FunctionContext* ctx, const StringVal // Calculating Pearson's correlation coefficient // xvar and yvar become negative in certain cases due to floating point rounding error. // Since these values are very small, they can be ignored and rounded to 0. - DCHECK(state->xvar >= -1E-8); - DCHECK(state->yvar >= -1E-8); + DCHECK(state->xvar >= FLOATING_POINT_ERROR_THRESHOLD); + DCHECK(state->yvar >= FLOATING_POINT_ERROR_THRESHOLD); if (state->count == 0 || state->count == 1 || state->xvar <= 0.0 || state->yvar <= 0.0) { return DoubleVal::null(); @@ -473,6 +673,60 @@ DoubleVal AggregateFunctions::CorrFinalize(FunctionContext* ctx, const StringVal return r; } +// Implementation of regr_r2(): +// CorrState is reused for implementing regr_r2. +// regr_r2() takes two arguments of numeric type and returns the coefficient of +// determination (also called R-squared or goodness of fit) for the regression. +// regr_r2() formula used: +// regr_2(y, x) = NULL if var_pop(x) = 0, else +// 1 if var_pop(y) = 0 (and var_pop(x) != 0), else +// power(corr(y, x),2) if (var_pop(y) != 0 and var_pop(x) != 0) +// where y and x are the dependent and independent variables +// respectively. Note that variances can't be negative. +DoubleVal AggregateFunctions::Regr_r2GetValue(FunctionContext* ctx, + const StringVal& src) { + const CorrState* state = reinterpret_cast<CorrState*>(src.ptr); + // Calculating Regression R2: + // In this function we use 'dependent_var' and 'independent_var' instead of 'y_var' and + // 'x_var'. This is to avoid confusion, because for regr_r2() the dependent variable is + // the first parameter and the independent variable is the second parameter, but + // CorrUpdate(), which we use to produce the intermediate values, has the opposite + // order. Our aggregate function framework passes the variables in order to + // CorrUpdate(), so in CorrUpdate() 'x' corresponds to the dependent variable of + // regr_r2() and 'y' to the independent variable of regr_r2(). + double dependent_var = state->xvar; + double independent_var = state->yvar; + + // dependent_var and independent_var become negative in certain cases due to floating + // point rounding error. + // Since these values are very small, they can be ignored and rounded to 0. + DCHECK(dependent_var >= FLOATING_POINT_ERROR_THRESHOLD); + DCHECK(independent_var >= FLOATING_POINT_ERROR_THRESHOLD); + if (state->count < 2 || (independent_var / state->count) <= 0.0 || + (dependent_var / state->count) < 0.0) { + return DoubleVal::null(); + } else if ((dependent_var / state->count) == 0.0) { + return 1; + } else { + double stddev_prod_squared = dependent_var * independent_var; + // Mathematically 'stddev_prod_squared' can only be 0 if either 'dependent_var' + // or 'independent_var' is 0, which we have handled earlier. However, if both + // 'dependent_var' and 'independent_var' are very small, the result may become + // 0 because of floating point underflow. In this case we return NULL, i.e. treat + // it as if 'dependent_var' was 0. + if (stddev_prod_squared == 0.0) return DoubleVal::null(); + return state->covar * state->covar / stddev_prod_squared; + } +} + +DoubleVal AggregateFunctions::Regr_r2Finalize(FunctionContext* ctx, + const StringVal& src) { + DoubleVal r = src.is_null ? DoubleVal::null() : + Regr_r2GetValue(ctx, src); + ctx->Free(src.ptr); + return r; +} + // Implementation of COVAR_SAMP() and COVAR_POP() which calculates sample and // population covariance between two columns of numeric types respectively using // the Welford's online algorithm. @@ -490,8 +744,12 @@ struct CovarState { void AggregateFunctions::CovarInit(FunctionContext* ctx, StringVal* dst) { dst->is_null = false; dst->len = sizeof(CovarState); - dst->ptr = ctx->Allocate(dst->len); - memset(dst->ptr, 0, dst->len); + AllocBuffer(ctx, dst, dst->len); + if (UNLIKELY(dst->is_null)) { + DCHECK(!ctx->impl()->state()->GetQueryStatus().ok()); + return; + } + *(reinterpret_cast<CovarState*>(dst->ptr)) = {}; } static inline void CovarUpdateState(double x, double y, CovarState* state) { diff --git a/be/src/exprs/aggregate-functions.h b/be/src/exprs/aggregate-functions.h index 0e8c53a50..c261e80aa 100644 --- a/be/src/exprs/aggregate-functions.h +++ b/be/src/exprs/aggregate-functions.h @@ -37,6 +37,8 @@ using impala_udf::StringVal; using impala_udf::DecimalVal; using impala_udf::DateVal; +static constexpr double FLOATING_POINT_ERROR_THRESHOLD = -1E-8; + /// Collection of builtin aggregate functions. Aggregate functions implement /// the various phases of the aggregation: Init(), Update(), Serialize(), Merge(), /// and Finalize(). Not all functions need to implement all of the steps and @@ -64,6 +66,22 @@ class AggregateFunctions { static StringVal StringValSerializeOrFinalize( FunctionContext* ctx, const StringVal& src); + /// Implementation of regr_slope() and regr_intercept() + static void RegrSlopeInit(FunctionContext* ctx, StringVal* dst); + static void RegrSlopeUpdate(FunctionContext* ctx, const DoubleVal& src1, + const DoubleVal& src2, StringVal* dst); + static void RegrSlopeRemove(FunctionContext* ctx, const DoubleVal& src1, + const DoubleVal& src2, StringVal* dst); + static void TimestampRegrSlopeUpdate(FunctionContext* ctx, + const TimestampVal& src1, const TimestampVal& src2, StringVal* dst); + static void TimestampRegrSlopeRemove(FunctionContext* ctx, + const TimestampVal& src1, const TimestampVal& src2, StringVal* dst); + static void RegrSlopeMerge(FunctionContext* ctx, const StringVal& src, StringVal* dst); + static DoubleVal RegrSlopeGetValue(FunctionContext* ctx, const StringVal& src); + static DoubleVal RegrSlopeFinalize(FunctionContext* ctx, const StringVal& src); + static DoubleVal RegrInterceptGetValue(FunctionContext* ctx, const StringVal& src); + static DoubleVal RegrInterceptFinalize(FunctionContext* ctx, const StringVal& src); + /// Implementation of Corr() static void CorrInit(FunctionContext* ctx, StringVal* dst); static void CorrUpdate(FunctionContext* ctx, const DoubleVal& src1, @@ -78,6 +96,10 @@ class AggregateFunctions { static DoubleVal CorrGetValue(FunctionContext* ctx, const StringVal& src); static DoubleVal CorrFinalize(FunctionContext* ctx, const StringVal& src); + /// Implementation of regr_r2() + static DoubleVal Regr_r2GetValue(FunctionContext* ctx, const StringVal& src); + static DoubleVal Regr_r2Finalize(FunctionContext* ctx, const StringVal& src); + /// Implementation of Covar_samp() and Covar_pop() static void CovarInit(FunctionContext* ctx, StringVal* dst); static void CovarUpdate(FunctionContext* ctx, const DoubleVal& src1, diff --git a/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java b/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java index b32e6acf3..1f1532761 100644 --- a/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java +++ b/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java @@ -1367,6 +1367,77 @@ public class BuiltinsDb extends Db { prefix + "12CorrFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", false, true, false)); + // Regr_r2() + db.addBuiltin(AggregateFunction.createBuiltin(db, "regr_r2", + Lists.<Type>newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.STRING, + prefix + "8CorrInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", + prefix + "10CorrUpdateEPN10impala_udf15FunctionContextERKNS1_9DoubleValES6_PNS1_9StringValE", + prefix + "9CorrMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", + stringValSerializeOrFinalize, + prefix + "15Regr_r2GetValueEPN10impala_udf15FunctionContextERKNS1_9StringValE", + prefix + "10CorrRemoveEPN10impala_udf15FunctionContextERKNS1_9DoubleValES6_PNS1_9StringValE", + prefix + "15Regr_r2FinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", + false, true, false)); + + db.addBuiltin(AggregateFunction.createBuiltin(db, "regr_r2", + Lists.<Type>newArrayList(Type.TIMESTAMP, Type.TIMESTAMP), Type.DOUBLE, Type.STRING, + prefix + "8CorrInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", + prefix + + "19TimestampCorrUpdateEPN10impala_udf15FunctionContextERKNS1_12TimestampValES6_PNS1_9StringValE", + prefix + "9CorrMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", + stringValSerializeOrFinalize, + prefix + "15Regr_r2GetValueEPN10impala_udf15FunctionContextERKNS1_9StringValE", + prefix + + "19TimestampCorrRemoveEPN10impala_udf15FunctionContextERKNS1_12TimestampValES6_PNS1_9StringValE", + prefix + "15Regr_r2FinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", + false, true, false)); + + //Regr_slope() + db.addBuiltin(AggregateFunction.createBuiltin(db, "regr_slope", + Lists.<Type>newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.STRING, + prefix + "13RegrSlopeInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", + prefix + "15RegrSlopeUpdateEPN10impala_udf15FunctionContextERKNS1_9DoubleValES6_PNS1_9StringValE", + prefix + "14RegrSlopeMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", + stringValSerializeOrFinalize, + prefix + "17RegrSlopeGetValueEPN10impala_udf15FunctionContextERKNS1_9StringValE", + prefix + "15RegrSlopeRemoveEPN10impala_udf15FunctionContextERKNS1_9DoubleValES6_PNS1_9StringValE", + prefix + "17RegrSlopeFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", + false, true, false)); + + db.addBuiltin(AggregateFunction.createBuiltin(db, "regr_slope", + Lists.<Type>newArrayList(Type.TIMESTAMP, Type.TIMESTAMP), Type.DOUBLE, Type.STRING, + prefix + "13RegrSlopeInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", + prefix + "24TimestampRegrSlopeUpdateEPN10impala_udf15FunctionContextERKNS1_12TimestampValES6_PNS1_9StringValE", + prefix + "14RegrSlopeMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", + stringValSerializeOrFinalize, + prefix + "17RegrSlopeGetValueEPN10impala_udf15FunctionContextERKNS1_9StringValE", + prefix + "24TimestampRegrSlopeRemoveEPN10impala_udf15FunctionContextERKNS1_12TimestampValES6_PNS1_9StringValE", + prefix + "17RegrSlopeFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", + false, true, false)); + + // Regr_intercept() + db.addBuiltin(AggregateFunction.createBuiltin(db, "regr_intercept", + Lists.<Type>newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.STRING, + prefix + "13RegrSlopeInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", + prefix + "15RegrSlopeUpdateEPN10impala_udf15FunctionContextERKNS1_9DoubleValES6_PNS1_9StringValE", + prefix + "14RegrSlopeMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", + stringValSerializeOrFinalize, + prefix + "21RegrInterceptGetValueEPN10impala_udf15FunctionContextERKNS1_9StringValE", + prefix + "15RegrSlopeRemoveEPN10impala_udf15FunctionContextERKNS1_9DoubleValES6_PNS1_9StringValE", + prefix + "21RegrInterceptFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", + false, true, false)); + + db.addBuiltin(AggregateFunction.createBuiltin(db, "regr_intercept", + Lists.<Type>newArrayList(Type.TIMESTAMP, Type.TIMESTAMP), Type.DOUBLE, Type.STRING, + prefix + "13RegrSlopeInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", + prefix + "24TimestampRegrSlopeUpdateEPN10impala_udf15FunctionContextERKNS1_12TimestampValES6_PNS1_9StringValE", + prefix + "14RegrSlopeMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", + stringValSerializeOrFinalize, + prefix + "21RegrInterceptGetValueEPN10impala_udf15FunctionContextERKNS1_9StringValE", + prefix + "24TimestampRegrSlopeRemoveEPN10impala_udf15FunctionContextERKNS1_12TimestampValES6_PNS1_9StringValE", + prefix + "21RegrInterceptFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", + false, true, false)); + // Covar_samp() db.addBuiltin(AggregateFunction.createBuiltin(db, "covar_samp", Lists.<Type>newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.STRING, diff --git a/testdata/workloads/functional-query/queries/QueryTest/aggregation.test b/testdata/workloads/functional-query/queries/QueryTest/aggregation.test index 98267d4c7..df8162d14 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/aggregation.test +++ b/testdata/workloads/functional-query/queries/QueryTest/aggregation.test @@ -2037,4 +2037,633 @@ select s_store_sk, covar_pop(s_number_employees, s_floor_space) over (partition 11,0 ---- TYPES int,double -==== \ No newline at end of file +==== +---- QUERY +# regression function examples +select regr_slope(ps_availqty, ps_supplycost), + regr_intercept(ps_availqty, ps_supplycost), regr_r2(ps_availqty, ps_supplycost) + from tpch.partsupp; +---- RESULTS +0.003223046670647307,5001.613715742804,1.035868858574101e-07 +---- TYPES +double, double, double +==== +---- QUERY +# Behavior of regression functions on null table +select regr_slope(d, e), regr_intercept(d, e), regr_r2(d, e) from functional.nulltable; +---- RESULTS +NULL,NULL,NULL +---- TYPES +double, double, double +==== +---- QUERY +# Behavior of regression functions on empty table +select regr_slope(f2, f2), regr_intercept(f2, f2), regr_r2(f2, f2) from functional.emptytable; +---- RESULTS +NULL,NULL,NULL +---- TYPES +double, double, double +==== +---- QUERY +# regr_slope() on different datatypes +select regr_slope(tinyint_col, tinyint_col), regr_slope(smallint_col, smallint_col), + regr_slope(int_col, int_col), regr_slope(bigint_col, bigint_col), regr_slope(float_col, float_col), + regr_slope(double_col, double_col), regr_slope(timestamp_col, timestamp_col) from functional.alltypes; +---- RESULTS +1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 +---- TYPES +double, double, double, double, double, double, double +==== +---- QUERY +# regr_intercept() on different datatypes +select regr_intercept(tinyint_col, tinyint_col), regr_intercept(smallint_col, smallint_col), + regr_intercept(int_col, int_col), regr_intercept(bigint_col, bigint_col), regr_intercept(float_col, float_col), + regr_intercept(double_col, double_col), regr_intercept(timestamp_col, timestamp_col) from functional.alltypes; +---- RESULTS +0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 +---- TYPES +double, double, double, double, double, double, double +==== +---- QUERY +# regr_r2() on different datatypes +select regr_r2(tinyint_col, tinyint_col), regr_r2(smallint_col, smallint_col), + regr_r2(int_col, int_col), regr_r2(bigint_col, bigint_col), regr_r2(float_col, float_col), + regr_r2(double_col, double_col), regr_r2(timestamp_col, timestamp_col) from functional.alltypes; +---- RESULTS +1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 +---- TYPES +double, double, double, double, double, double, double +==== +---- QUERY +# regr_slope(), regr_intercept(), regr_r2() on timestamp columns +select regr_slope(utctime, localtime), regr_intercept(utctime, localtime), + regr_r2(utctime, localtime) from functional.alltimezones; +---- RESULTS +0.9995840247725529, 593958.8067555428, 0.9999916154327646 +---- TYPES +double, double, double +==== +---- QUERY +# Since group by id will result in a single row, this test shows that regr_slope() returns null in case of a single row. +select id, regr_slope(int_col, int_col) from functional.alltypestiny group by id; +---- RESULTS +2,NULL +4,NULL +0,NULL +6,NULL +1,NULL +7,NULL +3,NULL +5,NULL +---- TYPES +int,double +==== +---- QUERY +# Since group by id will result in a single row, this test shows that regr_intercept() returns null in +# case of a single row. +select id, regr_intercept(int_col, int_col) from functional.alltypestiny group by id; +---- RESULTS +2,NULL +4,NULL +0,NULL +6,NULL +1,NULL +7,NULL +3,NULL +5,NULL +---- TYPES +int,double +==== +---- QUERY +# Since group by id will result in a single row, this test shows that regr_r2() returns null in case of a +# single row. +select id, regr_r2(int_col, int_col) from functional.alltypestiny group by id; +---- RESULTS +2,NULL +4,NULL +0,NULL +6,NULL +1,NULL +7,NULL +3,NULL +5,NULL +---- TYPES +int,double +==== +---- QUERY +# regr_slope(), regr_intercept(), regr_r2() on decimal datatype +select regr_slope(d3, d4), regr_intercept(d3, d4), regr_r2(d3, d4) from functional.decimal_tbl; +---- RESULTS +NULL,NULL,NULL +---- TYPES +double, double, double +==== +---- QUERY +select year, regr_slope(double_col, double_col) from functional.alltypes group by year; +---- RESULTS +2009,1.0 +2010,1.0 +---- TYPES +int,double +==== +---- QUERY +select year, regr_intercept(double_col, double_col) from functional.alltypes group by year; +---- RESULTS +2009,0.0 +2010,0.0 +---- TYPES +int,double +==== +---- QUERY +select year, regr_r2(double_col, double_col) from functional.alltypes group by year; +---- RESULTS +2009,1.0 +2010,1.0 +---- TYPES +int,double +==== +---- QUERY +select regr_slope(double_col, -double_col), regr_intercept(double_col, -double_col), + regr_r2(double_col, -double_col) from functional.alltypes; +---- RESULTS +-1.0, 0.0, 1.0 +---- TYPES +double, double, double +==== +---- QUERY +select regr_slope(double_col, double_col), regr_intercept(double_col, double_col), + regr_r2(double_col, double_col) from functional.alltypes; +---- RESULTS +1.0, 0.0, 1.0 +---- TYPES +double, double, double +==== +---- QUERY +select regr_slope(ss_sold_time_sk, ss_quantity), regr_intercept(ss_sold_time_sk, ss_quantity)/10000, + regr_r2(ss_sold_time_sk, ss_quantity) from tpcds.store_sales; +---- RESULTS +0.0602719636627,5.1709905412,1.87116649337e-08 +---- TYPES +double, double, double +==== +---- QUERY +select s_store_sk, regr_slope(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk) + from tpcds.store; +---- RESULTS +5,NULL +8,4.80120606296e-06 +12,-1.37354292706e-06 +1,NULL +2,-0.000255754475703 +3,-2.01036006341e-06 +4,-5.05103424244e-06 +6,-4.3565531677e-06 +7,1.21229193717e-06 +9,4.44553019714e-07 +10,3.9044206462e-06 +11,6.00483790103e-06 +---- TYPES +int, double +==== +---- QUERY +select s_store_sk, regr_intercept(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk) + from tpcds.store; +---- RESULTS +5,NULL +8, 244.41078639 +12, 296.416240104 +1,NULL +2, 1587.90537084399 +3, 251.125599973 +4, 268.395215604 +6, 264.570040249 +7, 234.323507488 +9, 244.099933952 +10, 223.133569068 +11, 210.405303328 +---- TYPES +int, double +==== +---- QUERY +select s_store_sk, regr_r2(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk) + from tpcds.store; +---- RESULTS +5,NULL +8, 1.0 +12, 0.107747035307 +1,NULL +2, 1.0 +3, 0.261613627698 +4, 0.776635209834 +6, 0.728906876949 +7, 0.00680634729367 +9, 0.000795425636424 +10, 0.0499353300472 +11, 0.1061202068 +---- TYPES +int, double +==== +---- QUERY +select id, double_col, regr_slope(double_col, int_col) over (partition by month order by id) from functional.alltypes + order by id limit 10; +---- RESULTS +0,0.0,NULL +1,10.1,10.1 +2,20.2,10.1 +3,30.3,10.1 +4,40.4,10.1 +5,50.5,10.1 +6,60.6,10.1 +7,70.7,10.1 +8,80.8,10.1 +9,90.9,10.1 +---- TYPES +int,double,double +==== +---- QUERY +select id, double_col, regr_intercept(double_col, int_col) over (partition by month order by id) from functional.alltypes + order by id limit 10; +---- RESULTS +0, 0.0,NULL +1, 10.1, 0.0 +2, 20.2, 0.0 +3, 30.3, 1.7763568394e-15 +4, 40.4, 0.0 +5, 50.5, 0.0 +6, 60.6, 3.5527136788e-15 +7, 70.7, -7.1054273576e-15 +8, 80.8, -7.1054273576e-15 +9, 90.9, -7.1054273576e-15 +---- TYPES +int,double,double +==== +---- QUERY +select id, double_col, regr_r2(double_col, int_col) over (partition by month order by id) from functional.alltypes + order by id limit 10; +---- RESULTS +0,0.0,NULL +1,10.1,1.0 +2,20.2,1.0 +3,30.3,1.0 +4,40.4,1.0 +5,50.5,1.0 +6,60.6,1.0 +7,70.7,1.0 +8,80.8,1.0 +9,90.9,1.0 +---- TYPES +int,double,double +==== +---- QUERY +# Regression functions when one column is filled with null +select regr_slope(null_int, rand()), regr_slope(rand(), null_int), regr_intercept(null_int, rand()), + regr_intercept(rand(), null_int), regr_r2(null_int, rand()), regr_r2(rand(), null_int) + from functional.nullrows; +---- RESULTS +NULL,NULL,NULL,NULL,NULL,NULL +---- TYPES +double, double, double, double, double, double +==== +---- QUERY +# Regression functions supporting join +select regr_slope(A.double_col, B.double_col), regr_intercept(A.double_col, B.double_col), + regr_r2(A.double_col, B.double_col) from functional.alltypes A, functional.alltypes B where A.id=B.id; +---- RESULTS +1.0, 0.0, 1.0 +---- TYPES +double, double, double +==== +---- QUERY +# Tests functioning of RegrSlopeRemoveState() +select s_store_sk, regr_slope(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between 5 preceding and 2 following) from tpcds.store; +---- RESULTS +5,-1.37354292706e-06 +8,-1.37354292706e-06 +12,-1.37354292706e-06 +1,-2.01036006341e-06 +2,-5.05103424244e-06 +3,-4.3565531677e-06 +4,1.21229193717e-06 +6,4.44553019714e-07 +7,3.9044206462e-06 +9,6.59936024162e-06 +10,4.13595759719e-06 +11,-2.68509660433e-06 +---- TYPES +int,double +==== +---- QUERY +# Tests functioning of RegrSlopeRemoveState() for regr_intercept() +select s_store_sk, regr_intercept(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between 5 preceding and 2 following) from tpcds.store; +---- RESULTS +5, 296.416240104 +8, 296.416240104 +12, 296.416240104 +1, 251.125599973 +2, 268.395215604 +3, 264.570040249 +4, 234.323507488 +6, 244.099933952 +7, 223.133569068 +9, 205.13592892 +10, 226.988621372 +11, 290.843308372 +---- TYPES +int,double +==== +---- QUERY +# Tests functioning of CorrRemoveState() for regr_r2() +select s_store_sk, regr_r2(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between 5 preceding and 2 following) from tpcds.store; +---- RESULTS +5, 0.107747035307 +8, 0.107747035307 +12, 0.107747035307 +1, 0.261613627698 +2, 0.776635209834 +3, 0.728906876949 +4, 0.00680634729367 +6, 0.000795425636424 +7, 0.0499353300472 +9, 0.0868111803219 +10, 0.0132688514648 +11, 0.00476393800297 +---- TYPES +int,double +==== +---- QUERY +# Mathematical operations on double can lead to variance becoming negative by a very small amount (around +1e-13), +# to avoid that a check is added (state->xvar < 0.0 || state->yvar <= 0.0), without which the below test will +# result in nan for certain cases. +# Testcase when dependednt variable becomes negative: +select s_store_sk, regr_r2(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between 1 preceding and 1 following) from tpcds.store; +---- RESULTS +5, 1.0 +8, 0.107747035307 +12, 1.0 +1, 1.0 +2, 0.261613627698 +3, 0.687656618674 +4, 0.782782611877 +6, 0.541316438509 +7, 0.0272991489992 +9, 0.942195083603 +10, 1.0 +11,NULL +---- TYPES +int,double +==== +---- QUERY +# regr_slope() when window size is 2 +select s_store_sk, regr_slope(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between 1 preceding and current row) from tpcds.store; +---- RESULTS +5,NULL +8,4.80120606296e-06 +12,-9.00681309118e-06 +1,NULL +2,-0.000255754475703 +3,0.0 +4,-1.00924694479e-05 +6,-3.48934955352e-05 +7,-0.000953195306915 +9,1.32728364256e-05 +10,1.00081893097e-05 +11,NULL +---- TYPES +int,double +==== +---- QUERY +# regr_intercept() when window size is 2 +select s_store_sk, regr_intercept(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between 1 preceding and current row) from tpcds.store; +---- RESULTS +5,NULL +8, 244.41078638954104 +12, 341.01161935181347 +1,NULL +2, 1587.9053708439897 +3, 236.0 +4, 312.2784702956196 +6, 543.9564370568922 +7, 8832.752449571763 +9, 178.14330273093714 +10, 200.98275763037407 +11,NULL +---- TYPES +int,double +==== +---- QUERY +# regr_r2() when window size is 2 +select s_store_sk, regr_r2(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between 1 preceding and current row) from tpcds.store; +---- RESULTS +5,NULL +8, 1.0 +12, 1.0 +1,NULL +2, 1.0 +3, 1.0 +4, 1.0 +6, 1.0 +7, 1.0 +9, 1.0 +10, 1.0 +11,NULL +---- TYPES +int,double +==== +---- QUERY +select s_store_sk, regr_slope(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between current row and 1 following) from tpcds.store; +---- RESULTS +5,4.80120606296e-06 +8,-9.00681309118e-06 +12,NULL +1,-0.000255754475703 +2,0.0 +3,-1.00924694479e-05 +4,-3.48934955352e-05 +6,-0.000953195306915 +7,1.32728364256e-05 +9,1.00081893097e-05 +10,NULL +11,NULL +---- TYPES +int,double +==== +---- QUERY +select s_store_sk, regr_intercept(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between current row and 1 following) from tpcds.store; +---- RESULTS +5, 244.41078639 +8, 341.011619352 +12,NULL +1, 1587.90537084399 +2, 236.0 +3, 312.278470296 +4, 543.956437057 +6, 8832.752449571763 +7, 178.143302731 +9, 200.98275763 +10,NULL +11,NULL +---- TYPES +int,double +==== +---- QUERY +select s_store_sk, regr_r2(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between current row and 1 following) from tpcds.store; +---- RESULTS +5, 1.0 +8, 1.0 +12,NULL +1, 1.0 +2, 1.0 +3, 1.0 +4, 1.0 +6, 1.0 +7, 1.0 +9, 1.0 +10,NULL +11,NULL +---- TYPES +int,double +==== +---- QUERY +select s_store_sk, regr_slope(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between current row and unbounded following) from tpcds.store; +---- RESULTS +12,NULL +8,-9.00681309118e-06 +5,-1.37354292706e-06 +11,NULL +10,NULL +9,1.00081893097e-05 +7,1.05678876878e-05 +6,5.25458148115e-06 +4,-2.68509660433e-06 +3,4.13595759719e-06 +2,6.59936024162e-06 +1,6.00483790103e-06 +---- TYPES +int,double +==== +---- QUERY +select s_store_sk, regr_intercept(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between current row and unbounded following) from tpcds.store; +---- RESULTS +12,NULL +8, 341.011619352 +5, 296.416240104 +11,NULL +10,NULL +9, 200.98275763 +7, 197.748657023 +6, 231.216488956 +4, 290.843308372 +3, 226.988621372 +2, 205.13592892 +1, 210.405303328 +---- TYPES +int,double +==== +---- QUERY +select s_store_sk, regr_r2(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between current row and unbounded following) from tpcds.store; +---- RESULTS +12,NULL +8, 1.0 +5, 0.107747035307 +11,NULL +10,NULL +9, 1.0 +7, 0.932586749287 +6, 0.0314560278457 +4, 0.00476393800297 +3, 0.0132688514648 +2, 0.0868111803219 +1, 0.1061202068 +---- TYPES +int,double +==== +---- QUERY +select s_store_sk, regr_slope(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between unbounded preceding and current row) from tpcds.store; +---- RESULTS +5,NULL +8,4.80120606296e-06 +12,-1.37354292706e-06 +1,NULL +2,-0.000255754475703 +3,-2.01036006341e-06 +4,-5.05103424244e-06 +6,-4.3565531677e-06 +7,1.21229193717e-06 +9,4.44553019714e-07 +10,3.9044206462e-06 +11,6.00483790103e-06 +---- TYPES +int,double +==== +---- QUERY +select s_store_sk, regr_intercept(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between unbounded preceding and current row) from tpcds.store; +---- RESULTS +5,NULL +8, 244.41078639 +12, 296.416240104 +1,NULL +2, 1587.90537084399 +3, 251.125599973 +4, 268.395215604 +6, 264.570040249 +7, 234.323507488 +9, 244.099933952 +10, 223.133569068 +11, 210.405303328 +---- TYPES +int,double +==== +---- QUERY +select s_store_sk, regr_r2(s_number_employees, s_floor_space) over (partition by s_city order by s_store_sk + rows between unbounded preceding and current row) from tpcds.store; +---- RESULTS +5,NULL +8, 1.0 +12, 0.10774703530702502 +1,NULL +2, 1.0 +3, 0.26161362769774743 +4, 0.7766352098335481 +6, 0.7289068769493157 +7, 0.006806347293667032 +9, 7.954256364237451E-4 +10, 0.049935330047174945 +11, 0.10612020680026239 +---- TYPES +int,double +==== +---- QUERY +# Testcase when independent variable becomes negative: +select s_store_sk, regr_r2(s_floor_space, s_number_employees) over (partition by s_city order by + s_store_sk rows between 1 preceding and 1 following) from tpcds.store; +---- RESULTS +5, 1.0 +8, 0.107747035307 +12, 1.0 +1, 1.0 +2, 0.261613627698 +3, 0.687656618674 +4, 0.782782611877 +6, 0.541316438509 +7, 0.0272991489992 +9, 0.942195083603 +10, 1.0 +11,NULL +---- TYPES +int, double +====
