Added support for 32 inputs and outputs per model. Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com> --- drivers/ml/cnxk/cn10k_ml_model.c | 374 ++++++++++++++++++++++--------- drivers/ml/cnxk/cn10k_ml_model.h | 5 +- drivers/ml/cnxk/cn10k_ml_ops.c | 125 ++++++++--- 3 files changed, 367 insertions(+), 137 deletions(-)
diff --git a/drivers/ml/cnxk/cn10k_ml_model.c b/drivers/ml/cnxk/cn10k_ml_model.c index a15df700aa..92c47d39ba 100644 --- a/drivers/ml/cnxk/cn10k_ml_model.c +++ b/drivers/ml/cnxk/cn10k_ml_model.c @@ -41,8 +41,9 @@ cn10k_ml_model_metadata_check(uint8_t *buffer, uint64_t size) struct cn10k_ml_model_metadata *metadata; uint32_t payload_crc32c; uint32_t header_crc32c; - uint8_t version[4]; + uint32_t version; uint8_t i; + uint8_t j; metadata = (struct cn10k_ml_model_metadata *)buffer; @@ -82,10 +83,13 @@ cn10k_ml_model_metadata_check(uint8_t *buffer, uint64_t size) } /* Header version */ - rte_memcpy(version, metadata->header.version, 4 * sizeof(uint8_t)); - if (version[0] * 1000 + version[1] * 100 != MRVL_ML_MODEL_VERSION_MIN) { - plt_err("Metadata version = %u.%u.%u.%u (< %u.%u.%u.%u) not supported", version[0], - version[1], version[2], version[3], (MRVL_ML_MODEL_VERSION_MIN / 1000) % 10, + version = metadata->header.version[0] * 1000 + metadata->header.version[1] * 100 + + metadata->header.version[2] * 10 + metadata->header.version[3]; + if (version < MRVL_ML_MODEL_VERSION_MIN) { + plt_err("Metadata version = %u.%u.%u.%u (< %u.%u.%u.%u) not supported", + metadata->header.version[0], metadata->header.version[1], + metadata->header.version[2], metadata->header.version[3], + (MRVL_ML_MODEL_VERSION_MIN / 1000) % 10, (MRVL_ML_MODEL_VERSION_MIN / 100) % 10, (MRVL_ML_MODEL_VERSION_MIN / 10) % 10, MRVL_ML_MODEL_VERSION_MIN % 10); return -ENOTSUP; @@ -125,60 +129,119 @@ cn10k_ml_model_metadata_check(uint8_t *buffer, uint64_t size) } /* Check input count */ - if (metadata->model.num_input > MRVL_ML_NUM_INPUT_OUTPUT_1) { - plt_err("Invalid metadata, num_input = %u (> %u)", metadata->model.num_input, - MRVL_ML_NUM_INPUT_OUTPUT_1); - return -EINVAL; - } - - /* Check output count */ - if (metadata->model.num_output > MRVL_ML_NUM_INPUT_OUTPUT_1) { - plt_err("Invalid metadata, num_output = %u (> %u)", metadata->model.num_output, - MRVL_ML_NUM_INPUT_OUTPUT_1); - return -EINVAL; - } - - /* Inputs */ - for (i = 0; i < metadata->model.num_input; i++) { - if (rte_ml_io_type_size_get(cn10k_ml_io_type_map(metadata->input1[i].input_type)) <= - 0) { - plt_err("Invalid metadata, input[%u] : input_type = %u", i, - metadata->input1[i].input_type); + if (version < 2301) { + if (metadata->model.num_input > MRVL_ML_NUM_INPUT_OUTPUT_1) { + plt_err("Invalid metadata, num_input = %u (> %u)", + metadata->model.num_input, MRVL_ML_NUM_INPUT_OUTPUT_1); return -EINVAL; } - if (rte_ml_io_type_size_get( - cn10k_ml_io_type_map(metadata->input1[i].model_input_type)) <= 0) { - plt_err("Invalid metadata, input[%u] : model_input_type = %u", i, - metadata->input1[i].model_input_type); + /* Check output count */ + if (metadata->model.num_output > MRVL_ML_NUM_INPUT_OUTPUT_1) { + plt_err("Invalid metadata, num_output = %u (> %u)", + metadata->model.num_output, MRVL_ML_NUM_INPUT_OUTPUT_1); return -EINVAL; } - - if (metadata->input1[i].relocatable != 1) { - plt_err("Model not supported, non-relocatable input: %u", i); - return -ENOTSUP; + } else { + if (metadata->model.num_input > MRVL_ML_NUM_INPUT_OUTPUT) { + plt_err("Invalid metadata, num_input = %u (> %u)", + metadata->model.num_input, MRVL_ML_NUM_INPUT_OUTPUT); + return -EINVAL; } - } - /* Outputs */ - for (i = 0; i < metadata->model.num_output; i++) { - if (rte_ml_io_type_size_get( - cn10k_ml_io_type_map(metadata->output1[i].output_type)) <= 0) { - plt_err("Invalid metadata, output[%u] : output_type = %u", i, - metadata->output1[i].output_type); + /* Check output count */ + if (metadata->model.num_output > MRVL_ML_NUM_INPUT_OUTPUT) { + plt_err("Invalid metadata, num_output = %u (> %u)", + metadata->model.num_output, MRVL_ML_NUM_INPUT_OUTPUT); return -EINVAL; } + } - if (rte_ml_io_type_size_get( - cn10k_ml_io_type_map(metadata->output1[i].model_output_type)) <= 0) { - plt_err("Invalid metadata, output[%u] : model_output_type = %u", i, - metadata->output1[i].model_output_type); - return -EINVAL; + /* Inputs */ + for (i = 0; i < metadata->model.num_input; i++) { + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + if (rte_ml_io_type_size_get( + cn10k_ml_io_type_map(metadata->input1[i].input_type)) <= 0) { + plt_err("Invalid metadata, input1[%u] : input_type = %u", i, + metadata->input1[i].input_type); + return -EINVAL; + } + + if (rte_ml_io_type_size_get(cn10k_ml_io_type_map( + metadata->input1[i].model_input_type)) <= 0) { + plt_err("Invalid metadata, input1[%u] : model_input_type = %u", i, + metadata->input1[i].model_input_type); + return -EINVAL; + } + + if (metadata->input1[i].relocatable != 1) { + plt_err("Model not supported, non-relocatable input1: %u", i); + return -ENOTSUP; + } + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + if (rte_ml_io_type_size_get( + cn10k_ml_io_type_map(metadata->input2[j].input_type)) <= 0) { + plt_err("Invalid metadata, input2[%u] : input_type = %u", j, + metadata->input2[j].input_type); + return -EINVAL; + } + + if (rte_ml_io_type_size_get(cn10k_ml_io_type_map( + metadata->input2[j].model_input_type)) <= 0) { + plt_err("Invalid metadata, input2[%u] : model_input_type = %u", j, + metadata->input2[j].model_input_type); + return -EINVAL; + } + + if (metadata->input2[j].relocatable != 1) { + plt_err("Model not supported, non-relocatable input2: %u", j); + return -ENOTSUP; + } } + } - if (metadata->output1[i].relocatable != 1) { - plt_err("Model not supported, non-relocatable output: %u", i); - return -ENOTSUP; + /* Outputs */ + for (i = 0; i < metadata->model.num_output; i++) { + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + if (rte_ml_io_type_size_get( + cn10k_ml_io_type_map(metadata->output1[i].output_type)) <= 0) { + plt_err("Invalid metadata, output1[%u] : output_type = %u", i, + metadata->output1[i].output_type); + return -EINVAL; + } + + if (rte_ml_io_type_size_get(cn10k_ml_io_type_map( + metadata->output1[i].model_output_type)) <= 0) { + plt_err("Invalid metadata, output1[%u] : model_output_type = %u", i, + metadata->output1[i].model_output_type); + return -EINVAL; + } + + if (metadata->output1[i].relocatable != 1) { + plt_err("Model not supported, non-relocatable output1: %u", i); + return -ENOTSUP; + } + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + if (rte_ml_io_type_size_get( + cn10k_ml_io_type_map(metadata->output2[j].output_type)) <= 0) { + plt_err("Invalid metadata, output2[%u] : output_type = %u", j, + metadata->output2[j].output_type); + return -EINVAL; + } + + if (rte_ml_io_type_size_get(cn10k_ml_io_type_map( + metadata->output2[j].model_output_type)) <= 0) { + plt_err("Invalid metadata, output2[%u] : model_output_type = %u", j, + metadata->output2[j].model_output_type); + return -EINVAL; + } + + if (metadata->output2[j].relocatable != 1) { + plt_err("Model not supported, non-relocatable output2: %u", j); + return -ENOTSUP; + } } } @@ -189,31 +252,60 @@ void cn10k_ml_model_metadata_update(struct cn10k_ml_model_metadata *metadata) { uint8_t i; + uint8_t j; for (i = 0; i < metadata->model.num_input; i++) { - metadata->input1[i].input_type = - cn10k_ml_io_type_map(metadata->input1[i].input_type); - metadata->input1[i].model_input_type = - cn10k_ml_io_type_map(metadata->input1[i].model_input_type); + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + metadata->input1[i].input_type = + cn10k_ml_io_type_map(metadata->input1[i].input_type); + metadata->input1[i].model_input_type = + cn10k_ml_io_type_map(metadata->input1[i].model_input_type); + + if (metadata->input1[i].shape.w == 0) + metadata->input1[i].shape.w = 1; + + if (metadata->input1[i].shape.x == 0) + metadata->input1[i].shape.x = 1; + + if (metadata->input1[i].shape.y == 0) + metadata->input1[i].shape.y = 1; - if (metadata->input1[i].shape.w == 0) - metadata->input1[i].shape.w = 1; + if (metadata->input1[i].shape.z == 0) + metadata->input1[i].shape.z = 1; + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + metadata->input2[j].input_type = + cn10k_ml_io_type_map(metadata->input2[j].input_type); + metadata->input2[j].model_input_type = + cn10k_ml_io_type_map(metadata->input2[j].model_input_type); - if (metadata->input1[i].shape.x == 0) - metadata->input1[i].shape.x = 1; + if (metadata->input2[j].shape.w == 0) + metadata->input2[j].shape.w = 1; - if (metadata->input1[i].shape.y == 0) - metadata->input1[i].shape.y = 1; + if (metadata->input2[j].shape.x == 0) + metadata->input2[j].shape.x = 1; - if (metadata->input1[i].shape.z == 0) - metadata->input1[i].shape.z = 1; + if (metadata->input2[j].shape.y == 0) + metadata->input2[j].shape.y = 1; + + if (metadata->input2[j].shape.z == 0) + metadata->input2[j].shape.z = 1; + } } for (i = 0; i < metadata->model.num_output; i++) { - metadata->output1[i].output_type = - cn10k_ml_io_type_map(metadata->output1[i].output_type); - metadata->output1[i].model_output_type = - cn10k_ml_io_type_map(metadata->output1[i].model_output_type); + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + metadata->output1[i].output_type = + cn10k_ml_io_type_map(metadata->output1[i].output_type); + metadata->output1[i].model_output_type = + cn10k_ml_io_type_map(metadata->output1[i].model_output_type); + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + metadata->output2[j].output_type = + cn10k_ml_io_type_map(metadata->output2[j].output_type); + metadata->output2[j].model_output_type = + cn10k_ml_io_type_map(metadata->output2[j].model_output_type); + } } } @@ -226,6 +318,7 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, uint8_t *buffer, uint8_ uint8_t *dma_addr_load; uint8_t *dma_addr_run; uint8_t i; + uint8_t j; int fpos; metadata = &model->metadata; @@ -272,37 +365,80 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, uint8_t *buffer, uint8_ addr->total_input_sz_d = 0; addr->total_input_sz_q = 0; for (i = 0; i < metadata->model.num_input; i++) { - addr->input[i].nb_elements = - metadata->input1[i].shape.w * metadata->input1[i].shape.x * - metadata->input1[i].shape.y * metadata->input1[i].shape.z; - addr->input[i].sz_d = addr->input[i].nb_elements * - rte_ml_io_type_size_get(metadata->input1[i].input_type); - addr->input[i].sz_q = addr->input[i].nb_elements * - rte_ml_io_type_size_get(metadata->input1[i].model_input_type); - addr->total_input_sz_d += addr->input[i].sz_d; - addr->total_input_sz_q += addr->input[i].sz_q; - - plt_ml_dbg("model_id = %u, input[%u] - w:%u x:%u y:%u z:%u, sz_d = %u sz_q = %u", - model->model_id, i, metadata->input1[i].shape.w, - metadata->input1[i].shape.x, metadata->input1[i].shape.y, - metadata->input1[i].shape.z, addr->input[i].sz_d, addr->input[i].sz_q); + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + addr->input[i].nb_elements = + metadata->input1[i].shape.w * metadata->input1[i].shape.x * + metadata->input1[i].shape.y * metadata->input1[i].shape.z; + addr->input[i].sz_d = + addr->input[i].nb_elements * + rte_ml_io_type_size_get(metadata->input1[i].input_type); + addr->input[i].sz_q = + addr->input[i].nb_elements * + rte_ml_io_type_size_get(metadata->input1[i].model_input_type); + addr->total_input_sz_d += addr->input[i].sz_d; + addr->total_input_sz_q += addr->input[i].sz_q; + + plt_ml_dbg( + "model_id = %u, input[%u] - w:%u x:%u y:%u z:%u, sz_d = %u sz_q = %u", + model->model_id, i, metadata->input1[i].shape.w, + metadata->input1[i].shape.x, metadata->input1[i].shape.y, + metadata->input1[i].shape.z, addr->input[i].sz_d, + addr->input[i].sz_q); + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + addr->input[i].nb_elements = + metadata->input2[j].shape.w * metadata->input2[j].shape.x * + metadata->input2[j].shape.y * metadata->input2[j].shape.z; + addr->input[i].sz_d = + addr->input[i].nb_elements * + rte_ml_io_type_size_get(metadata->input2[j].input_type); + addr->input[i].sz_q = + addr->input[i].nb_elements * + rte_ml_io_type_size_get(metadata->input2[j].model_input_type); + addr->total_input_sz_d += addr->input[i].sz_d; + addr->total_input_sz_q += addr->input[i].sz_q; + + plt_ml_dbg( + "model_id = %u, input2[%u] - w:%u x:%u y:%u z:%u, sz_d = %u sz_q = %u", + model->model_id, j, metadata->input2[j].shape.w, + metadata->input2[j].shape.x, metadata->input2[j].shape.y, + metadata->input2[j].shape.z, addr->input[i].sz_d, + addr->input[i].sz_q); + } } /* Outputs */ addr->total_output_sz_q = 0; addr->total_output_sz_d = 0; for (i = 0; i < metadata->model.num_output; i++) { - addr->output[i].nb_elements = metadata->output1[i].size; - addr->output[i].sz_d = addr->output[i].nb_elements * - rte_ml_io_type_size_get(metadata->output1[i].output_type); - addr->output[i].sz_q = - addr->output[i].nb_elements * - rte_ml_io_type_size_get(metadata->output1[i].model_output_type); - addr->total_output_sz_q += addr->output[i].sz_q; - addr->total_output_sz_d += addr->output[i].sz_d; - - plt_ml_dbg("model_id = %u, output[%u] - sz_d = %u, sz_q = %u", model->model_id, i, - addr->output[i].sz_d, addr->output[i].sz_q); + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + addr->output[i].nb_elements = metadata->output1[i].size; + addr->output[i].sz_d = + addr->output[i].nb_elements * + rte_ml_io_type_size_get(metadata->output1[i].output_type); + addr->output[i].sz_q = + addr->output[i].nb_elements * + rte_ml_io_type_size_get(metadata->output1[i].model_output_type); + addr->total_output_sz_q += addr->output[i].sz_q; + addr->total_output_sz_d += addr->output[i].sz_d; + + plt_ml_dbg("model_id = %u, output[%u] - sz_d = %u, sz_q = %u", + model->model_id, i, addr->output[i].sz_d, addr->output[i].sz_q); + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + addr->output[i].nb_elements = metadata->output2[j].size; + addr->output[i].sz_d = + addr->output[i].nb_elements * + rte_ml_io_type_size_get(metadata->output2[j].output_type); + addr->output[i].sz_q = + addr->output[i].nb_elements * + rte_ml_io_type_size_get(metadata->output2[j].model_output_type); + addr->total_output_sz_q += addr->output[i].sz_q; + addr->total_output_sz_d += addr->output[i].sz_d; + + plt_ml_dbg("model_id = %u, output2[%u] - sz_d = %u, sz_q = %u", + model->model_id, j, addr->output[i].sz_d, addr->output[i].sz_q); + } } } @@ -366,6 +502,7 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cn10k_ml_model *model) struct rte_ml_io_info *output; struct rte_ml_io_info *input; uint8_t i; + uint8_t j; metadata = &model->metadata; info = PLT_PTR_CAST(model->info); @@ -389,26 +526,53 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cn10k_ml_model *model) /* Set input info */ for (i = 0; i < info->nb_inputs; i++) { - rte_memcpy(input[i].name, metadata->input1[i].input_name, MRVL_ML_INPUT_NAME_LEN); - input[i].dtype = metadata->input1[i].input_type; - input[i].qtype = metadata->input1[i].model_input_type; - input[i].shape.format = metadata->input1[i].shape.format; - input[i].shape.w = metadata->input1[i].shape.w; - input[i].shape.x = metadata->input1[i].shape.x; - input[i].shape.y = metadata->input1[i].shape.y; - input[i].shape.z = metadata->input1[i].shape.z; + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + rte_memcpy(input[i].name, metadata->input1[i].input_name, + MRVL_ML_INPUT_NAME_LEN); + input[i].dtype = metadata->input1[i].input_type; + input[i].qtype = metadata->input1[i].model_input_type; + input[i].shape.format = metadata->input1[i].shape.format; + input[i].shape.w = metadata->input1[i].shape.w; + input[i].shape.x = metadata->input1[i].shape.x; + input[i].shape.y = metadata->input1[i].shape.y; + input[i].shape.z = metadata->input1[i].shape.z; + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + rte_memcpy(input[i].name, metadata->input2[j].input_name, + MRVL_ML_INPUT_NAME_LEN); + input[i].dtype = metadata->input2[j].input_type; + input[i].qtype = metadata->input2[j].model_input_type; + input[i].shape.format = metadata->input2[j].shape.format; + input[i].shape.w = metadata->input2[j].shape.w; + input[i].shape.x = metadata->input2[j].shape.x; + input[i].shape.y = metadata->input2[j].shape.y; + input[i].shape.z = metadata->input2[j].shape.z; + } } /* Set output info */ for (i = 0; i < info->nb_outputs; i++) { - rte_memcpy(output[i].name, metadata->output1[i].output_name, - MRVL_ML_OUTPUT_NAME_LEN); - output[i].dtype = metadata->output1[i].output_type; - output[i].qtype = metadata->output1[i].model_output_type; - output[i].shape.format = RTE_ML_IO_FORMAT_1D; - output[i].shape.w = metadata->output1[i].size; - output[i].shape.x = 1; - output[i].shape.y = 1; - output[i].shape.z = 1; + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + rte_memcpy(output[i].name, metadata->output1[i].output_name, + MRVL_ML_OUTPUT_NAME_LEN); + output[i].dtype = metadata->output1[i].output_type; + output[i].qtype = metadata->output1[i].model_output_type; + output[i].shape.format = RTE_ML_IO_FORMAT_1D; + output[i].shape.w = metadata->output1[i].size; + output[i].shape.x = 1; + output[i].shape.y = 1; + output[i].shape.z = 1; + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + rte_memcpy(output[i].name, metadata->output2[j].output_name, + MRVL_ML_OUTPUT_NAME_LEN); + output[i].dtype = metadata->output2[j].output_type; + output[i].qtype = metadata->output2[j].model_output_type; + output[i].shape.format = RTE_ML_IO_FORMAT_1D; + output[i].shape.w = metadata->output2[j].size; + output[i].shape.x = 1; + output[i].shape.y = 1; + output[i].shape.z = 1; + } } } diff --git a/drivers/ml/cnxk/cn10k_ml_model.h b/drivers/ml/cnxk/cn10k_ml_model.h index bd863a8c12..5c34e4d747 100644 --- a/drivers/ml/cnxk/cn10k_ml_model.h +++ b/drivers/ml/cnxk/cn10k_ml_model.h @@ -30,6 +30,7 @@ enum cn10k_ml_model_state { #define MRVL_ML_OUTPUT_NAME_LEN 16 #define MRVL_ML_NUM_INPUT_OUTPUT_1 8 #define MRVL_ML_NUM_INPUT_OUTPUT_2 24 +#define MRVL_ML_NUM_INPUT_OUTPUT (MRVL_ML_NUM_INPUT_OUTPUT_1 + MRVL_ML_NUM_INPUT_OUTPUT_2) /* Header (256-byte) */ struct cn10k_ml_model_metadata_header { @@ -413,7 +414,7 @@ struct cn10k_ml_model_addr { /* Quantized input size */ uint32_t sz_q; - } input[MRVL_ML_NUM_INPUT_OUTPUT_1]; + } input[MRVL_ML_NUM_INPUT_OUTPUT]; /* Output address and size */ struct { @@ -425,7 +426,7 @@ struct cn10k_ml_model_addr { /* Quantized output size */ uint32_t sz_q; - } output[MRVL_ML_NUM_INPUT_OUTPUT_1]; + } output[MRVL_ML_NUM_INPUT_OUTPUT]; /* Total size of quantized input */ uint32_t total_input_sz_q; diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c index aecc6e74ad..1033afb1b0 100644 --- a/drivers/ml/cnxk/cn10k_ml_ops.c +++ b/drivers/ml/cnxk/cn10k_ml_ops.c @@ -269,6 +269,7 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t model_id, FILE *fp) struct cn10k_ml_ocm *ocm; char str[STR_LEN]; uint8_t i; + uint8_t j; mldev = dev->data->dev_private; ocm = &mldev->ocm; @@ -324,16 +325,36 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t model_id, FILE *fp) "model_input_type", "quantize", "format"); print_line(fp, LINE_LEN); for (i = 0; i < model->metadata.model.num_input; i++) { - fprintf(fp, "%8u ", i); - fprintf(fp, "%*s ", 16, model->metadata.input1[i].input_name); - rte_ml_io_type_to_str(model->metadata.input1[i].input_type, str, STR_LEN); - fprintf(fp, "%*s ", 12, str); - rte_ml_io_type_to_str(model->metadata.input1[i].model_input_type, str, STR_LEN); - fprintf(fp, "%*s ", 18, str); - fprintf(fp, "%*s", 12, (model->metadata.input1[i].quantize == 1 ? "Yes" : "No")); - rte_ml_io_format_to_str(model->metadata.input1[i].shape.format, str, STR_LEN); - fprintf(fp, "%*s", 16, str); - fprintf(fp, "\n"); + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + fprintf(fp, "%8u ", i); + fprintf(fp, "%*s ", 16, model->metadata.input1[i].input_name); + rte_ml_io_type_to_str(model->metadata.input1[i].input_type, str, STR_LEN); + fprintf(fp, "%*s ", 12, str); + rte_ml_io_type_to_str(model->metadata.input1[i].model_input_type, str, + STR_LEN); + fprintf(fp, "%*s ", 18, str); + fprintf(fp, "%*s", 12, + (model->metadata.input1[i].quantize == 1 ? "Yes" : "No")); + rte_ml_io_format_to_str(model->metadata.input1[i].shape.format, str, + STR_LEN); + fprintf(fp, "%*s", 16, str); + fprintf(fp, "\n"); + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + fprintf(fp, "%8u ", i); + fprintf(fp, "%*s ", 16, model->metadata.input2[j].input_name); + rte_ml_io_type_to_str(model->metadata.input2[j].input_type, str, STR_LEN); + fprintf(fp, "%*s ", 12, str); + rte_ml_io_type_to_str(model->metadata.input2[j].model_input_type, str, + STR_LEN); + fprintf(fp, "%*s ", 18, str); + fprintf(fp, "%*s", 12, + (model->metadata.input2[j].quantize == 1 ? "Yes" : "No")); + rte_ml_io_format_to_str(model->metadata.input2[j].shape.format, str, + STR_LEN); + fprintf(fp, "%*s", 16, str); + fprintf(fp, "\n"); + } } fprintf(fp, "\n"); @@ -342,14 +363,30 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t model_id, FILE *fp) "model_output_type", "dequantize"); print_line(fp, LINE_LEN); for (i = 0; i < model->metadata.model.num_output; i++) { - fprintf(fp, "%8u ", i); - fprintf(fp, "%*s ", 16, model->metadata.output1[i].output_name); - rte_ml_io_type_to_str(model->metadata.output1[i].output_type, str, STR_LEN); - fprintf(fp, "%*s ", 12, str); - rte_ml_io_type_to_str(model->metadata.output1[i].model_output_type, str, STR_LEN); - fprintf(fp, "%*s ", 18, str); - fprintf(fp, "%*s", 12, (model->metadata.output1[i].dequantize == 1 ? "Yes" : "No")); - fprintf(fp, "\n"); + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + fprintf(fp, "%8u ", i); + fprintf(fp, "%*s ", 16, model->metadata.output1[i].output_name); + rte_ml_io_type_to_str(model->metadata.output1[i].output_type, str, STR_LEN); + fprintf(fp, "%*s ", 12, str); + rte_ml_io_type_to_str(model->metadata.output1[i].model_output_type, str, + STR_LEN); + fprintf(fp, "%*s ", 18, str); + fprintf(fp, "%*s", 12, + (model->metadata.output1[i].dequantize == 1 ? "Yes" : "No")); + fprintf(fp, "\n"); + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + fprintf(fp, "%8u ", i); + fprintf(fp, "%*s ", 16, model->metadata.output2[j].output_name); + rte_ml_io_type_to_str(model->metadata.output2[j].output_type, str, STR_LEN); + fprintf(fp, "%*s ", 12, str); + rte_ml_io_type_to_str(model->metadata.output2[j].model_output_type, str, + STR_LEN); + fprintf(fp, "%*s ", 18, str); + fprintf(fp, "%*s", 12, + (model->metadata.output2[j].dequantize == 1 ? "Yes" : "No")); + fprintf(fp, "\n"); + } } fprintf(fp, "\n"); print_line(fp, LINE_LEN); @@ -1863,10 +1900,14 @@ cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batc void *qbuffer) { struct cn10k_ml_model *model; + uint8_t model_input_type; uint8_t *lcl_dbuffer; uint8_t *lcl_qbuffer; + uint8_t input_type; uint32_t batch_id; + float qscale; uint32_t i; + uint32_t j; int ret; model = dev->data->models[model_id]; @@ -1882,28 +1923,38 @@ cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_batc next_batch: for (i = 0; i < model->metadata.model.num_input; i++) { - if (model->metadata.input1[i].input_type == - model->metadata.input1[i].model_input_type) { + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + input_type = model->metadata.input1[i].input_type; + model_input_type = model->metadata.input1[i].model_input_type; + qscale = model->metadata.input1[i].qscale; + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + input_type = model->metadata.input2[j].input_type; + model_input_type = model->metadata.input2[j].model_input_type; + qscale = model->metadata.input2[j].qscale; + } + + if (input_type == model_input_type) { rte_memcpy(lcl_qbuffer, lcl_dbuffer, model->addr.input[i].sz_d); } else { switch (model->metadata.input1[i].model_input_type) { case RTE_ML_IO_TYPE_INT8: - ret = rte_ml_io_float32_to_int8(model->metadata.input1[i].qscale, + ret = rte_ml_io_float32_to_int8(qscale, model->addr.input[i].nb_elements, lcl_dbuffer, lcl_qbuffer); break; case RTE_ML_IO_TYPE_UINT8: - ret = rte_ml_io_float32_to_uint8(model->metadata.input1[i].qscale, + ret = rte_ml_io_float32_to_uint8(qscale, model->addr.input[i].nb_elements, lcl_dbuffer, lcl_qbuffer); break; case RTE_ML_IO_TYPE_INT16: - ret = rte_ml_io_float32_to_int16(model->metadata.input1[i].qscale, + ret = rte_ml_io_float32_to_int16(qscale, model->addr.input[i].nb_elements, lcl_dbuffer, lcl_qbuffer); break; case RTE_ML_IO_TYPE_UINT16: - ret = rte_ml_io_float32_to_uint16(model->metadata.input1[i].qscale, + ret = rte_ml_io_float32_to_uint16(qscale, model->addr.input[i].nb_elements, lcl_dbuffer, lcl_qbuffer); break; @@ -1936,10 +1987,14 @@ cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_ba void *qbuffer, void *dbuffer) { struct cn10k_ml_model *model; + uint8_t model_output_type; uint8_t *lcl_qbuffer; uint8_t *lcl_dbuffer; + uint8_t output_type; uint32_t batch_id; + float dscale; uint32_t i; + uint32_t j; int ret; model = dev->data->models[model_id]; @@ -1955,28 +2010,38 @@ cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t nb_ba next_batch: for (i = 0; i < model->metadata.model.num_output; i++) { - if (model->metadata.output1[i].output_type == - model->metadata.output1[i].model_output_type) { + if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) { + output_type = model->metadata.output1[i].output_type; + model_output_type = model->metadata.output1[i].model_output_type; + dscale = model->metadata.output1[i].dscale; + } else { + j = i - MRVL_ML_NUM_INPUT_OUTPUT_1; + output_type = model->metadata.output2[j].output_type; + model_output_type = model->metadata.output2[j].model_output_type; + dscale = model->metadata.output2[j].dscale; + } + + if (output_type == model_output_type) { rte_memcpy(lcl_dbuffer, lcl_qbuffer, model->addr.output[i].sz_q); } else { switch (model->metadata.output1[i].model_output_type) { case RTE_ML_IO_TYPE_INT8: - ret = rte_ml_io_int8_to_float32(model->metadata.output1[i].dscale, + ret = rte_ml_io_int8_to_float32(dscale, model->addr.output[i].nb_elements, lcl_qbuffer, lcl_dbuffer); break; case RTE_ML_IO_TYPE_UINT8: - ret = rte_ml_io_uint8_to_float32(model->metadata.output1[i].dscale, + ret = rte_ml_io_uint8_to_float32(dscale, model->addr.output[i].nb_elements, lcl_qbuffer, lcl_dbuffer); break; case RTE_ML_IO_TYPE_INT16: - ret = rte_ml_io_int16_to_float32(model->metadata.output1[i].dscale, + ret = rte_ml_io_int16_to_float32(dscale, model->addr.output[i].nb_elements, lcl_qbuffer, lcl_dbuffer); break; case RTE_ML_IO_TYPE_UINT16: - ret = rte_ml_io_uint16_to_float32(model->metadata.output1[i].dscale, + ret = rte_ml_io_uint16_to_float32(dscale, model->addr.output[i].nb_elements, lcl_qbuffer, lcl_dbuffer); break; -- 2.17.1