This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 01c001e2ac [refactor](javaudf) simplify UdfExecutor and UdafExecutor (#16050) 01c001e2ac is described below commit 01c001e2ac06513d8049d519341e05b712460c7a Author: Gabriel <gabrielleeb...@gmail.com> AuthorDate: Sat Jan 21 08:07:28 2023 +0800 [refactor](javaudf) simplify UdfExecutor and UdafExecutor (#16050) * [refactor](javaudf) simplify UdfExecutor and UdafExecutor * update * update --- .../udf/{UdfExecutor.java => BaseExecutor.java} | 502 +++++++-------------- .../java/org/apache/doris/udf/UdafExecutor.java | 379 ++-------------- .../java/org/apache/doris/udf/UdfExecutor.java | 384 ++-------------- 3 files changed, 227 insertions(+), 1038 deletions(-) diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java similarity index 60% copy from fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java copy to fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java index 62deef5cda..55ff08f700 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java @@ -18,80 +18,67 @@ package org.apache.doris.udf; import org.apache.doris.catalog.Type; -import org.apache.doris.common.Pair; import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; import org.apache.doris.udf.UdfUtils.JavaUdfDataType; -import com.google.common.base.Joiner; import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; import org.apache.log4j.Logger; import org.apache.thrift.TDeserializer; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; import java.io.IOException; -import java.lang.reflect.Constructor; -import java.lang.reflect.Method; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; -import java.net.MalformedURLException; import java.net.URLClassLoader; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; import java.util.Arrays; -public class UdfExecutor { - private static final Logger LOG = Logger.getLogger(UdfExecutor.class); +public abstract class BaseExecutor { + private static final Logger LOG = Logger.getLogger(BaseExecutor.class); // By convention, the function in the class must be called evaluate() public static final String UDF_FUNCTION_NAME = "evaluate"; + public static final String UDAF_CREATE_FUNCTION = "create"; + public static final String UDAF_DESTROY_FUNCTION = "destroy"; + public static final String UDAF_ADD_FUNCTION = "add"; + public static final String UDAF_SERIALIZE_FUNCTION = "serialize"; + public static final String UDAF_DESERIALIZE_FUNCTION = "deserialize"; + public static final String UDAF_MERGE_FUNCTION = "merge"; + public static final String UDAF_RESULT_FUNCTION = "getValue"; // Object to deserialize ctor params from BE. - private static final TBinaryProtocol.Factory PROTOCOL_FACTORY = + protected static final TBinaryProtocol.Factory PROTOCOL_FACTORY = new TBinaryProtocol.Factory(); - private Object udf; + protected Object udf; // setup by init() and cleared by close() - private Method method; - // setup by init() and cleared by close() - private URLClassLoader classLoader; + protected URLClassLoader classLoader; // Return and argument types of the function inferred from the udf method signature. // The JavaUdfDataType enum maps it to corresponding primitive type. - private JavaUdfDataType[] argTypes; - private JavaUdfDataType retType; + protected JavaUdfDataType[] argTypes; + protected JavaUdfDataType retType; // Input buffer from the backend. This is valid for the duration of an evaluate() call. // These buffers are allocated in the BE. - private final long inputBufferPtrs; - private final long inputNullsPtrs; - private final long inputOffsetsPtrs; + protected final long inputBufferPtrs; + protected final long inputNullsPtrs; + protected final long inputOffsetsPtrs; // Output buffer to return non-string values. These buffers are allocated in the BE. - private final long outputBufferPtr; - private final long outputNullPtr; - private final long outputOffsetsPtr; - private final long outputIntermediateStatePtr; - - // Pre-constructed input objects for the UDF. This minimizes object creation overhead - // as these objects are reused across calls to evaluate(). - private Object[] inputObjects; - // inputArgs_[i] is either inputObjects[i] or null - private Object[] inputArgs; - - private long outputOffset; - private long rowIdx; - - private final long batchSizePtr; - private Class[] argClass; + protected final long outputBufferPtr; + protected final long outputNullPtr; + protected final long outputOffsetsPtr; + protected final long outputIntermediateStatePtr; + protected Class[] argClass; /** * Create a UdfExecutor, using parameters from a serialized thrift object. Used by * the backend. */ - public UdfExecutor(byte[] thriftParams) throws Exception { + public BaseExecutor(byte[] thriftParams) throws Exception { TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams(); TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY); try { @@ -99,14 +86,6 @@ public class UdfExecutor { } catch (TException e) { throw new InternalException(e.getMessage()); } - String className = request.fn.scalar_fn.symbol; - String jarFile = request.location; - Type retType = UdfUtils.fromThrift(request.fn.ret_type, 0).first; - Type[] parameterTypes = new Type[request.fn.arg_types.size()]; - for (int i = 0; i < request.fn.arg_types.size(); ++i) { - parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i)); - } - batchSizePtr = request.batch_size_ptr; inputBufferPtrs = request.input_buffer_ptrs; inputNullsPtrs = request.input_nulls_ptrs; inputOffsetsPtrs = request.input_offsets_ptrs; @@ -116,18 +95,139 @@ public class UdfExecutor { outputOffsetsPtr = request.output_offsets_ptr; outputIntermediateStatePtr = request.output_intermediate_state_ptr; - outputOffset = 0L; - rowIdx = 0L; + Type[] parameterTypes = new Type[request.fn.arg_types.size()]; + for (int i = 0; i < request.fn.arg_types.size(); ++i) { + parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i)); + } + String jarFile = request.location; + Type funcRetType = UdfUtils.fromThrift(request.fn.ret_type, 0).first; - init(jarFile, className, retType, parameterTypes); + init(request, jarFile, funcRetType, parameterTypes); } - @Override - protected void finalize() throws Throwable { - close(); - super.finalize(); + protected abstract void init(TJavaUdfExecutorCtorParams request, String jarPath, + Type funcRetType, Type... parameterTypes) throws UdfRuntimeException; + + protected Object[] allocateInputObjects(long row, int argClassOffset) throws UdfRuntimeException { + Object[] inputObjects = new Object[argTypes.length]; + + for (int i = 0; i < argTypes.length; ++i) { + if (UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) != -1 + && (UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + row) == 1)) { + inputObjects[i] = null; + continue; + } + switch (argTypes[i]) { + case BOOLEAN: + inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); + break; + case TINYINT: + inputObjects[i] = UdfUtils.UNSAFE.getByte(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); + break; + case SMALLINT: + inputObjects[i] = UdfUtils.UNSAFE.getShort(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + argTypes[i].getLen() * row); + break; + case INT: + inputObjects[i] = UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + argTypes[i].getLen() * row); + break; + case BIGINT: + inputObjects[i] = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + argTypes[i].getLen() * row); + break; + case FLOAT: + inputObjects[i] = UdfUtils.UNSAFE.getFloat(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + argTypes[i].getLen() * row); + break; + case DOUBLE: + inputObjects[i] = UdfUtils.UNSAFE.getDouble(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + argTypes[i].getLen() * row); + break; + case DATE: { + long data = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + argTypes[i].getLen() * row); + inputObjects[i] = UdfUtils.convertDateToJavaDate(data, argClass[i + argClassOffset]); + break; + } + case DATETIME: { + long data = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + argTypes[i].getLen() * row); + inputObjects[i] = UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i + argClassOffset]); + break; + } + case DATEV2: { + int data = UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + argTypes[i].getLen() * row); + inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data, argClass[i + argClassOffset]); + break; + } + case DATETIMEV2: { + long data = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + argTypes[i].getLen() * row); + inputObjects[i] = UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i + argClassOffset]); + break; + } + case LARGEINT: { + long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + argTypes[i].getLen() * row; + byte[] bytes = new byte[argTypes[i].getLen()]; + UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen()); + + inputObjects[i] = new BigInteger(UdfUtils.convertByteOrder(bytes)); + break; + } + case DECIMALV2: + case DECIMAL32: + case DECIMAL64: + case DECIMAL128: { + long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + argTypes[i].getLen() * row; + byte[] bytes = new byte[argTypes[i].getLen()]; + UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen()); + + BigInteger value = new BigInteger(UdfUtils.convertByteOrder(bytes)); + inputObjects[i] = new BigDecimal(value, argTypes[i].getScale()); + break; + } + case CHAR: + case VARCHAR: + case STRING: { + long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * row)); + long numBytes = row == 0 ? offset : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1))); + long base = + row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) : + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + offset - numBytes; + byte[] bytes = new byte[(int) numBytes]; + UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); + inputObjects[i] = new String(bytes, StandardCharsets.UTF_8); + break; + } + default: + throw new UdfRuntimeException("Unsupported argument type: " + argTypes[i]); + } + } + return inputObjects; } + protected abstract long getCurrentOutputOffset(long row); + /** * Close the class loader we may have created. */ @@ -142,91 +242,11 @@ public class UdfExecutor { } // We are now un-usable (because the class loader has been // closed), so null out method_ and classLoader_. - method = null; classLoader = null; } - /** - * evaluate function called by the backend. The inputs to the UDF have - * been serialized to 'input' - */ - public void evaluate() throws UdfRuntimeException { - int batchSize = UdfUtils.UNSAFE.getInt(null, batchSizePtr); - try { - if (retType.equals(JavaUdfDataType.STRING) || retType.equals(JavaUdfDataType.VARCHAR) - || retType.equals(JavaUdfDataType.CHAR)) { - // If this udf return variable-size type (e.g.) String, we have to allocate output - // buffer multiple times until buffer size is enough to store output column. So we - // always begin with the last evaluated row instead of beginning of this batch. - rowIdx = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr + 8); - if (rowIdx == 0) { - outputOffset = 0L; - } - } else { - rowIdx = 0; - } - for (; rowIdx < batchSize; rowIdx++) { - allocateInputObjects(rowIdx); - for (int i = 0; i < argTypes.length; ++i) { - // Currently, -1 indicates this column is not nullable. So input argument is - // null iff inputNullsPtrs_ != -1 and nullCol[row_idx] != 0. - if (UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) == -1 - || UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + rowIdx) == 0) { - inputArgs[i] = inputObjects[i]; - } else { - inputArgs[i] = null; - } - } - // `storeUdfResult` is called to store udf result to output column. If true - // is returned, current value is stored successfully. Otherwise, current result is - // not processed successfully (e.g. current output buffer is not large enough) so - // we break this loop directly. - if (!storeUdfResult(evaluate(inputArgs), rowIdx)) { - UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, rowIdx); - return; - } - } - } catch (Exception e) { - if (retType.equals(JavaUdfDataType.STRING)) { - UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, batchSize); - } - throw new UdfRuntimeException("UDF::evaluate() ran into a problem.", e); - } - if (retType.equals(JavaUdfDataType.STRING)) { - UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, rowIdx); - } - } - - /** - * Evaluates the UDF with 'args' as the input to the UDF. - */ - private Object evaluate(Object... args) throws UdfRuntimeException { - try { - return method.invoke(udf, args); - } catch (Exception e) { - throw new UdfRuntimeException("UDF failed to evaluate", e); - } - } - - public Method getMethod() { - return method; - } - // Sets the result object 'obj' into the outputBufferPtr and outputNullPtr_ - private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException { - if (obj == null) { - if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) { - throw new UdfRuntimeException("UDF failed to store null data to not null column"); - } - UdfUtils.UNSAFE.putByte(null, UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 1); - if (retType.equals(JavaUdfDataType.STRING)) { - UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) - + 4L * row, Integer.parseUnsignedInt(String.valueOf(outputOffset))); - } - return true; - } + protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException { if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1) { UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 0); } @@ -268,22 +288,22 @@ public class UdfExecutor { return true; } case DATE: { - long time = UdfUtils.convertToDate(obj, method.getReturnType()); + long time = UdfUtils.convertToDate(obj, retClass); UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); return true; } case DATETIME: { - long time = UdfUtils.convertToDateTime(obj, method.getReturnType()); + long time = UdfUtils.convertToDateTime(obj, retClass); UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); return true; } case DATEV2: { - int time = UdfUtils.convertToDateV2(obj, method.getReturnType()); + int time = UdfUtils.convertToDateV2(obj, retClass); UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); return true; } case DATETIMEV2: { - long time = UdfUtils.convertToDateTimeV2(obj, method.getReturnType()); + long time = UdfUtils.convertToDateTimeV2(obj, retClass); UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); return true; } @@ -349,14 +369,16 @@ public class UdfExecutor { case STRING: { long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr); byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8); - if (outputOffset + bytes.length > bufferSize) { + long offset = getCurrentOutputOffset(row); + if (offset + bytes.length > bufferSize) { return false; } - outputOffset += bytes.length; + offset += bytes.length; UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * row, - Integer.parseUnsignedInt(String.valueOf(outputOffset))); + Integer.parseUnsignedInt(String.valueOf(offset))); UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + outputOffset - bytes.length, bytes.length); + UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + offset - bytes.length, bytes.length); + updateOutputOffset(offset); return true; } default: @@ -364,203 +386,5 @@ public class UdfExecutor { } } - // Preallocate the input objects that will be passed to the underlying UDF. - // These objects are allocated once and reused across calls to evaluate() - private void allocateInputObjects(long row) throws UdfRuntimeException { - inputObjects = new Object[argTypes.length]; - inputArgs = new Object[argTypes.length]; - - for (int i = 0; i < argTypes.length; ++i) { - switch (argTypes[i]) { - case BOOLEAN: - inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); - break; - case TINYINT: - inputObjects[i] = UdfUtils.UNSAFE.getByte(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); - break; - case SMALLINT: - inputObjects[i] = UdfUtils.UNSAFE.getShort(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case INT: - inputObjects[i] = UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case BIGINT: - inputObjects[i] = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case FLOAT: - inputObjects[i] = UdfUtils.UNSAFE.getFloat(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case DOUBLE: - inputObjects[i] = UdfUtils.UNSAFE.getDouble(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case DATE: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateToJavaDate(data, argClass[i]); - break; - } - case DATETIME: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i]); - break; - } - case DATEV2: { - int data = UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data, argClass[i]); - break; - } - case DATETIMEV2: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i]); - break; - } - case LARGEINT: { - long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row; - byte[] bytes = new byte[argTypes[i].getLen()]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen()); - - inputObjects[i] = new BigInteger(UdfUtils.convertByteOrder(bytes)); - break; - } - case DECIMALV2: - case DECIMAL32: - case DECIMAL64: - case DECIMAL128: { - long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row; - byte[] bytes = new byte[argTypes[i].getLen()]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen()); - - BigInteger value = new BigInteger(UdfUtils.convertByteOrder(bytes)); - inputObjects[i] = new BigDecimal(value, argTypes[i].getScale()); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * row)); - long numBytes = row == 0 ? offset : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1))); - long base = - row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) : - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + offset - numBytes; - byte[] bytes = new byte[(int) numBytes]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); - inputObjects[i] = new String(bytes, StandardCharsets.UTF_8); - break; - } - default: - throw new UdfRuntimeException("Unsupported argument type: " + argTypes[i]); - } - } - } - - private void init(String jarPath, String udfPath, Type funcRetType, Type... parameterTypes) - throws UdfRuntimeException { - ArrayList<String> signatures = Lists.newArrayList(); - try { - LOG.debug("Loading UDF '" + udfPath + "' from " + jarPath); - ClassLoader loader; - if (jarPath != null) { - // Save for cleanup. - ClassLoader parent = getClass().getClassLoader(); - classLoader = UdfUtils.getClassLoader(jarPath, parent); - loader = classLoader; - } else { - // for test - loader = ClassLoader.getSystemClassLoader(); - } - Class<?> c = Class.forName(udfPath, true, loader); - Constructor<?> ctor = c.getConstructor(); - udf = ctor.newInstance(); - Method[] methods = c.getMethods(); - for (Method m : methods) { - // By convention, the udf must contain the function "evaluate" - if (!m.getName().equals(UDF_FUNCTION_NAME)) { - continue; - } - signatures.add(m.toGenericString()); - argClass = m.getParameterTypes(); - - // Try to match the arguments - if (argClass.length != parameterTypes.length) { - continue; - } - method = m; - Pair<Boolean, JavaUdfDataType> returnType; - if (argClass.length == 0 && parameterTypes.length == 0) { - // Special case where the UDF doesn't take any input args - returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType()); - if (!returnType.first) { - continue; - } else { - retType = returnType.second; - } - argTypes = new JavaUdfDataType[0]; - LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath); - return; - } - returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType()); - if (!returnType.first) { - continue; - } else { - retType = returnType.second; - } - Pair<Boolean, JavaUdfDataType[]> inputType = UdfUtils.setArgTypes(parameterTypes, argClass, false); - if (!inputType.first) { - continue; - } else { - argTypes = inputType.second; - } - LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath); - return; - } - - StringBuilder sb = new StringBuilder(); - sb.append("Unable to find evaluate function with the correct signature: ") - .append(udfPath + ".evaluate(") - .append(Joiner.on(", ").join(parameterTypes)) - .append(")\n") - .append("UDF contains: \n ") - .append(Joiner.on("\n ").join(signatures)); - throw new UdfRuntimeException(sb.toString()); - } catch (MalformedURLException e) { - throw new UdfRuntimeException("Unable to load jar.", e); - } catch (SecurityException e) { - throw new UdfRuntimeException("Unable to load function.", e); - } catch (ClassNotFoundException e) { - throw new UdfRuntimeException("Unable to find class.", e); - } catch (NoSuchMethodException e) { - throw new UdfRuntimeException( - "Unable to find constructor with no arguments.", e); - } catch (IllegalArgumentException e) { - throw new UdfRuntimeException( - "Unable to call UDF constructor with no arguments.", e); - } catch (Exception e) { - throw new UdfRuntimeException("Unable to call create UDF instance.", e); - } - } + protected void updateOutputOffset(long offset) {} } diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java index 0e6028b06e..4f88fa967e 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java @@ -23,12 +23,8 @@ import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; import org.apache.doris.udf.UdfUtils.JavaUdfDataType; import com.google.common.base.Joiner; -import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import org.apache.log4j.Logger; -import org.apache.thrift.TDeserializer; -import org.apache.thrift.TException; -import org.apache.thrift.protocol.TBinaryProtocol; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -36,100 +32,36 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.lang.reflect.Constructor; import java.lang.reflect.Method; -import java.math.BigDecimal; -import java.math.BigInteger; -import java.math.RoundingMode; import java.net.MalformedURLException; -import java.net.URLClassLoader; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; /** * udaf executor. */ -public class UdafExecutor { - public static final String UDAF_CREATE_FUNCTION = "create"; - public static final String UDAF_DESTROY_FUNCTION = "destroy"; - public static final String UDAF_ADD_FUNCTION = "add"; - public static final String UDAF_SERIALIZE_FUNCTION = "serialize"; - public static final String UDAF_DESERIALIZE_FUNCTION = "deserialize"; - public static final String UDAF_MERGE_FUNCTION = "merge"; - public static final String UDAF_RESULT_FUNCTION = "getValue"; +public class UdafExecutor extends BaseExecutor { + private static final Logger LOG = Logger.getLogger(UdafExecutor.class); - private static final TBinaryProtocol.Factory PROTOCOL_FACTORY = new TBinaryProtocol.Factory(); - private final long inputBufferPtrs; - private final long inputNullsPtrs; - private final long inputOffsetsPtrs; - private final long inputPlacesPtr; - private final long outputBufferPtr; - private final long outputNullPtr; - private final long outputOffsetsPtr; - private final long outputIntermediateStatePtr; - private Object udaf; + + private long inputPlacesPtr; private HashMap<String, Method> allMethods; private HashMap<Long, Object> stateObjMap; - private URLClassLoader classLoader; - private JavaUdfDataType[] argTypes; - private JavaUdfDataType retType; - private Class[] argClass; private Class retClass; /** * Constructor to create an object. */ public UdafExecutor(byte[] thriftParams) throws Exception { - TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams(); - TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY); - try { - deserializer.deserialize(request, thriftParams); - } catch (TException e) { - throw new InternalException(e.getMessage()); - } - Type[] parameterTypes = new Type[request.fn.arg_types.size()]; - for (int i = 0; i < request.fn.arg_types.size(); ++i) { - parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i)); - } - inputBufferPtrs = request.input_buffer_ptrs; - inputNullsPtrs = request.input_nulls_ptrs; - inputOffsetsPtrs = request.input_offsets_ptrs; - inputPlacesPtr = request.input_places_ptr; - - outputBufferPtr = request.output_buffer_ptr; - outputNullPtr = request.output_null_ptr; - outputOffsetsPtr = request.output_offsets_ptr; - outputIntermediateStatePtr = request.output_intermediate_state_ptr; - allMethods = new HashMap<>(); - stateObjMap = new HashMap<>(); - String className = request.fn.aggregate_fn.symbol; - String jarFile = request.location; - Type funcRetType = UdfUtils.fromThrift(request.fn.ret_type, 0).first; - init(jarFile, className, funcRetType, parameterTypes); + super(thriftParams); } /** * close and invoke destroy function. */ + @Override public void close() { - if (classLoader != null) { - try { - classLoader.close(); - } catch (Exception e) { - // Log and ignore. - LOG.debug("Error closing the URLClassloader.", e); - } - } - // We are now un-usable (because the class loader has been - // closed), so null out allMethods and classLoader. allMethods = null; - classLoader = null; - } - - @Override - protected void finalize() throws Throwable { - close(); - super.finalize(); + super.close(); } /** @@ -144,11 +76,11 @@ public class UdafExecutor { stateObjMap.putIfAbsent(curPlace, createAggState()); inputArgs[0] = stateObjMap.get(curPlace); do { - Object[] inputObjects = allocateInputObjects(idx); + Object[] inputObjects = allocateInputObjects(idx, 1); for (int i = 0; i < argTypes.length; ++i) { inputArgs[i + 1] = inputObjects[i]; } - allMethods.get(UDAF_ADD_FUNCTION).invoke(udaf, inputArgs); + allMethods.get(UDAF_ADD_FUNCTION).invoke(udf, inputArgs); idx++; } while (isSinglePlace && idx < rowEnd); } while (idx < rowEnd); @@ -162,7 +94,7 @@ public class UdafExecutor { */ public Object createAggState() throws UdfRuntimeException { try { - return allMethods.get(UDAF_CREATE_FUNCTION).invoke(udaf, null); + return allMethods.get(UDAF_CREATE_FUNCTION).invoke(udf, null); } catch (Exception e) { throw new UdfRuntimeException("UDAF failed to create: ", e); } @@ -174,7 +106,7 @@ public class UdafExecutor { public void destroy() throws UdfRuntimeException { try { for (Object obj : stateObjMap.values()) { - allMethods.get(UDAF_DESTROY_FUNCTION).invoke(udaf, obj); + allMethods.get(UDAF_DESTROY_FUNCTION).invoke(udf, obj); } stateObjMap.clear(); } catch (Exception e) { @@ -191,7 +123,7 @@ public class UdafExecutor { ByteArrayOutputStream baos = new ByteArrayOutputStream(); args[0] = stateObjMap.get((Long) place); args[1] = new DataOutputStream(baos); - allMethods.get(UDAF_SERIALIZE_FUNCTION).invoke(udaf, args); + allMethods.get(UDAF_SERIALIZE_FUNCTION).invoke(udf, args); return baos.toByteArray(); } catch (Exception e) { throw new UdfRuntimeException("UDAF failed to serialize: ", e); @@ -208,12 +140,12 @@ public class UdafExecutor { ByteArrayInputStream bins = new ByteArrayInputStream(data); args[0] = createAggState(); args[1] = new DataInputStream(bins); - allMethods.get(UDAF_DESERIALIZE_FUNCTION).invoke(udaf, args); + allMethods.get(UDAF_DESERIALIZE_FUNCTION).invoke(udf, args); args[1] = args[0]; Long curPlace = place; stateObjMap.putIfAbsent(curPlace, createAggState()); args[0] = stateObjMap.get(curPlace); - allMethods.get(UDAF_MERGE_FUNCTION).invoke(udaf, args); + allMethods.get(UDAF_MERGE_FUNCTION).invoke(udf, args); } catch (Exception e) { throw new UdfRuntimeException("UDAF failed to merge: ", e); } @@ -224,14 +156,15 @@ public class UdafExecutor { */ public boolean getValue(long row, long place) throws UdfRuntimeException { try { - return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udaf, stateObjMap.get((Long) place)), - row); + return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get((Long) place)), + row, retClass); } catch (Exception e) { throw new UdfRuntimeException("UDAF failed to result", e); } } - private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException { + @Override + protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException { if (obj == null) { // If result is null, return true directly when row == 0 as we have already inserted default value. if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) { @@ -239,267 +172,23 @@ public class UdafExecutor { } return true; } - if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1) { - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 0); - } - switch (retType) { - case BOOLEAN: { - boolean val = (boolean) obj; - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - val ? (byte) 1 : 0); - return true; - } - case TINYINT: { - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - (byte) obj); - return true; - } - case SMALLINT: { - UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - (short) obj); - return true; - } - case INT: { - UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - (int) obj); - return true; - } - case BIGINT: { - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - (long) obj); - return true; - } - case FLOAT: { - UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - (float) obj); - return true; - } - case DOUBLE: { - UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - (double) obj); - return true; - } - case DATE: { - long time = UdfUtils.convertToDate(obj, retClass); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; - } - case DATETIME: { - long time = UdfUtils.convertToDateTime(obj, retClass); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; - } - case DATEV2: { - long time = UdfUtils.convertToDateV2(obj, retClass); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; - } - case DATETIMEV2: { - long time = UdfUtils.convertToDateTimeV2(obj, retClass); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; - } - case LARGEINT: { - BigInteger data = (BigInteger) obj; - byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); - - //here value is 16 bytes, so if result data greater than the maximum of 16 bytes - //it will return a wrong num to backend; - byte[] value = new byte[16]; - //check data is negative - if (data.signum() == -1) { - Arrays.fill(value, (byte) -1); - } - for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { - value[index] = bytes[index]; - } - - UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); - return true; - } - case DECIMALV2: { - Preconditions.checkArgument(((BigDecimal) obj).scale() == 9, "Scale of DECIMALV2 must be 9"); - BigInteger data = ((BigDecimal) obj).unscaledValue(); - byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); - //TODO: here is maybe overflow also, and may find a better way to handle - byte[] value = new byte[16]; - if (data.signum() == -1) { - Arrays.fill(value, (byte) -1); - } - - for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { - value[index] = bytes[index]; - } - - UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); - return true; - } - case DECIMAL32: - case DECIMAL64: - case DECIMAL128: { - BigDecimal retValue = ((BigDecimal) obj).setScale(retType.getScale(), RoundingMode.HALF_EVEN); - BigInteger data = retValue.unscaledValue(); - byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); - //TODO: here is maybe overflow also, and may find a better way to handle - byte[] value = new byte[retType.getLen()]; - if (data.signum() == -1) { - Arrays.fill(value, (byte) -1); - } - - for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { - value[index] = bytes[index]; - } - - UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); - return true; - } - case CHAR: - case VARCHAR: - case STRING: { - long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr); - byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8); - long offset = Integer.toUnsignedLong( - UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1))); - if (offset + bytes.length > bufferSize) { - return false; - } - offset += bytes.length; - UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * row, - Integer.parseUnsignedInt(String.valueOf(offset))); - UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + offset - bytes.length, bytes.length); - return true; - } - default: - throw new UdfRuntimeException("Unsupported return type: " + retType); - } + return super.storeUdfResult(obj, row, retClass); } - private Object[] allocateInputObjects(long row) throws UdfRuntimeException { - Object[] inputObjects = new Object[argTypes.length]; - - for (int i = 0; i < argTypes.length; ++i) { - // skip the input column of current row is null - if (UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) != -1 - && (UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + row) == 1)) { - inputObjects[i] = null; - continue; - } - switch (argTypes[i]) { - case BOOLEAN: - inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); - break; - case TINYINT: - inputObjects[i] = UdfUtils.UNSAFE.getByte(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); - break; - case SMALLINT: - inputObjects[i] = UdfUtils.UNSAFE.getShort(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case INT: - inputObjects[i] = UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case BIGINT: - inputObjects[i] = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case FLOAT: - inputObjects[i] = UdfUtils.UNSAFE.getFloat(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case DOUBLE: - inputObjects[i] = UdfUtils.UNSAFE.getDouble(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case DATE: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateToJavaDate(data, argClass[i + 1]); - break; - } - case DATETIME: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i + 1]); - break; - } - case DATEV2: { - int data = UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data, argClass[i + 1]); - break; - } - case DATETIMEV2: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i + 1]); - break; - } - case LARGEINT: { - long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row; - byte[] bytes = new byte[argTypes[i].getLen()]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen()); - - inputObjects[i] = new BigInteger(UdfUtils.convertByteOrder(bytes)); - break; - } - case DECIMALV2: - case DECIMAL32: - case DECIMAL64: - case DECIMAL128: { - long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row; - byte[] bytes = new byte[argTypes[i].getLen()]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen()); - - BigInteger value = new BigInteger(UdfUtils.convertByteOrder(bytes)); - inputObjects[i] = new BigDecimal(value, argTypes[i].getScale()); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) - + 4L * row)); - long numBytes = row == 0 ? offset : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - - 1))); - long base = row == 0 ? UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - : UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + offset - - numBytes; - byte[] bytes = new byte[(int) numBytes]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); - inputObjects[i] = new String(bytes, StandardCharsets.UTF_8); - break; - } - default: - throw new UdfRuntimeException("Unsupported argument type: " + argTypes[i]); - } - } - return inputObjects; + @Override + protected long getCurrentOutputOffset(long row) { + return Integer.toUnsignedLong( + UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1))); } - private void init(String jarPath, String udfPath, Type funcRetType, Type... parameterTypes) - throws UdfRuntimeException { + @Override + protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type funcRetType, + Type... parameterTypes) throws UdfRuntimeException { + String className = request.fn.aggregate_fn.symbol; + inputPlacesPtr = request.input_places_ptr; + allMethods = new HashMap<>(); + stateObjMap = new HashMap<>(); + ArrayList<String> signatures = Lists.newArrayList(); try { ClassLoader loader; @@ -511,9 +200,9 @@ public class UdafExecutor { // for test loader = ClassLoader.getSystemClassLoader(); } - Class<?> c = Class.forName(udfPath, true, loader); + Class<?> c = Class.forName(className, true, loader); Constructor<?> ctor = c.getConstructor(); - udaf = ctor.newInstance(); + udf = ctor.newInstance(); Method[] methods = c.getDeclaredMethods(); int idx = 0; for (idx = 0; idx < methods.length; ++idx) { @@ -569,7 +258,7 @@ public class UdafExecutor { return; } StringBuilder sb = new StringBuilder(); - sb.append("Unable to find evaluate function with the correct signature: ").append(udfPath + ".evaluate(") + sb.append("Unable to find evaluate function with the correct signature: ").append(className + ".evaluate(") .append(Joiner.on(", ").join(parameterTypes)).append(")\n").append("UDF contains: \n ") .append(Joiner.on("\n ").join(signatures)); throw new UdfRuntimeException(sb.toString()); diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index 62deef5cda..5f043f64a8 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -23,127 +23,45 @@ import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; import org.apache.doris.udf.UdfUtils.JavaUdfDataType; import com.google.common.base.Joiner; -import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import org.apache.log4j.Logger; -import org.apache.thrift.TDeserializer; -import org.apache.thrift.TException; -import org.apache.thrift.protocol.TBinaryProtocol; -import java.io.IOException; import java.lang.reflect.Constructor; import java.lang.reflect.Method; -import java.math.BigDecimal; -import java.math.BigInteger; -import java.math.RoundingMode; import java.net.MalformedURLException; -import java.net.URLClassLoader; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; -import java.util.Arrays; -public class UdfExecutor { +public class UdfExecutor extends BaseExecutor { private static final Logger LOG = Logger.getLogger(UdfExecutor.class); - - // By convention, the function in the class must be called evaluate() - public static final String UDF_FUNCTION_NAME = "evaluate"; - - // Object to deserialize ctor params from BE. - private static final TBinaryProtocol.Factory PROTOCOL_FACTORY = - new TBinaryProtocol.Factory(); - - private Object udf; // setup by init() and cleared by close() private Method method; - // setup by init() and cleared by close() - private URLClassLoader classLoader; - - // Return and argument types of the function inferred from the udf method signature. - // The JavaUdfDataType enum maps it to corresponding primitive type. - private JavaUdfDataType[] argTypes; - private JavaUdfDataType retType; - - // Input buffer from the backend. This is valid for the duration of an evaluate() call. - // These buffers are allocated in the BE. - private final long inputBufferPtrs; - private final long inputNullsPtrs; - private final long inputOffsetsPtrs; - - // Output buffer to return non-string values. These buffers are allocated in the BE. - private final long outputBufferPtr; - private final long outputNullPtr; - private final long outputOffsetsPtr; - private final long outputIntermediateStatePtr; // Pre-constructed input objects for the UDF. This minimizes object creation overhead // as these objects are reused across calls to evaluate(). private Object[] inputObjects; - // inputArgs_[i] is either inputObjects[i] or null - private Object[] inputArgs; private long outputOffset; private long rowIdx; - private final long batchSizePtr; - private Class[] argClass; + private long batchSizePtr; /** * Create a UdfExecutor, using parameters from a serialized thrift object. Used by * the backend. */ public UdfExecutor(byte[] thriftParams) throws Exception { - TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams(); - TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY); - try { - deserializer.deserialize(request, thriftParams); - } catch (TException e) { - throw new InternalException(e.getMessage()); - } - String className = request.fn.scalar_fn.symbol; - String jarFile = request.location; - Type retType = UdfUtils.fromThrift(request.fn.ret_type, 0).first; - Type[] parameterTypes = new Type[request.fn.arg_types.size()]; - for (int i = 0; i < request.fn.arg_types.size(); ++i) { - parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i)); - } - batchSizePtr = request.batch_size_ptr; - inputBufferPtrs = request.input_buffer_ptrs; - inputNullsPtrs = request.input_nulls_ptrs; - inputOffsetsPtrs = request.input_offsets_ptrs; - - outputBufferPtr = request.output_buffer_ptr; - outputNullPtr = request.output_null_ptr; - outputOffsetsPtr = request.output_offsets_ptr; - outputIntermediateStatePtr = request.output_intermediate_state_ptr; - - outputOffset = 0L; - rowIdx = 0L; - - init(jarFile, className, retType, parameterTypes); - } - - @Override - protected void finalize() throws Throwable { - close(); - super.finalize(); + super(thriftParams); } /** * Close the class loader we may have created. */ + @Override public void close() { - if (classLoader != null) { - try { - classLoader.close(); - } catch (IOException e) { - // Log and ignore. - LOG.debug("Error closing the URLClassloader.", e); - } - } // We are now un-usable (because the class loader has been // closed), so null out method_ and classLoader_. method = null; - classLoader = null; + super.close(); } /** @@ -166,24 +84,12 @@ public class UdfExecutor { rowIdx = 0; } for (; rowIdx < batchSize; rowIdx++) { - allocateInputObjects(rowIdx); - for (int i = 0; i < argTypes.length; ++i) { - // Currently, -1 indicates this column is not nullable. So input argument is - // null iff inputNullsPtrs_ != -1 and nullCol[row_idx] != 0. - if (UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) == -1 - || UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + rowIdx) == 0) { - inputArgs[i] = inputObjects[i]; - } else { - inputArgs[i] = null; - } - } + inputObjects = allocateInputObjects(rowIdx, 0); // `storeUdfResult` is called to store udf result to output column. If true // is returned, current value is stored successfully. Otherwise, current result is // not processed successfully (e.g. current output buffer is not large enough) so // we break this loop directly. - if (!storeUdfResult(evaluate(inputArgs), rowIdx)) { + if (!storeUdfResult(evaluate(inputObjects), rowIdx, method.getReturnType())) { UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, rowIdx); return; } @@ -215,7 +121,8 @@ public class UdfExecutor { } // Sets the result object 'obj' into the outputBufferPtr and outputNullPtr_ - private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException { + @Override + protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException { if (obj == null) { if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) { throw new UdfRuntimeException("UDF failed to store null data to not null column"); @@ -227,262 +134,31 @@ public class UdfExecutor { } return true; } - if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1) { - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 0); - } - switch (retType) { - case BOOLEAN: { - boolean val = (boolean) obj; - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - val ? (byte) 1 : 0); - return true; - } - case TINYINT: { - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - (byte) obj); - return true; - } - case SMALLINT: { - UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - (short) obj); - return true; - } - case INT: { - UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - (int) obj); - return true; - } - case BIGINT: { - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - (long) obj); - return true; - } - case FLOAT: { - UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - (float) obj); - return true; - } - case DOUBLE: { - UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - (double) obj); - return true; - } - case DATE: { - long time = UdfUtils.convertToDate(obj, method.getReturnType()); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; - } - case DATETIME: { - long time = UdfUtils.convertToDateTime(obj, method.getReturnType()); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; - } - case DATEV2: { - int time = UdfUtils.convertToDateV2(obj, method.getReturnType()); - UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; - } - case DATETIMEV2: { - long time = UdfUtils.convertToDateTimeV2(obj, method.getReturnType()); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; - } - case LARGEINT: { - BigInteger data = (BigInteger) obj; - byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); - - //here value is 16 bytes, so if result data greater than the maximum of 16 bytes - //it will return a wrong num to backend; - byte[] value = new byte[16]; - //check data is negative - if (data.signum() == -1) { - Arrays.fill(value, (byte) -1); - } - for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { - value[index] = bytes[index]; - } - - UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); - return true; - } - case DECIMALV2: { - Preconditions.checkArgument(((BigDecimal) obj).scale() == 9, "Scale of DECIMALV2 must be 9"); - BigInteger data = ((BigDecimal) obj).unscaledValue(); - byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); - //TODO: here is maybe overflow also, and may find a better way to handle - byte[] value = new byte[16]; - if (data.signum() == -1) { - Arrays.fill(value, (byte) -1); - } - - for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { - value[index] = bytes[index]; - } - - UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); - return true; - } - case DECIMAL32: - case DECIMAL64: - case DECIMAL128: { - BigDecimal retValue = ((BigDecimal) obj).setScale(retType.getScale(), RoundingMode.HALF_EVEN); - BigInteger data = retValue.unscaledValue(); - byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); - //TODO: here is maybe overflow also, and may find a better way to handle - byte[] value = new byte[retType.getLen()]; - if (data.signum() == -1) { - Arrays.fill(value, (byte) -1); - } + return super.storeUdfResult(obj, row, retClass); + } - for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { - value[index] = bytes[index]; - } + @Override + protected long getCurrentOutputOffset(long row) { + return outputOffset; + } - UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); - return true; - } - case CHAR: - case VARCHAR: - case STRING: { - long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr); - byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8); - if (outputOffset + bytes.length > bufferSize) { - return false; - } - outputOffset += bytes.length; - UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * row, - Integer.parseUnsignedInt(String.valueOf(outputOffset))); - UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + outputOffset - bytes.length, bytes.length); - return true; - } - default: - throw new UdfRuntimeException("Unsupported return type: " + retType); - } + @Override + protected void updateOutputOffset(long offset) { + outputOffset = offset; } // Preallocate the input objects that will be passed to the underlying UDF. // These objects are allocated once and reused across calls to evaluate() - private void allocateInputObjects(long row) throws UdfRuntimeException { - inputObjects = new Object[argTypes.length]; - inputArgs = new Object[argTypes.length]; - - for (int i = 0; i < argTypes.length; ++i) { - switch (argTypes[i]) { - case BOOLEAN: - inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); - break; - case TINYINT: - inputObjects[i] = UdfUtils.UNSAFE.getByte(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); - break; - case SMALLINT: - inputObjects[i] = UdfUtils.UNSAFE.getShort(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case INT: - inputObjects[i] = UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case BIGINT: - inputObjects[i] = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case FLOAT: - inputObjects[i] = UdfUtils.UNSAFE.getFloat(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case DOUBLE: - inputObjects[i] = UdfUtils.UNSAFE.getDouble(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case DATE: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateToJavaDate(data, argClass[i]); - break; - } - case DATETIME: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i]); - break; - } - case DATEV2: { - int data = UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data, argClass[i]); - break; - } - case DATETIMEV2: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i]); - break; - } - case LARGEINT: { - long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row; - byte[] bytes = new byte[argTypes[i].getLen()]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen()); - - inputObjects[i] = new BigInteger(UdfUtils.convertByteOrder(bytes)); - break; - } - case DECIMALV2: - case DECIMAL32: - case DECIMAL64: - case DECIMAL128: { - long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row; - byte[] bytes = new byte[argTypes[i].getLen()]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen()); - - BigInteger value = new BigInteger(UdfUtils.convertByteOrder(bytes)); - inputObjects[i] = new BigDecimal(value, argTypes[i].getScale()); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * row)); - long numBytes = row == 0 ? offset : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1))); - long base = - row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) : - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + offset - numBytes; - byte[] bytes = new byte[(int) numBytes]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); - inputObjects[i] = new String(bytes, StandardCharsets.UTF_8); - break; - } - default: - throw new UdfRuntimeException("Unsupported argument type: " + argTypes[i]); - } - } - } - - private void init(String jarPath, String udfPath, Type funcRetType, Type... parameterTypes) - throws UdfRuntimeException { + @Override + protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type funcRetType, + Type... parameterTypes) throws UdfRuntimeException { + String className = request.fn.scalar_fn.symbol; + batchSizePtr = request.batch_size_ptr; + outputOffset = 0L; + rowIdx = 0L; ArrayList<String> signatures = Lists.newArrayList(); try { - LOG.debug("Loading UDF '" + udfPath + "' from " + jarPath); + LOG.debug("Loading UDF '" + className + "' from " + jarPath); ClassLoader loader; if (jarPath != null) { // Save for cleanup. @@ -493,7 +169,7 @@ public class UdfExecutor { // for test loader = ClassLoader.getSystemClassLoader(); } - Class<?> c = Class.forName(udfPath, true, loader); + Class<?> c = Class.forName(className, true, loader); Constructor<?> ctor = c.getConstructor(); udf = ctor.newInstance(); Method[] methods = c.getMethods(); @@ -520,7 +196,7 @@ public class UdfExecutor { retType = returnType.second; } argTypes = new JavaUdfDataType[0]; - LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath); + LOG.debug("Loaded UDF '" + className + "' from " + jarPath); return; } returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType()); @@ -535,13 +211,13 @@ public class UdfExecutor { } else { argTypes = inputType.second; } - LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath); + LOG.debug("Loaded UDF '" + className + "' from " + jarPath); return; } StringBuilder sb = new StringBuilder(); sb.append("Unable to find evaluate function with the correct signature: ") - .append(udfPath + ".evaluate(") + .append(className + ".evaluate(") .append(Joiner.on(", ").join(parameterTypes)) .append(")\n") .append("UDF contains: \n ") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org