This is an automated email from the ASF dual-hosted git repository.
yiconghuang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/texera.git
The following commit(s) were added to refs/heads/main by this push:
new 6980f19376 refactor(core): unify type ops and reuse in sort/agg (#4024)
6980f19376 is described below
commit 6980f19376c016b3505663827465aa41283e608e
Author: carloea2 <[email protected]>
AuthorDate: Fri Nov 21 11:54:38 2025 -0800
refactor(core): unify type ops and reuse in sort/agg (#4024)
### What changes were proposed in this PR?
1. **Centralize and extend `AttributeType` operations**
Move and refactor the existing attribute-type helpers into
`AttributeTypeUtils`:
* `compare`, `add`, `zeroValue`, `minValue`, `maxValue`.
* Unify null-handling semantics across these operations. (use of
match-case instead of if + match)
Extend support to additional types:
* Add comparison/aggregation support for `BOOLEAN`, `STRING`, and
`BINARY`.
Change numeric coercion strategy:
* Coerce numeric values to `Number` instead of a specific primitive type
(e.g., `Double`) to reduce `ClassCastException`s when the input is not
strictly schema-validated.
* Preserve existing comparison semantics for doubles by delegating to
`java.lang.Double.compare` (including handling of ±∞ and `NaN`).
Introduce “identity” helpers:
* `zeroValue` returns an additive identity for numeric/timestamp types,
and `Array.emptyByteArray` for `BINARY` as a safe, non-throwing
identity.
* `minValue` / `maxValue`: provide lower/upper bounds for supported
numeric and timestamp types.
2. **Refactor operators to reuse `AttributeTypeUtils`**
* `AggregationOperation`: implement `SUM` / `MIN` / `MAX` using the
centralized helpers instead of custom per-operator logic.
* `StableMergeSortOpExec`: reuse the typed compare logic from
`AttributeTypeUtils`.
* `SortPartitionsOpExec`: simplify to use a one-liner comparator based
on `AttributeTypeUtils.compare` (or a thin wrapper) for clarity and
reuse.
3. **Add tests**
*
workflow-core/src/test/scala/org/apache/amber/core\tuple/AttributeTypeUtilsSpec.scala
* **compare**: Verifies correct null-handling and ordering for INTEGER,
BOOLEAN, TIMESTAMP, STRING, and BINARY values.
* **add**: Ensures `null` acts as identity and confirms correct addition
for INTEGER, LONG, DOUBLE, and TIMESTAMP.
* **zeroValue**: Checks that numeric/timestamp zero identities and empty
binary array for BINARY are returned, and that unsupported types (e.g.,
STRING) throw.
* **minValue / maxValue**: Validate correct numeric and timestamp
bounds, BINARY minimum, and exceptions for unsupported types (e.g.,
BOOLEAN, STRING).
*
workflow-operator/src/test/scala/org/apache/amber/operator/aggregate/AggregateOpSpec.scala
* Verifies `getAggregationAttribute` chooses the correct result type for
different functions (SUM keeps input type, COUNT → INTEGER, CONCAT →
STRING).
* Checks `getAggFunc` SUM behavior for INTEGER and DOUBLE columns,
ensuring correct totals and preserved fractional values.
* Tests COUNT, CONCAT, MIN, MAX, and AVERAGE aggregations, including
correct handling of `null` values and edge cases like “no rows”.
* Confirms `getFinal` rewrites COUNT into a SUM on the intermediate
count column and rewires attributes correctly for non-COUNT functions.
* Exercises `AggregateOpExec` end-to-end: SUM grouped by a key (city)
and combined global SUM+COUNT with no group-by keys, validating the
produced tuples.
5. **Scope / non-goals / Extras**
* No change to external APIs
* Main behavior changes are localized to `AttributeType` operations and
the operators that consume them.
---
**Any related issues, documentation, discussions?**
* Closes: #3923
**How was this PR tested?**
Workflow Image:
<img width="1684" height="859" alt="image"
src="https://github.com/user-attachments/assets/2682ebdc-0f45-40c6-b304-0cea0b76b44f"
/>
Workflow file:
[agg_test_1.json](https://github.com/user-attachments/files/23540242/agg_test_1.json)
Python benchmark:
```
import pandas as pd
df = pd.read_csv("/mnt/data/test.csv")
# Limit BEFORE sorting
df_limited = df.head(1000)
# Now sort ascending
df_sorted = df_limited.sort_values("rna_umis", ascending=True)
# Group by pass_all_filters with aggregations
agg = df_sorted.groupby("pass_all_filters")["rna_umis"].agg(
min="min", max="max", count="count", avg="mean", sum="sum"
).reset_index()
agg
```
Python Result:
<img width="928" height="188" alt="image"
src="https://github.com/user-attachments/assets/69da33cd-ada4-4b05-a3f9-ae139f8575b9"
/>
Texera Result (Avg):
False | 0 | 80926 | 240 | 15987.68 | 3837043
-- | -- | -- | -- | -- | --
True | 11893 | 102559 | 760 | 35557.93 | 27024027
For timestamps test:
- 1970-01-01T00:00:00Z
- 2000-02-29T12:00:00Z
- 2024-12-31T23:59:59Z
1. Avg:
- New version: 909835199750
- Previous version: 909835199750
2. Sum:
- New version: 2055-03-01T05:59:59.000Z (UTC)
- Previous version: 2055-03-01T11:59:59.000Z (UTC-6; Mexico City Time)
**Was this PR authored or co-authored using generative AI tooling?**
* Co-authored with ChatGPT.
---
.../amber/core/tuple/AttributeTypeUtils.scala | 131 ++++++++
.../amber/core/tuple/AttributeTypeUtilsSpec.scala | 135 ++++++++-
.../operator/aggregate/AggregationOperation.scala | 120 +-------
.../sortPartitions/SortPartitionsOpExec.scala | 23 +-
.../amber/operator/aggregate/AggregateOpSpec.scala | 333 +++++++++++++++++++++
5 files changed, 621 insertions(+), 121 deletions(-)
diff --git
a/common/workflow-core/src/main/scala/org/apache/amber/core/tuple/AttributeTypeUtils.scala
b/common/workflow-core/src/main/scala/org/apache/amber/core/tuple/AttributeTypeUtils.scala
index e4fdcb4611..7cbfb27179 100644
---
a/common/workflow-core/src/main/scala/org/apache/amber/core/tuple/AttributeTypeUtils.scala
+++
b/common/workflow-core/src/main/scala/org/apache/amber/core/tuple/AttributeTypeUtils.scala
@@ -387,6 +387,137 @@ object AttributeTypeUtils extends Serializable {
}
}
+ /** Three-way compare for the given attribute type.
+ * Returns < 0 if left < right, > 0 if left > right, 0 if equal.
+ * Null semantics: null < non-null (both null => 0).
+ */
+ @throws[UnsupportedOperationException]
+ def compare(left: Any, right: Any, attrType: AttributeType): Int =
+ (left, right) match {
+ case (null, null) => 0
+ case (null, _) => -1
+ case (_, null) => 1
+ case _ =>
+ attrType match {
+ case AttributeType.INTEGER =>
+ java.lang.Integer.compare(
+ left.asInstanceOf[Number].intValue(),
+ right.asInstanceOf[Number].intValue()
+ )
+ case AttributeType.LONG =>
+ java.lang.Long.compare(
+ left.asInstanceOf[Number].longValue(),
+ right.asInstanceOf[Number].longValue()
+ )
+ case AttributeType.DOUBLE =>
+ java.lang.Double.compare(
+ left.asInstanceOf[Number].doubleValue(),
+ right.asInstanceOf[Number].doubleValue()
+ ) // -Infinity < ... < -0.0 < +0.0 < ... < +Infinity < NaN
+ case AttributeType.BOOLEAN =>
+ java.lang.Boolean.compare(
+ left.asInstanceOf[Boolean],
+ right.asInstanceOf[Boolean]
+ )
+ case AttributeType.TIMESTAMP =>
+ java.lang.Long.compare(
+ left.asInstanceOf[Timestamp].getTime,
+ right.asInstanceOf[Timestamp].getTime
+ )
+ case AttributeType.STRING =>
+ left.toString.compareTo(right.toString)
+ case AttributeType.BINARY =>
+ java.util.Arrays.compareUnsigned(
+ left.asInstanceOf[Array[Byte]],
+ right.asInstanceOf[Array[Byte]]
+ )
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Unsupported attribute type for compare: $attrType"
+ )
+ }
+ }
+
+ /** Type-aware addition (null is identity). */
+ @throws[UnsupportedOperationException]
+ def add(left: Object, right: Object, attrType: AttributeType): Object =
+ (left, right) match {
+ case (null, null) => zeroValue(attrType)
+ case (null, right) => right
+ case (left, null) => left
+ case (left, right) =>
+ attrType match {
+ case AttributeType.INTEGER =>
+ java.lang.Integer.valueOf(
+ left.asInstanceOf[Number].intValue() +
right.asInstanceOf[Number].intValue()
+ )
+ case AttributeType.LONG =>
+ java.lang.Long.valueOf(
+ left.asInstanceOf[Number].longValue() +
right.asInstanceOf[Number].longValue()
+ )
+ case AttributeType.DOUBLE =>
+ java.lang.Double.valueOf(
+ left.asInstanceOf[Number].doubleValue() +
right.asInstanceOf[Number].doubleValue()
+ )
+ case AttributeType.TIMESTAMP =>
+ new Timestamp(
+ left.asInstanceOf[Timestamp].getTime +
right.asInstanceOf[Timestamp].getTime
+ )
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Unsupported attribute type for addition: $attrType"
+ )
+ }
+ }
+
+ /** Additive identity for supported numeric/timestamp types.
+ * For BINARY an empty array is returned as an identity value.
+ */
+ @throws[UnsupportedOperationException]
+ def zeroValue(attrType: AttributeType): Object =
+ attrType match {
+ case AttributeType.INTEGER => java.lang.Integer.valueOf(0)
+ case AttributeType.LONG => java.lang.Long.valueOf(0L)
+ case AttributeType.DOUBLE => java.lang.Double.valueOf(0.0d)
+ case AttributeType.TIMESTAMP => new Timestamp(0L)
+ case AttributeType.BINARY => Array.emptyByteArray
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Unsupported attribute type for zero value: $attrType"
+ )
+ }
+
+ /** Returns the maximum possible value for a given attribute type. */
+ @throws[UnsupportedOperationException]
+ def maxValue(attrType: AttributeType): Object =
+ attrType match {
+ case AttributeType.INTEGER =>
java.lang.Integer.valueOf(Integer.MAX_VALUE)
+ case AttributeType.LONG =>
java.lang.Long.valueOf(java.lang.Long.MAX_VALUE)
+ case AttributeType.DOUBLE =>
java.lang.Double.valueOf(java.lang.Double.MAX_VALUE)
+ case AttributeType.TIMESTAMP => new Timestamp(java.lang.Long.MAX_VALUE)
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Unsupported attribute type for max value: $attrType"
+ )
+ }
+
+ /** Returns the minimum possible value for a given attribute type. (note
Double.MIN_VALUE is > 0).
+ * For BINARY under lexicographic order, the empty array is the global
minimum.
+ */
+ @throws[UnsupportedOperationException]
+ def minValue(attrType: AttributeType): Object =
+ attrType match {
+ case AttributeType.INTEGER =>
java.lang.Integer.valueOf(Integer.MIN_VALUE)
+ case AttributeType.LONG =>
java.lang.Long.valueOf(java.lang.Long.MIN_VALUE)
+ case AttributeType.DOUBLE =>
java.lang.Double.valueOf(java.lang.Double.MIN_VALUE)
+ case AttributeType.TIMESTAMP => new Timestamp(0L)
+ case AttributeType.BINARY => Array.emptyByteArray
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Unsupported attribute type for min value: $attrType"
+ )
+ }
+
class AttributeTypeException(msg: String, cause: Throwable = null)
extends IllegalArgumentException(msg, cause) {}
}
diff --git
a/common/workflow-core/src/test/scala/org/apache/amber/core/tuple/AttributeTypeUtilsSpec.scala
b/common/workflow-core/src/test/scala/org/apache/amber/core/tuple/AttributeTypeUtilsSpec.scala
index 24c998b3f7..53e5f68430 100644
---
a/common/workflow-core/src/test/scala/org/apache/amber/core/tuple/AttributeTypeUtilsSpec.scala
+++
b/common/workflow-core/src/test/scala/org/apache/amber/core/tuple/AttributeTypeUtilsSpec.scala
@@ -24,11 +24,17 @@ import org.apache.amber.core.tuple.AttributeTypeUtils.{
AttributeTypeException,
inferField,
inferSchemaFromRows,
- parseField
+ parseField,
+ compare,
+ add,
+ minValue,
+ maxValue,
+ zeroValue
}
import org.scalatest.funsuite.AnyFunSuite
class AttributeTypeUtilsSpec extends AnyFunSuite {
+
// Unit Test for Infer Schema
test("type should get inferred correctly individually") {
@@ -190,4 +196,131 @@ class AttributeTypeUtilsSpec extends AnyFunSuite {
assert(parseField("anything", AttributeType.ANY) == "anything")
}
+ test("compare correctly handles null values for different attribute types") {
+ assert(compare(null, null, INTEGER) == 0)
+ assert(compare(null, 10, INTEGER) < 0)
+ assert(compare(10, null, INTEGER) > 0)
+ }
+
+ test("compare correctly orders numeric, boolean, timestamp, string and
binary values") {
+ assert(compare(1, 2, INTEGER) < 0)
+ assert(compare(2, 1, INTEGER) > 0)
+ assert(compare(5, 5, INTEGER) == 0)
+
+ assert(compare(false, true, BOOLEAN) < 0)
+ assert(compare(true, false, BOOLEAN) > 0)
+ assert(compare(true, true, BOOLEAN) == 0)
+
+ val earlierTimestamp = new java.sql.Timestamp(1000L)
+ val laterTimestamp = new java.sql.Timestamp(2000L)
+ assert(compare(earlierTimestamp, laterTimestamp, TIMESTAMP) < 0)
+ assert(compare(laterTimestamp, earlierTimestamp, TIMESTAMP) > 0)
+
+ assert(compare("apple", "banana", STRING) < 0)
+ assert(compare("banana", "apple", STRING) > 0)
+ assert(compare("same", "same", STRING) == 0)
+
+ val firstBytes = Array[Byte](0, 1, 2)
+ val secondBytes = Array[Byte](0, 2, 0)
+ assert(compare(firstBytes, secondBytes, BINARY) < 0)
+ }
+
+ test("add correctly handles null values as identity for numeric types") {
+ val integerZeroFromAdd = add(null, null, INTEGER).asInstanceOf[Int]
+ assert(integerZeroFromAdd == 0)
+
+ val rightOnlyResult =
+ add(null, java.lang.Integer.valueOf(5), INTEGER).asInstanceOf[Int]
+ assert(rightOnlyResult == 5)
+
+ val leftOnlyResult =
+ add(java.lang.Integer.valueOf(7), null, INTEGER).asInstanceOf[Int]
+ assert(leftOnlyResult == 7)
+ }
+
+ test("add correctly adds integer, long, double and timestamp values") {
+ val integerSum =
+ add(java.lang.Integer.valueOf(3), java.lang.Integer.valueOf(4), INTEGER)
+ .asInstanceOf[Int]
+ assert(integerSum == 7)
+
+ val longSum =
+ add(java.lang.Long.valueOf(10L), java.lang.Long.valueOf(5L), LONG)
+ .asInstanceOf[Long]
+ assert(longSum == 15L)
+
+ val doubleSum =
+ add(java.lang.Double.valueOf(1.5), java.lang.Double.valueOf(2.5), DOUBLE)
+ .asInstanceOf[Double]
+ assert(doubleSum == 4.0)
+
+ val firstTimestamp = new java.sql.Timestamp(1000L)
+ val secondTimestamp = new java.sql.Timestamp(2500L)
+ val timestampSum =
+ add(firstTimestamp, secondTimestamp,
TIMESTAMP).asInstanceOf[java.sql.Timestamp]
+ assert(timestampSum.getTime == 3500L)
+ }
+
+ test("zeroValue returns correct numeric and timestamp identity values") {
+ val integerZero = zeroValue(INTEGER).asInstanceOf[Int]
+ val longZero = zeroValue(LONG).asInstanceOf[Long]
+ val doubleZero = zeroValue(DOUBLE).asInstanceOf[Double]
+ val timestampZero = zeroValue(TIMESTAMP).asInstanceOf[java.sql.Timestamp]
+
+ assert(integerZero == 0)
+ assert(longZero == 0L)
+ assert(doubleZero == 0.0d)
+ assert(timestampZero.getTime == 0L)
+ }
+
+ test("zeroValue returns empty binary array and fails for unsupported types")
{
+ val binaryZero = zeroValue(BINARY).asInstanceOf[Array[Byte]]
+ assert(binaryZero.isEmpty)
+
+ assertThrows[UnsupportedOperationException] {
+ zeroValue(STRING)
+ }
+ }
+
+ test("maxValue returns correct maximum numeric bounds") {
+ val integerMax = maxValue(INTEGER).asInstanceOf[Int]
+ val longMax = maxValue(LONG).asInstanceOf[Long]
+ val doubleMax = maxValue(DOUBLE).asInstanceOf[Double]
+
+ assert(integerMax == Int.MaxValue)
+ assert(longMax == Long.MaxValue)
+ assert(doubleMax == Double.MaxValue)
+ }
+
+ test("maxValue returns maximum timestamp and fails for unsupported types") {
+ val timestampMax = maxValue(TIMESTAMP).asInstanceOf[java.sql.Timestamp]
+ assert(timestampMax.getTime == Long.MaxValue)
+
+ assertThrows[UnsupportedOperationException] {
+ maxValue(BOOLEAN)
+ }
+ }
+
+ test("minValue returns correct minimum numeric bounds") {
+ val integerMin = minValue(INTEGER).asInstanceOf[Int]
+ val longMin = minValue(LONG).asInstanceOf[Long]
+ val doubleMin = minValue(DOUBLE).asInstanceOf[Double]
+
+ assert(integerMin == Int.MinValue)
+ assert(longMin == Long.MinValue)
+ assert(doubleMin == java.lang.Double.MIN_VALUE)
+ }
+
+ test("minValue returns timestamp epoch and empty binary array, and fails for
unsupported types") {
+ val timestampMin = minValue(TIMESTAMP).asInstanceOf[java.sql.Timestamp]
+ val binaryMin = minValue(BINARY).asInstanceOf[Array[Byte]]
+
+ assert(timestampMin.getTime == 0L)
+
+ assert(binaryMin.isEmpty)
+
+ assertThrows[UnsupportedOperationException] {
+ minValue(STRING)
+ }
+ }
}
diff --git
a/common/workflow-operator/src/main/scala/org/apache/amber/operator/aggregate/AggregationOperation.scala
b/common/workflow-operator/src/main/scala/org/apache/amber/operator/aggregate/AggregationOperation.scala
index 8818d831e1..931163e9ed 100644
---
a/common/workflow-operator/src/main/scala/org/apache/amber/operator/aggregate/AggregationOperation.scala
+++
b/common/workflow-operator/src/main/scala/org/apache/amber/operator/aggregate/AggregationOperation.scala
@@ -21,11 +21,9 @@ package org.apache.amber.operator.aggregate
import com.fasterxml.jackson.annotation.{JsonIgnore, JsonProperty,
JsonPropertyDescription}
import com.kjetland.jackson.jsonSchema.annotations.{JsonSchemaInject,
JsonSchemaTitle}
-import org.apache.amber.core.tuple.AttributeTypeUtils.parseTimestamp
-import org.apache.amber.core.tuple.{Attribute, AttributeType, Tuple}
+import org.apache.amber.core.tuple.{Attribute, AttributeType,
AttributeTypeUtils, Tuple}
import org.apache.amber.operator.metadata.annotations.AutofillAttributeName
-import java.sql.Timestamp
import javax.validation.constraints.NotNull
case class AveragePartialObj(sum: Double, count: Double) extends Serializable
{}
@@ -130,12 +128,12 @@ class AggregationOperation {
)
}
new DistributedAggregation[Object](
- () => zero(attributeType),
+ () => AttributeTypeUtils.zeroValue(attributeType),
(partial, tuple) => {
val value = tuple.getField[Object](attribute)
- add(partial, value, attributeType)
+ AttributeTypeUtils.add(partial, value, attributeType)
},
- (partial1, partial2) => add(partial1, partial2, attributeType),
+ (partial1, partial2) => AttributeTypeUtils.add(partial1, partial2,
attributeType),
partial => partial
)
}
@@ -190,15 +188,16 @@ class AggregationOperation {
)
}
new DistributedAggregation[Object](
- () => maxValue(attributeType),
+ () => AttributeTypeUtils.maxValue(attributeType),
(partial, tuple) => {
val value = tuple.getField[Object](attribute)
- val comp = compare(value, partial, attributeType)
+ val comp = AttributeTypeUtils.compare(value, partial, attributeType)
if (value != null && comp < 0) value else partial
},
(partial1, partial2) =>
- if (compare(partial1, partial2, attributeType) < 0) partial1 else
partial2,
- partial => if (partial == maxValue(attributeType)) null else partial
+ if (AttributeTypeUtils.compare(partial1, partial2, attributeType) < 0)
partial1
+ else partial2,
+ partial => if (partial == AttributeTypeUtils.maxValue(attributeType))
null else partial
)
}
@@ -214,15 +213,16 @@ class AggregationOperation {
)
}
new DistributedAggregation[Object](
- () => minValue(attributeType),
+ () => AttributeTypeUtils.minValue(attributeType),
(partial, tuple) => {
val value = tuple.getField[Object](attribute)
- val comp = compare(value, partial, attributeType)
+ val comp = AttributeTypeUtils.compare(value, partial, attributeType)
if (value != null && comp > 0) value else partial
},
(partial1, partial2) =>
- if (compare(partial1, partial2, attributeType) > 0) partial1 else
partial2,
- partial => if (partial == maxValue(attributeType)) null else partial
+ if (AttributeTypeUtils.compare(partial1, partial2, attributeType) > 0)
partial1
+ else partial2,
+ partial => if (partial == AttributeTypeUtils.maxValue(attributeType))
null else partial
)
}
@@ -232,7 +232,7 @@ class AggregationOperation {
return None
if (tuple.getSchema.getAttribute(attribute).getType ==
AttributeType.TIMESTAMP)
- Option(parseTimestamp(value.toString).getTime.toDouble)
+
Option(AttributeTypeUtils.parseTimestamp(value.toString).getTime.toDouble)
else Option(value.toString.toDouble)
}
@@ -254,94 +254,4 @@ class AggregationOperation {
}
)
}
-
- // return a.compare(b),
- // < 0 if a < b,
- // > 0 if a > b,
- // 0 if a = b
- private def compare(a: Object, b: Object, attributeType: AttributeType): Int
= {
- if (a == null && b == null) {
- return 0
- } else if (a == null) {
- return -1
- } else if (b == null) {
- return 1
- }
- attributeType match {
- case AttributeType.INTEGER =>
a.asInstanceOf[Integer].compareTo(b.asInstanceOf[Integer])
- case AttributeType.DOUBLE =>
-
a.asInstanceOf[java.lang.Double].compareTo(b.asInstanceOf[java.lang.Double])
- case AttributeType.LONG =>
-
a.asInstanceOf[java.lang.Long].compareTo(b.asInstanceOf[java.lang.Long])
- case AttributeType.TIMESTAMP =>
-
a.asInstanceOf[Timestamp].getTime.compareTo(b.asInstanceOf[Timestamp].getTime)
- case _ =>
- throw new UnsupportedOperationException(
- "Unsupported attribute type for comparison: " + attributeType
- )
- }
- }
-
- private def add(a: Object, b: Object, attributeType: AttributeType): Object
= {
- if (a == null && b == null) {
- return zero(attributeType)
- } else if (a == null) {
- return b
- } else if (b == null) {
- return a
- }
- attributeType match {
- case AttributeType.INTEGER =>
- Integer.valueOf(a.asInstanceOf[Integer] + b.asInstanceOf[Integer])
- case AttributeType.DOUBLE =>
- java.lang.Double.valueOf(
- a.asInstanceOf[java.lang.Double] + b.asInstanceOf[java.lang.Double]
- )
- case AttributeType.LONG =>
- java.lang.Long.valueOf(a.asInstanceOf[java.lang.Long] +
b.asInstanceOf[java.lang.Long])
- case AttributeType.TIMESTAMP =>
- new Timestamp(a.asInstanceOf[Timestamp].getTime +
b.asInstanceOf[Timestamp].getTime)
- case _ =>
- throw new UnsupportedOperationException(
- "Unsupported attribute type for addition: " + attributeType
- )
- }
- }
-
- private def zero(attributeType: AttributeType): Object =
- attributeType match {
- case AttributeType.INTEGER => java.lang.Integer.valueOf(0)
- case AttributeType.DOUBLE => java.lang.Double.valueOf(0)
- case AttributeType.LONG => java.lang.Long.valueOf(0)
- case AttributeType.TIMESTAMP => new Timestamp(0)
- case _ =>
- throw new UnsupportedOperationException(
- "Unsupported attribute type for zero value: " + attributeType
- )
- }
-
- private def maxValue(attributeType: AttributeType): Object =
- attributeType match {
- case AttributeType.INTEGER => Integer.MAX_VALUE.asInstanceOf[Object]
- case AttributeType.DOUBLE =>
java.lang.Double.MAX_VALUE.asInstanceOf[Object]
- case AttributeType.LONG =>
java.lang.Long.MAX_VALUE.asInstanceOf[Object]
- case AttributeType.TIMESTAMP => new Timestamp(java.lang.Long.MAX_VALUE)
- case _ =>
- throw new UnsupportedOperationException(
- "Unsupported attribute type for max value: " + attributeType
- )
- }
-
- private def minValue(attributeType: AttributeType): Object =
- attributeType match {
- case AttributeType.INTEGER => Integer.MIN_VALUE.asInstanceOf[Object]
- case AttributeType.DOUBLE =>
java.lang.Double.MIN_VALUE.asInstanceOf[Object]
- case AttributeType.LONG =>
java.lang.Long.MIN_VALUE.asInstanceOf[Object]
- case AttributeType.TIMESTAMP => new Timestamp(0)
- case _ =>
- throw new UnsupportedOperationException(
- "Unsupported attribute type for min value: " + attributeType
- )
- }
-
}
diff --git
a/common/workflow-operator/src/main/scala/org/apache/amber/operator/sortPartitions/SortPartitionsOpExec.scala
b/common/workflow-operator/src/main/scala/org/apache/amber/operator/sortPartitions/SortPartitionsOpExec.scala
index ac6a9da59c..5748a41da6 100644
---
a/common/workflow-operator/src/main/scala/org/apache/amber/operator/sortPartitions/SortPartitionsOpExec.scala
+++
b/common/workflow-operator/src/main/scala/org/apache/amber/operator/sortPartitions/SortPartitionsOpExec.scala
@@ -20,7 +20,7 @@
package org.apache.amber.operator.sortPartitions
import org.apache.amber.core.executor.OperatorExecutor
-import org.apache.amber.core.tuple.{AttributeType, Tuple, TupleLike}
+import org.apache.amber.core.tuple.{AttributeTypeUtils, Tuple, TupleLike}
import org.apache.amber.util.JSONUtils.objectMapper
import scala.collection.mutable.ArrayBuffer
@@ -47,18 +47,11 @@ class SortPartitionsOpExec(descString: String) extends
OperatorExecutor {
override def onFinish(port: Int): Iterator[TupleLike] = sortTuples()
- private def compareTuples(t1: Tuple, t2: Tuple): Boolean = {
- val attributeType =
t1.getSchema.getAttribute(desc.sortAttributeName).getType
- val attributeIndex = t1.getSchema.getIndex(desc.sortAttributeName)
- attributeType match {
- case AttributeType.LONG =>
- t1.getField[Long](attributeIndex) < t2.getField[Long](attributeIndex)
- case AttributeType.INTEGER =>
- t1.getField[Int](attributeIndex) < t2.getField[Int](attributeIndex)
- case AttributeType.DOUBLE =>
- t1.getField[Double](attributeIndex) <
t2.getField[Double](attributeIndex)
- case _ =>
- true // unsupported type
- }
- }
+ private def compareTuples(tuple1: Tuple, tuple2: Tuple): Boolean =
+ AttributeTypeUtils.compare(
+ tuple1.getField[Any](tuple1.getSchema.getIndex(desc.sortAttributeName)),
+ tuple2.getField[Any](tuple2.getSchema.getIndex(desc.sortAttributeName)),
+ tuple1.getSchema.getAttribute(desc.sortAttributeName).getType
+ ) < 0
+
}
diff --git
a/common/workflow-operator/src/test/scala/org/apache/amber/operator/aggregate/AggregateOpSpec.scala
b/common/workflow-operator/src/test/scala/org/apache/amber/operator/aggregate/AggregateOpSpec.scala
new file mode 100644
index 0000000000..9eb405d817
--- /dev/null
+++
b/common/workflow-operator/src/test/scala/org/apache/amber/operator/aggregate/AggregateOpSpec.scala
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.amber.operator.aggregate
+
+import org.apache.amber.core.tuple.{Attribute, AttributeType, Schema, Tuple}
+import org.apache.amber.util.JSONUtils.objectMapper
+import org.scalatest.funsuite.AnyFunSuite
+
+class AggregateOpSpec extends AnyFunSuite {
+
+ /** Helpers */
+
+ private def makeAggregationOp(
+ fn: AggregationFunction,
+ attributeName: String,
+ resultName: String
+ ): AggregationOperation = {
+ val operation = new AggregationOperation()
+ operation.aggFunction = fn
+ operation.attribute = attributeName
+ operation.resultAttribute = resultName
+ operation
+ }
+
+ private def makeSchema(fields: (String, AttributeType)*): Schema =
+ Schema(fields.map { case (n, t) => new Attribute(n, t) }.toList)
+
+ private def makeTuple(schema: Schema, values: Any*): Tuple =
+ Tuple(schema, values.toArray)
+
+ test("getAggregationAttribute keeps original type for SUM") {
+ val operation = makeAggregationOp(AggregationFunction.SUM, "amount",
"total_amount")
+ val attr = operation.getAggregationAttribute(AttributeType.DOUBLE)
+
+ assert(attr.getName == "total_amount")
+ assert(attr.getType == AttributeType.DOUBLE)
+ }
+
+ test("getAggregationAttribute maps COUNT result to INTEGER regardless of
input type") {
+ val operation = makeAggregationOp(AggregationFunction.COUNT, "quantity",
"row_count")
+ val attr = operation.getAggregationAttribute(AttributeType.LONG)
+
+ assert(attr.getName == "row_count")
+ assert(attr.getType == AttributeType.INTEGER)
+ }
+
+ test("getAggregationAttribute maps CONCAT result type to STRING") {
+ val operation = makeAggregationOp(AggregationFunction.CONCAT, "tag",
"all_tags")
+ val attr = operation.getAggregationAttribute(AttributeType.INTEGER)
+
+ assert(attr.getName == "all_tags")
+ assert(attr.getType == AttributeType.STRING)
+ }
+
+ //
---------------------------------------------------------------------------
+ // Basic DistributedAggregation behaviour via AggregationOperation.getAggFunc
+ //
---------------------------------------------------------------------------
+
+ test("SUM aggregation over INTEGER column adds values correctly") {
+ val schema = makeSchema("amount" -> AttributeType.INTEGER)
+ val tuple1 = makeTuple(schema, 5)
+ val tuple2 = makeTuple(schema, 7)
+ val tuple3 = makeTuple(schema, 3)
+
+ val operation = makeAggregationOp(AggregationFunction.SUM, "amount",
"total_amount")
+ val agg = operation.getAggFunc(AttributeType.INTEGER)
+
+ var partial = agg.init()
+ partial = agg.iterate(partial, tuple1)
+ partial = agg.iterate(partial, tuple2)
+ partial = agg.iterate(partial, tuple3)
+
+ val result = agg.finalAgg(partial).asInstanceOf[Number].intValue()
+ assert(result == 15)
+ }
+
+ test("SUM aggregation over DOUBLE column keeps fractional part") {
+ val schema = makeSchema("score" -> AttributeType.DOUBLE)
+ val tuple1 = makeTuple(schema, 1.25)
+ val tuple2 = makeTuple(schema, 2.75)
+
+ val operation = makeAggregationOp(AggregationFunction.SUM, "score",
"total_score")
+ val agg = operation.getAggFunc(AttributeType.DOUBLE)
+
+ var partial = agg.init()
+ partial = agg.iterate(partial, tuple1)
+ partial = agg.iterate(partial, tuple2)
+
+ val result =
agg.finalAgg(partial).asInstanceOf[java.lang.Double].doubleValue()
+ assert(math.abs(result - 4.0) < 1e-6)
+ }
+
+ test("COUNT aggregation with attribute == null counts all rows") {
+ val schema = makeSchema("points" -> AttributeType.INTEGER)
+ val tuple1 = makeTuple(schema, 10)
+ val tuple2 = makeTuple(schema, null)
+ val tuple3 = makeTuple(schema, 20)
+
+ val operation = makeAggregationOp(AggregationFunction.COUNT, null,
"row_count")
+ val agg = operation.getAggFunc(AttributeType.INTEGER)
+
+ var partial = agg.init()
+ partial = agg.iterate(partial, tuple1)
+ partial = agg.iterate(partial, tuple2)
+ partial = agg.iterate(partial, tuple3)
+
+ val result = agg.finalAgg(partial).asInstanceOf[Number].intValue()
+ assert(result == 3)
+ }
+
+ test("COUNT aggregation with attribute set only counts non-null values") {
+ val schema = makeSchema("points" -> AttributeType.INTEGER)
+ val tuple1 = makeTuple(schema, 10)
+ val tuple2 = makeTuple(schema, null)
+ val tuple3 = makeTuple(schema, 5)
+
+ val operation = makeAggregationOp(AggregationFunction.COUNT, "points",
"non_null_points")
+ val agg = operation.getAggFunc(AttributeType.INTEGER)
+
+ var partial = agg.init()
+ partial = agg.iterate(partial, tuple1)
+ partial = agg.iterate(partial, tuple2)
+ partial = agg.iterate(partial, tuple3)
+
+ val result = agg.finalAgg(partial).asInstanceOf[Number].intValue()
+ assert(result == 2)
+ }
+
+ test("CONCAT aggregation concatenates string representations with commas") {
+ val schema = makeSchema("tag" -> AttributeType.STRING)
+ val tuple1 = makeTuple(schema, "red")
+ val tuple2 = makeTuple(schema, null)
+ val tuple3 = makeTuple(schema, "blue")
+
+ val operation = makeAggregationOp(AggregationFunction.CONCAT, "tag",
"all_tags")
+ val agg = operation.getAggFunc(AttributeType.STRING)
+
+ var partial = agg.init()
+ partial = agg.iterate(partial, tuple1)
+ partial = agg.iterate(partial, tuple2)
+ partial = agg.iterate(partial, tuple3)
+
+ val result = agg.finalAgg(partial).asInstanceOf[String]
+ assert(result == "red,,blue")
+ }
+
+ test("MIN aggregation finds smallest INTEGER and returns null when given no
values") {
+ val schema = makeSchema("temperature" -> AttributeType.INTEGER)
+ val tuple1 = makeTuple(schema, 10)
+ val tuple2 = makeTuple(schema, -2)
+ val tuple3 = makeTuple(schema, 5)
+
+ val operation = makeAggregationOp(AggregationFunction.MIN, "temperature",
"min_temp")
+ val agg = operation.getAggFunc(AttributeType.INTEGER)
+
+ // Empty case: never iterate, just finalize init
+ val emptyPartial = agg.init()
+ val emptyResult = agg.finalAgg(emptyPartial)
+ assert(emptyResult == null)
+
+ // Non-empty case
+ var partial = agg.init()
+ partial = agg.iterate(partial, tuple1)
+ partial = agg.iterate(partial, tuple2)
+ partial = agg.iterate(partial, tuple3)
+
+ val result = agg.finalAgg(partial).asInstanceOf[Number].intValue()
+ assert(result == -2)
+ }
+
+ test("MAX aggregation finds largest LONG value") {
+ val schema = makeSchema("latency" -> AttributeType.LONG)
+ val tuple1 = makeTuple(schema, 100L)
+ val tuple2 = makeTuple(schema, 50L)
+ val tuple3 = makeTuple(schema, 250L)
+
+ val operation = makeAggregationOp(AggregationFunction.MAX, "latency",
"max_latency")
+ val agg = operation.getAggFunc(AttributeType.LONG)
+
+ var partial = agg.init()
+ partial = agg.iterate(partial, tuple1)
+ partial = agg.iterate(partial, tuple2)
+ partial = agg.iterate(partial, tuple3)
+
+ val result = agg.finalAgg(partial).asInstanceOf[java.lang.Long].longValue()
+ assert(result == 250L)
+ }
+
+ test("AVERAGE aggregation ignores nulls and returns null when all values are
null") {
+ val schema = makeSchema("price" -> AttributeType.DOUBLE)
+ val tuple1 = makeTuple(schema, 10.0)
+ val tuple2 = makeTuple(schema, null)
+ val tuple3 = makeTuple(schema, 20.0)
+
+ val operation = makeAggregationOp(AggregationFunction.AVERAGE, "price",
"avg_price")
+ val agg = operation.getAggFunc(AttributeType.DOUBLE)
+
+ // Mixed null and non-null
+ var partial = agg.init()
+ partial = agg.iterate(partial, tuple1)
+ partial = agg.iterate(partial, tuple2)
+ partial = agg.iterate(partial, tuple3)
+
+ val avg =
agg.finalAgg(partial).asInstanceOf[java.lang.Double].doubleValue()
+ assert(math.abs(avg - 15.0) < 1e-6)
+
+ // All nulls
+ val allNull = makeTuple(schema, null)
+ var partialAllNull = agg.init()
+ partialAllNull = agg.iterate(partialAllNull, allNull)
+ val allNullResult = agg.finalAgg(partialAllNull)
+ assert(allNullResult == null)
+ }
+
+ //
---------------------------------------------------------------------------
+ // getFinal behaviour
+ //
---------------------------------------------------------------------------
+
+ test("getFinal rewrites COUNT into SUM over the intermediate result
attribute") {
+ val operation = makeAggregationOp(AggregationFunction.COUNT, "price",
"price_count")
+ val finalOp = operation.getFinal
+
+ assert(finalOp.aggFunction == AggregationFunction.SUM)
+ assert(finalOp.attribute == "price_count")
+ assert(finalOp.resultAttribute == "price_count")
+ }
+
+ test("getFinal keeps non-COUNT aggregation function and rewires attribute to
resultAttribute") {
+ val operation = makeAggregationOp(AggregationFunction.SUM, "amount",
"total_amount")
+ val finalOp = operation.getFinal
+
+ assert(finalOp.aggFunction == AggregationFunction.SUM)
+ assert(finalOp.attribute == "total_amount")
+ assert(finalOp.resultAttribute == "total_amount")
+ }
+
+ //
---------------------------------------------------------------------------
+ // AggregateOpExec: integration-style tests with groupBy
+ //
---------------------------------------------------------------------------
+
+ test("AggregateOpExec groups by a single key and computes SUM per group") {
+ // schema: city (group key), sales
+ val schema = makeSchema(
+ "city" -> AttributeType.STRING,
+ "sales" -> AttributeType.INTEGER
+ )
+
+ val tuple1 = makeTuple(schema, "NY", 10)
+ val tuple2 = makeTuple(schema, "SF", 20)
+ val tuple3 = makeTuple(schema, "NY", 5)
+
+ val desc = new AggregateOpDesc()
+ val sumAgg = makeAggregationOp(AggregationFunction.SUM, "sales",
"total_sales")
+ desc.aggregations = List(sumAgg)
+ desc.groupByKeys = List("city")
+
+ val descJson = objectMapper.writeValueAsString(desc)
+
+ val exec = new AggregateOpExec(descJson)
+ exec.open()
+ exec.processTuple(tuple1, 0)
+ exec.processTuple(tuple2, 0)
+ exec.processTuple(tuple3, 0)
+
+ val results = exec.onFinish(0).toList
+
+ // Expect two output rows: (NY, 15) and (SF, 20)
+ val resultMap = results.map { tupleLike =>
+ val fields = tupleLike.getFields
+ val city = fields(0).asInstanceOf[String]
+ val total = fields(1).asInstanceOf[Number].intValue()
+ city -> total
+ }.toMap
+
+ assert(resultMap.size == 2)
+ assert(resultMap("NY") == 15)
+ assert(resultMap("SF") == 20)
+ }
+
+ test("AggregateOpExec performs global SUM and COUNT when there are no
groupBy keys") {
+ // schema: region (ignored for aggregation), revenue
+ val schema = makeSchema(
+ "region" -> AttributeType.STRING,
+ "revenue" -> AttributeType.INTEGER
+ )
+
+ val tuple1 = makeTuple(schema, "west", 100)
+ val tuple2 = makeTuple(schema, "east", 200)
+ val tuple3 = makeTuple(schema, "west", 50)
+
+ val desc = new AggregateOpDesc()
+ val sumAgg = makeAggregationOp(AggregationFunction.SUM, "revenue",
"total_revenue")
+ val countAgg = makeAggregationOp(AggregationFunction.COUNT, "revenue",
"row_count")
+ desc.aggregations = List(sumAgg, countAgg)
+ desc.groupByKeys = List() // global aggregation
+
+ val descJson = objectMapper.writeValueAsString(desc)
+
+ val exec = new AggregateOpExec(descJson)
+ exec.open()
+ exec.processTuple(tuple1, 0)
+ exec.processTuple(tuple2, 0)
+ exec.processTuple(tuple3, 0)
+
+ val results = exec.onFinish(0).toList
+ assert(results.size == 1)
+
+ val fields = results.head.getFields
+ // No group keys, so fields(0) is SUM(revenue), fields(1) is COUNT(revenue)
+ val totalRevenue = fields(0).asInstanceOf[Number].intValue()
+ val rowCount = fields(1).asInstanceOf[Number].intValue()
+
+ assert(totalRevenue == 350)
+ assert(rowCount == 3)
+ }
+}