HappenLee commented on code in PR #16841: URL: https://github.com/apache/doris/pull/16841#discussion_r1108534879
########## fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java: ########## @@ -206,25 +220,299 @@ protected Object[] allocateInputObjects(long row, int argClassOffset) throws Udf case STRING: { long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * row)); - long numBytes = row == 0 ? offset : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1))); - long base = - row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) : - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + offset - numBytes; + long numBytes = row == 0 ? offset + : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1))); + long base = row == 0 + ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + : UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + offset - numBytes; byte[] bytes = new byte[(int) numBytes]; UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); inputObjects[i] = new String(bytes, StandardCharsets.UTF_8); break; } + case ARRAY_TYPE: { + Type type = argTypes[i].getItemType(); + inputObjects[i] = arrayTypeInputData(type, i, row); + break; + } default: throw new UdfRuntimeException("Unsupported argument type: " + argTypes[i]); } } return inputObjects; } + public ArrayList<?> arrayTypeInputData(Type type, int argIdx, long row) + throws UdfRuntimeException { + long offsetStart = (row == 0) ? 0 + : Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputOffsetsPtrs, argIdx)) + 8L * (row - 1))); + long offsetEnd = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputOffsetsPtrs, argIdx)) + 8L * row)); + + switch (type.getPrimitiveType()) { + case BOOLEAN: { + ArrayList<Boolean> data = new ArrayList<>(); + for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { + if ((UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)) + offsetRow) == 1)) { + data.add(null); + } else { + boolean value = UdfUtils.UNSAFE.getBoolean(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)) + + offsetRow); + data.add(value); + } + } + return data; + } + case TINYINT: { + ArrayList<Byte> data = new ArrayList<>(); + for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { + if ((UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)) + offsetRow) == 1)) { + data.add(null); + } else { + byte value = UdfUtils.UNSAFE.getByte(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)) + + offsetRow); + data.add(value); + } + } + return data; + } + case SMALLINT: { + ArrayList<Short> data = new ArrayList<>(); + for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { + if ((UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)) + offsetRow) == 1)) { + data.add(null); + } else { + short value = UdfUtils.UNSAFE.getShort(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)) + + 2L * offsetRow); + data.add(value); + } + } + return data; + } + case INT: { + ArrayList<Integer> data = new ArrayList<>(); + for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { + if ((UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)) + offsetRow) == 1)) { + data.add(null); + } else { + int value = UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)) + + 4L * offsetRow); + data.add(value); + } + } + return data; + } + case BIGINT: { + ArrayList<Long> data = new ArrayList<>(); + for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { + if ((UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)) + offsetRow) == 1)) { + data.add(null); + } else { + long value = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)) + + 8L * offsetRow); + data.add(value); + } + } + return data; + } + case FLOAT: { + ArrayList<Float> data = new ArrayList<>(); + for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { + if ((UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)) + offsetRow) == 1)) { + data.add(null); + } else { + float value = UdfUtils.UNSAFE.getFloat(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)) + + 4L * offsetRow); + data.add(value); + } + } + return data; + } + case DOUBLE: { + ArrayList<Double> data = new ArrayList<>(); + for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { + if ((UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)) + offsetRow) == 1)) { + data.add(null); + } else { + double value = UdfUtils.UNSAFE.getDouble(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)) + + 8L * offsetRow); + data.add(value); + } + } + return data; + } + case DATE: { + ArrayList<LocalDate> data = new ArrayList<>(); + for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { + if ((UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)) + offsetRow) == 1)) { + data.add(null); + } else { + long value = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)) + + 8L * offsetRow); + // TODO: now argClass[argIdx + argClassOffset] is java.util.ArrayList, can't get + // nested class type + // LocalDate obj = UdfUtils.convertDateToJavaDate(value, argClass[argIdx + + // argClassOffset]); + LocalDate obj = (LocalDate) UdfUtils.convertDateToJavaDate(value, LocalDate.class); + data.add(obj); + } + } + return data; + } + case DATETIME: { + ArrayList<LocalDateTime> data = new ArrayList<>(); + for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { + if ((UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)) + offsetRow) == 1)) { + data.add(null); + } else { + long value = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)) + + 8L * offsetRow); + // Object obj = UdfUtils.convertDateTimeToJavaDateTime(value, argClass[argIdx + + // argClassOffset]); + LocalDateTime obj = (LocalDateTime) UdfUtils.convertDateTimeToJavaDateTime(value, + LocalDateTime.class); + data.add(obj); + } + } + return data; + } + case DATEV2: { + ArrayList<LocalDate> data = new ArrayList<>(); + for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { + if ((UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)) + offsetRow) == 1)) { + data.add(null); + } else { + int value = UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)) + + 4L * offsetRow); + // Object obj = UdfUtils.convertDateV2ToJavaDate(value, argClass[argIdx + + // argClassOffset]); + LocalDate obj = (LocalDate) UdfUtils.convertDateV2ToJavaDate(value, LocalDate.class); + data.add(obj); + } + } + return data; + } + case DATETIMEV2: { + ArrayList<LocalDateTime> data = new ArrayList<>(); + for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { + if ((UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)) + offsetRow) == 1)) { + data.add(null); + } else { + long value = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)) + + 8L * offsetRow); + LocalDateTime obj = (LocalDateTime) UdfUtils.convertDateTimeV2ToJavaDateTime(value, + LocalDateTime.class); + data.add(obj); + } + } + return data; + } + case LARGEINT: { + ArrayList<BigInteger> data = new ArrayList<>(); + for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { + if ((UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)) + offsetRow) == 1)) { + data.add(null); + } else { + long value = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)) + + 16L * offsetRow); + byte[] bytes = new byte[16]; Review Comment: the new should out for loop -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org