Github user twalthr commented on a diff in the pull request: https://github.com/apache/flink/pull/5327#discussion_r182427888 --- Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/join/NonWindowInnerJoin.scala --- @@ -54,238 +44,51 @@ class NonWindowInnerJoin( genJoinFuncName: String, genJoinFuncCode: String, queryConfig: StreamQueryConfig) - extends CoProcessFunction[CRow, CRow, CRow] - with Compiler[FlatJoinFunction[Row, Row, Row]] - with Logging { - - // check if input types implement proper equals/hashCode - validateEqualsHashCode("join", leftType) - validateEqualsHashCode("join", rightType) - - // state to hold left stream element - private var leftState: MapState[Row, JTuple2[Int, Long]] = _ - // state to hold right stream element - private var rightState: MapState[Row, JTuple2[Int, Long]] = _ - private var cRowWrapper: CRowWrappingMultiOutputCollector = _ - - private val minRetentionTime: Long = queryConfig.getMinIdleStateRetentionTime - private val maxRetentionTime: Long = queryConfig.getMaxIdleStateRetentionTime - private val stateCleaningEnabled: Boolean = minRetentionTime > 1 - - // state to record last timer of left stream, 0 means no timer - private var leftTimer: ValueState[Long] = _ - // state to record last timer of right stream, 0 means no timer - private var rightTimer: ValueState[Long] = _ - - // other condition function - private var joinFunction: FlatJoinFunction[Row, Row, Row] = _ + extends NonWindowJoin( + leftType, + rightType, + resultType, + genJoinFuncName, + genJoinFuncCode, + queryConfig) { override def open(parameters: Configuration): Unit = { - LOG.debug(s"Compiling JoinFunction: $genJoinFuncName \n\n " + - s"Code:\n$genJoinFuncCode") - val clazz = compile( - getRuntimeContext.getUserCodeClassLoader, - genJoinFuncName, - genJoinFuncCode) - LOG.debug("Instantiating JoinFunction.") - joinFunction = clazz.newInstance() - - // initialize left and right state, the first element of tuple2 indicates how many rows of - // this row, while the second element represents the expired time of this row. - val tupleTypeInfo = new TupleTypeInfo[JTuple2[Int, Long]](Types.INT, Types.LONG) - val leftStateDescriptor = new MapStateDescriptor[Row, JTuple2[Int, Long]]( - "left", leftType, tupleTypeInfo) - val rightStateDescriptor = new MapStateDescriptor[Row, JTuple2[Int, Long]]( - "right", rightType, tupleTypeInfo) - leftState = getRuntimeContext.getMapState(leftStateDescriptor) - rightState = getRuntimeContext.getMapState(rightStateDescriptor) - - // initialize timer state - val valueStateDescriptor1 = new ValueStateDescriptor[Long]("timervaluestate1", classOf[Long]) - leftTimer = getRuntimeContext.getState(valueStateDescriptor1) - val valueStateDescriptor2 = new ValueStateDescriptor[Long]("timervaluestate2", classOf[Long]) - rightTimer = getRuntimeContext.getState(valueStateDescriptor2) - - cRowWrapper = new CRowWrappingMultiOutputCollector() - } - - /** - * Process left stream records - * - * @param valueC The input value. - * @param ctx The ctx to register timer or get current time - * @param out The collector for returning result values. - * - */ - override def processElement1( - valueC: CRow, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow]): Unit = { - - processElement(valueC, ctx, out, leftTimer, leftState, rightState, isLeft = true) - } - - /** - * Process right stream records - * - * @param valueC The input value. - * @param ctx The ctx to register timer or get current time - * @param out The collector for returning result values. - * - */ - override def processElement2( - valueC: CRow, - ctx: CoProcessFunction[CRow, CRow, CRow]#Context, - out: Collector[CRow]): Unit = { - - processElement(valueC, ctx, out, rightTimer, rightState, leftState, isLeft = false) - } - - - /** - * Called when a processing timer trigger. - * Expire left/right records which are expired in left and right state. - * - * @param timestamp The timestamp of the firing timer. - * @param ctx The ctx to register timer or get current time - * @param out The collector for returning result values. - */ - override def onTimer( - timestamp: Long, - ctx: CoProcessFunction[CRow, CRow, CRow]#OnTimerContext, - out: Collector[CRow]): Unit = { - - if (stateCleaningEnabled && leftTimer.value == timestamp) { - expireOutTimeRow( - timestamp, - leftState, - leftTimer, - ctx - ) - } - - if (stateCleaningEnabled && rightTimer.value == timestamp) { - expireOutTimeRow( - timestamp, - rightState, - rightTimer, - ctx - ) - } - } - - def getNewExpiredTime( - curProcessTime: Long, - oldExpiredTime: Long): Long = { - - if (stateCleaningEnabled && curProcessTime + minRetentionTime > oldExpiredTime) { - curProcessTime + maxRetentionTime - } else { - oldExpiredTime - } + super.open(parameters) + LOG.debug("Instantiating NonWindowInnerJoin.") } /** * Puts or Retract an element from the input stream into state and search the other state to * output records meet the condition. Records will be expired in state if state retention time * has been specified. */ - def processElement( + override def processElement( value: CRow, ctx: CoProcessFunction[CRow, CRow, CRow]#Context, out: Collector[CRow], timerState: ValueState[Long], - currentSideState: MapState[Row, JTuple2[Int, Long]], - otherSideState: MapState[Row, JTuple2[Int, Long]], + currentSideState: MapState[Row, JTuple2[Long, Long]], + otherSideState: MapState[Row, JTuple2[Long, Long]], isLeft: Boolean): Unit = { val inputRow = value.row + val (curProcessTime, _) = updateCurrentSide(value, ctx, timerState, currentSideState) --- End diff -- We should not use Scala sugar in runtime classes. This might create an tuple object for every processed element.
---