Added support for 32 inputs and outputs per model.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_model.c | 374 ++++++++++++++++++++++---------
 drivers/ml/cnxk/cn10k_ml_model.h |   5 +-
 drivers/ml/cnxk/cn10k_ml_ops.c   | 125 ++++++++---
 3 files changed, 367 insertions(+), 137 deletions(-)

diff --git a/drivers/ml/cnxk/cn10k_ml_model.c b/drivers/ml/cnxk/cn10k_ml_model.c
index a15df700aa..92c47d39ba 100644
--- a/drivers/ml/cnxk/cn10k_ml_model.c
+++ b/drivers/ml/cnxk/cn10k_ml_model.c
@@ -41,8 +41,9 @@ cn10k_ml_model_metadata_check(uint8_t *buffer, uint64_t size)
        struct cn10k_ml_model_metadata *metadata;
        uint32_t payload_crc32c;
        uint32_t header_crc32c;
-       uint8_t version[4];
+       uint32_t version;
        uint8_t i;
+       uint8_t j;
 
        metadata = (struct cn10k_ml_model_metadata *)buffer;
 
@@ -82,10 +83,13 @@ cn10k_ml_model_metadata_check(uint8_t *buffer, uint64_t 
size)
        }
 
        /* Header version */
-       rte_memcpy(version, metadata->header.version, 4 * sizeof(uint8_t));
-       if (version[0] * 1000 + version[1] * 100 != MRVL_ML_MODEL_VERSION_MIN) {
-               plt_err("Metadata version = %u.%u.%u.%u (< %u.%u.%u.%u) not 
supported", version[0],
-                       version[1], version[2], version[3], 
(MRVL_ML_MODEL_VERSION_MIN / 1000) % 10,
+       version = metadata->header.version[0] * 1000 + 
metadata->header.version[1] * 100 +
+                 metadata->header.version[2] * 10 + 
metadata->header.version[3];
+       if (version < MRVL_ML_MODEL_VERSION_MIN) {
+               plt_err("Metadata version = %u.%u.%u.%u (< %u.%u.%u.%u) not 
supported",
+                       metadata->header.version[0], 
metadata->header.version[1],
+                       metadata->header.version[2], 
metadata->header.version[3],
+                       (MRVL_ML_MODEL_VERSION_MIN / 1000) % 10,
                        (MRVL_ML_MODEL_VERSION_MIN / 100) % 10,
                        (MRVL_ML_MODEL_VERSION_MIN / 10) % 10, 
MRVL_ML_MODEL_VERSION_MIN % 10);
                return -ENOTSUP;
@@ -125,60 +129,119 @@ cn10k_ml_model_metadata_check(uint8_t *buffer, uint64_t 
size)
        }
 
        /* Check input count */
-       if (metadata->model.num_input > MRVL_ML_NUM_INPUT_OUTPUT_1) {
-               plt_err("Invalid metadata, num_input  = %u (> %u)", 
metadata->model.num_input,
-                       MRVL_ML_NUM_INPUT_OUTPUT_1);
-               return -EINVAL;
-       }
-
-       /* Check output count */
-       if (metadata->model.num_output > MRVL_ML_NUM_INPUT_OUTPUT_1) {
-               plt_err("Invalid metadata, num_output  = %u (> %u)", 
metadata->model.num_output,
-                       MRVL_ML_NUM_INPUT_OUTPUT_1);
-               return -EINVAL;
-       }
-
-       /* Inputs */
-       for (i = 0; i < metadata->model.num_input; i++) {
-               if 
(rte_ml_io_type_size_get(cn10k_ml_io_type_map(metadata->input1[i].input_type)) 
<=
-                   0) {
-                       plt_err("Invalid metadata, input[%u] : input_type = 
%u", i,
-                               metadata->input1[i].input_type);
+       if (version < 2301) {
+               if (metadata->model.num_input > MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       plt_err("Invalid metadata, num_input  = %u (> %u)",
+                               metadata->model.num_input, 
MRVL_ML_NUM_INPUT_OUTPUT_1);
                        return -EINVAL;
                }
 
-               if (rte_ml_io_type_size_get(
-                           
cn10k_ml_io_type_map(metadata->input1[i].model_input_type)) <= 0) {
-                       plt_err("Invalid metadata, input[%u] : model_input_type 
= %u", i,
-                               metadata->input1[i].model_input_type);
+               /* Check output count */
+               if (metadata->model.num_output > MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       plt_err("Invalid metadata, num_output  = %u (> %u)",
+                               metadata->model.num_output, 
MRVL_ML_NUM_INPUT_OUTPUT_1);
                        return -EINVAL;
                }
-
-               if (metadata->input1[i].relocatable != 1) {
-                       plt_err("Model not supported, non-relocatable input: 
%u", i);
-                       return -ENOTSUP;
+       } else {
+               if (metadata->model.num_input > MRVL_ML_NUM_INPUT_OUTPUT) {
+                       plt_err("Invalid metadata, num_input  = %u (> %u)",
+                               metadata->model.num_input, 
MRVL_ML_NUM_INPUT_OUTPUT);
+                       return -EINVAL;
                }
-       }
 
-       /* Outputs */
-       for (i = 0; i < metadata->model.num_output; i++) {
-               if (rte_ml_io_type_size_get(
-                           
cn10k_ml_io_type_map(metadata->output1[i].output_type)) <= 0) {
-                       plt_err("Invalid metadata, output[%u] : output_type = 
%u", i,
-                               metadata->output1[i].output_type);
+               /* Check output count */
+               if (metadata->model.num_output > MRVL_ML_NUM_INPUT_OUTPUT) {
+                       plt_err("Invalid metadata, num_output  = %u (> %u)",
+                               metadata->model.num_output, 
MRVL_ML_NUM_INPUT_OUTPUT);
                        return -EINVAL;
                }
+       }
 
-               if (rte_ml_io_type_size_get(
-                           
cn10k_ml_io_type_map(metadata->output1[i].model_output_type)) <= 0) {
-                       plt_err("Invalid metadata, output[%u] : 
model_output_type = %u", i,
-                               metadata->output1[i].model_output_type);
-                       return -EINVAL;
+       /* Inputs */
+       for (i = 0; i < metadata->model.num_input; i++) {
+               if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       if (rte_ml_io_type_size_get(
+                                   
cn10k_ml_io_type_map(metadata->input1[i].input_type)) <= 0) {
+                               plt_err("Invalid metadata, input1[%u] : 
input_type = %u", i,
+                                       metadata->input1[i].input_type);
+                               return -EINVAL;
+                       }
+
+                       if (rte_ml_io_type_size_get(cn10k_ml_io_type_map(
+                                   metadata->input1[i].model_input_type)) <= 
0) {
+                               plt_err("Invalid metadata, input1[%u] : 
model_input_type = %u", i,
+                                       metadata->input1[i].model_input_type);
+                               return -EINVAL;
+                       }
+
+                       if (metadata->input1[i].relocatable != 1) {
+                               plt_err("Model not supported, non-relocatable 
input1: %u", i);
+                               return -ENOTSUP;
+                       }
+               } else {
+                       j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+                       if (rte_ml_io_type_size_get(
+                                   
cn10k_ml_io_type_map(metadata->input2[j].input_type)) <= 0) {
+                               plt_err("Invalid metadata, input2[%u] : 
input_type = %u", j,
+                                       metadata->input2[j].input_type);
+                               return -EINVAL;
+                       }
+
+                       if (rte_ml_io_type_size_get(cn10k_ml_io_type_map(
+                                   metadata->input2[j].model_input_type)) <= 
0) {
+                               plt_err("Invalid metadata, input2[%u] : 
model_input_type = %u", j,
+                                       metadata->input2[j].model_input_type);
+                               return -EINVAL;
+                       }
+
+                       if (metadata->input2[j].relocatable != 1) {
+                               plt_err("Model not supported, non-relocatable 
input2: %u", j);
+                               return -ENOTSUP;
+                       }
                }
+       }
 
-               if (metadata->output1[i].relocatable != 1) {
-                       plt_err("Model not supported, non-relocatable output: 
%u", i);
-                       return -ENOTSUP;
+       /* Outputs */
+       for (i = 0; i < metadata->model.num_output; i++) {
+               if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       if (rte_ml_io_type_size_get(
+                                   
cn10k_ml_io_type_map(metadata->output1[i].output_type)) <= 0) {
+                               plt_err("Invalid metadata, output1[%u] : 
output_type = %u", i,
+                                       metadata->output1[i].output_type);
+                               return -EINVAL;
+                       }
+
+                       if (rte_ml_io_type_size_get(cn10k_ml_io_type_map(
+                                   metadata->output1[i].model_output_type)) <= 
0) {
+                               plt_err("Invalid metadata, output1[%u] : 
model_output_type = %u", i,
+                                       metadata->output1[i].model_output_type);
+                               return -EINVAL;
+                       }
+
+                       if (metadata->output1[i].relocatable != 1) {
+                               plt_err("Model not supported, non-relocatable 
output1: %u", i);
+                               return -ENOTSUP;
+                       }
+               } else {
+                       j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+                       if (rte_ml_io_type_size_get(
+                                   
cn10k_ml_io_type_map(metadata->output2[j].output_type)) <= 0) {
+                               plt_err("Invalid metadata, output2[%u] : 
output_type = %u", j,
+                                       metadata->output2[j].output_type);
+                               return -EINVAL;
+                       }
+
+                       if (rte_ml_io_type_size_get(cn10k_ml_io_type_map(
+                                   metadata->output2[j].model_output_type)) <= 
0) {
+                               plt_err("Invalid metadata, output2[%u] : 
model_output_type = %u", j,
+                                       metadata->output2[j].model_output_type);
+                               return -EINVAL;
+                       }
+
+                       if (metadata->output2[j].relocatable != 1) {
+                               plt_err("Model not supported, non-relocatable 
output2: %u", j);
+                               return -ENOTSUP;
+                       }
                }
        }
 
@@ -189,31 +252,60 @@ void
 cn10k_ml_model_metadata_update(struct cn10k_ml_model_metadata *metadata)
 {
        uint8_t i;
+       uint8_t j;
 
        for (i = 0; i < metadata->model.num_input; i++) {
-               metadata->input1[i].input_type =
-                       cn10k_ml_io_type_map(metadata->input1[i].input_type);
-               metadata->input1[i].model_input_type =
-                       
cn10k_ml_io_type_map(metadata->input1[i].model_input_type);
+               if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       metadata->input1[i].input_type =
+                               
cn10k_ml_io_type_map(metadata->input1[i].input_type);
+                       metadata->input1[i].model_input_type =
+                               
cn10k_ml_io_type_map(metadata->input1[i].model_input_type);
+
+                       if (metadata->input1[i].shape.w == 0)
+                               metadata->input1[i].shape.w = 1;
+
+                       if (metadata->input1[i].shape.x == 0)
+                               metadata->input1[i].shape.x = 1;
+
+                       if (metadata->input1[i].shape.y == 0)
+                               metadata->input1[i].shape.y = 1;
 
-               if (metadata->input1[i].shape.w == 0)
-                       metadata->input1[i].shape.w = 1;
+                       if (metadata->input1[i].shape.z == 0)
+                               metadata->input1[i].shape.z = 1;
+               } else {
+                       j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+                       metadata->input2[j].input_type =
+                               
cn10k_ml_io_type_map(metadata->input2[j].input_type);
+                       metadata->input2[j].model_input_type =
+                               
cn10k_ml_io_type_map(metadata->input2[j].model_input_type);
 
-               if (metadata->input1[i].shape.x == 0)
-                       metadata->input1[i].shape.x = 1;
+                       if (metadata->input2[j].shape.w == 0)
+                               metadata->input2[j].shape.w = 1;
 
-               if (metadata->input1[i].shape.y == 0)
-                       metadata->input1[i].shape.y = 1;
+                       if (metadata->input2[j].shape.x == 0)
+                               metadata->input2[j].shape.x = 1;
 
-               if (metadata->input1[i].shape.z == 0)
-                       metadata->input1[i].shape.z = 1;
+                       if (metadata->input2[j].shape.y == 0)
+                               metadata->input2[j].shape.y = 1;
+
+                       if (metadata->input2[j].shape.z == 0)
+                               metadata->input2[j].shape.z = 1;
+               }
        }
 
        for (i = 0; i < metadata->model.num_output; i++) {
-               metadata->output1[i].output_type =
-                       cn10k_ml_io_type_map(metadata->output1[i].output_type);
-               metadata->output1[i].model_output_type =
-                       
cn10k_ml_io_type_map(metadata->output1[i].model_output_type);
+               if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       metadata->output1[i].output_type =
+                               
cn10k_ml_io_type_map(metadata->output1[i].output_type);
+                       metadata->output1[i].model_output_type =
+                               
cn10k_ml_io_type_map(metadata->output1[i].model_output_type);
+               } else {
+                       j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+                       metadata->output2[j].output_type =
+                               
cn10k_ml_io_type_map(metadata->output2[j].output_type);
+                       metadata->output2[j].model_output_type =
+                               
cn10k_ml_io_type_map(metadata->output2[j].model_output_type);
+               }
        }
 }
 
@@ -226,6 +318,7 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, 
uint8_t *buffer, uint8_
        uint8_t *dma_addr_load;
        uint8_t *dma_addr_run;
        uint8_t i;
+       uint8_t j;
        int fpos;
 
        metadata = &model->metadata;
@@ -272,37 +365,80 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, 
uint8_t *buffer, uint8_
        addr->total_input_sz_d = 0;
        addr->total_input_sz_q = 0;
        for (i = 0; i < metadata->model.num_input; i++) {
-               addr->input[i].nb_elements =
-                       metadata->input1[i].shape.w * 
metadata->input1[i].shape.x *
-                       metadata->input1[i].shape.y * 
metadata->input1[i].shape.z;
-               addr->input[i].sz_d = addr->input[i].nb_elements *
-                                     
rte_ml_io_type_size_get(metadata->input1[i].input_type);
-               addr->input[i].sz_q = addr->input[i].nb_elements *
-                                     
rte_ml_io_type_size_get(metadata->input1[i].model_input_type);
-               addr->total_input_sz_d += addr->input[i].sz_d;
-               addr->total_input_sz_q += addr->input[i].sz_q;
-
-               plt_ml_dbg("model_id = %u, input[%u] - w:%u x:%u y:%u z:%u, 
sz_d = %u sz_q = %u",
-                          model->model_id, i, metadata->input1[i].shape.w,
-                          metadata->input1[i].shape.x, 
metadata->input1[i].shape.y,
-                          metadata->input1[i].shape.z, addr->input[i].sz_d, 
addr->input[i].sz_q);
+               if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       addr->input[i].nb_elements =
+                               metadata->input1[i].shape.w * 
metadata->input1[i].shape.x *
+                               metadata->input1[i].shape.y * 
metadata->input1[i].shape.z;
+                       addr->input[i].sz_d =
+                               addr->input[i].nb_elements *
+                               
rte_ml_io_type_size_get(metadata->input1[i].input_type);
+                       addr->input[i].sz_q =
+                               addr->input[i].nb_elements *
+                               
rte_ml_io_type_size_get(metadata->input1[i].model_input_type);
+                       addr->total_input_sz_d += addr->input[i].sz_d;
+                       addr->total_input_sz_q += addr->input[i].sz_q;
+
+                       plt_ml_dbg(
+                               "model_id = %u, input[%u] - w:%u x:%u y:%u 
z:%u, sz_d = %u sz_q = %u",
+                               model->model_id, i, metadata->input1[i].shape.w,
+                               metadata->input1[i].shape.x, 
metadata->input1[i].shape.y,
+                               metadata->input1[i].shape.z, 
addr->input[i].sz_d,
+                               addr->input[i].sz_q);
+               } else {
+                       j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+                       addr->input[i].nb_elements =
+                               metadata->input2[j].shape.w * 
metadata->input2[j].shape.x *
+                               metadata->input2[j].shape.y * 
metadata->input2[j].shape.z;
+                       addr->input[i].sz_d =
+                               addr->input[i].nb_elements *
+                               
rte_ml_io_type_size_get(metadata->input2[j].input_type);
+                       addr->input[i].sz_q =
+                               addr->input[i].nb_elements *
+                               
rte_ml_io_type_size_get(metadata->input2[j].model_input_type);
+                       addr->total_input_sz_d += addr->input[i].sz_d;
+                       addr->total_input_sz_q += addr->input[i].sz_q;
+
+                       plt_ml_dbg(
+                               "model_id = %u, input2[%u] - w:%u x:%u y:%u 
z:%u, sz_d = %u sz_q = %u",
+                               model->model_id, j, metadata->input2[j].shape.w,
+                               metadata->input2[j].shape.x, 
metadata->input2[j].shape.y,
+                               metadata->input2[j].shape.z, 
addr->input[i].sz_d,
+                               addr->input[i].sz_q);
+               }
        }
 
        /* Outputs */
        addr->total_output_sz_q = 0;
        addr->total_output_sz_d = 0;
        for (i = 0; i < metadata->model.num_output; i++) {
-               addr->output[i].nb_elements = metadata->output1[i].size;
-               addr->output[i].sz_d = addr->output[i].nb_elements *
-                                      
rte_ml_io_type_size_get(metadata->output1[i].output_type);
-               addr->output[i].sz_q =
-                       addr->output[i].nb_elements *
-                       
rte_ml_io_type_size_get(metadata->output1[i].model_output_type);
-               addr->total_output_sz_q += addr->output[i].sz_q;
-               addr->total_output_sz_d += addr->output[i].sz_d;
-
-               plt_ml_dbg("model_id = %u, output[%u] - sz_d = %u, sz_q = %u", 
model->model_id, i,
-                          addr->output[i].sz_d, addr->output[i].sz_q);
+               if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       addr->output[i].nb_elements = metadata->output1[i].size;
+                       addr->output[i].sz_d =
+                               addr->output[i].nb_elements *
+                               
rte_ml_io_type_size_get(metadata->output1[i].output_type);
+                       addr->output[i].sz_q =
+                               addr->output[i].nb_elements *
+                               
rte_ml_io_type_size_get(metadata->output1[i].model_output_type);
+                       addr->total_output_sz_q += addr->output[i].sz_q;
+                       addr->total_output_sz_d += addr->output[i].sz_d;
+
+                       plt_ml_dbg("model_id = %u, output[%u] - sz_d = %u, sz_q 
= %u",
+                                  model->model_id, i, addr->output[i].sz_d, 
addr->output[i].sz_q);
+               } else {
+                       j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+                       addr->output[i].nb_elements = metadata->output2[j].size;
+                       addr->output[i].sz_d =
+                               addr->output[i].nb_elements *
+                               
rte_ml_io_type_size_get(metadata->output2[j].output_type);
+                       addr->output[i].sz_q =
+                               addr->output[i].nb_elements *
+                               
rte_ml_io_type_size_get(metadata->output2[j].model_output_type);
+                       addr->total_output_sz_q += addr->output[i].sz_q;
+                       addr->total_output_sz_d += addr->output[i].sz_d;
+
+                       plt_ml_dbg("model_id = %u, output2[%u] - sz_d = %u, 
sz_q = %u",
+                                  model->model_id, j, addr->output[i].sz_d, 
addr->output[i].sz_q);
+               }
        }
 }
 
@@ -366,6 +502,7 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct 
cn10k_ml_model *model)
        struct rte_ml_io_info *output;
        struct rte_ml_io_info *input;
        uint8_t i;
+       uint8_t j;
 
        metadata = &model->metadata;
        info = PLT_PTR_CAST(model->info);
@@ -389,26 +526,53 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct 
cn10k_ml_model *model)
 
        /* Set input info */
        for (i = 0; i < info->nb_inputs; i++) {
-               rte_memcpy(input[i].name, metadata->input1[i].input_name, 
MRVL_ML_INPUT_NAME_LEN);
-               input[i].dtype = metadata->input1[i].input_type;
-               input[i].qtype = metadata->input1[i].model_input_type;
-               input[i].shape.format = metadata->input1[i].shape.format;
-               input[i].shape.w = metadata->input1[i].shape.w;
-               input[i].shape.x = metadata->input1[i].shape.x;
-               input[i].shape.y = metadata->input1[i].shape.y;
-               input[i].shape.z = metadata->input1[i].shape.z;
+               if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       rte_memcpy(input[i].name, 
metadata->input1[i].input_name,
+                                  MRVL_ML_INPUT_NAME_LEN);
+                       input[i].dtype = metadata->input1[i].input_type;
+                       input[i].qtype = metadata->input1[i].model_input_type;
+                       input[i].shape.format = 
metadata->input1[i].shape.format;
+                       input[i].shape.w = metadata->input1[i].shape.w;
+                       input[i].shape.x = metadata->input1[i].shape.x;
+                       input[i].shape.y = metadata->input1[i].shape.y;
+                       input[i].shape.z = metadata->input1[i].shape.z;
+               } else {
+                       j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+                       rte_memcpy(input[i].name, 
metadata->input2[j].input_name,
+                                  MRVL_ML_INPUT_NAME_LEN);
+                       input[i].dtype = metadata->input2[j].input_type;
+                       input[i].qtype = metadata->input2[j].model_input_type;
+                       input[i].shape.format = 
metadata->input2[j].shape.format;
+                       input[i].shape.w = metadata->input2[j].shape.w;
+                       input[i].shape.x = metadata->input2[j].shape.x;
+                       input[i].shape.y = metadata->input2[j].shape.y;
+                       input[i].shape.z = metadata->input2[j].shape.z;
+               }
        }
 
        /* Set output info */
        for (i = 0; i < info->nb_outputs; i++) {
-               rte_memcpy(output[i].name, metadata->output1[i].output_name,
-                          MRVL_ML_OUTPUT_NAME_LEN);
-               output[i].dtype = metadata->output1[i].output_type;
-               output[i].qtype = metadata->output1[i].model_output_type;
-               output[i].shape.format = RTE_ML_IO_FORMAT_1D;
-               output[i].shape.w = metadata->output1[i].size;
-               output[i].shape.x = 1;
-               output[i].shape.y = 1;
-               output[i].shape.z = 1;
+               if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       rte_memcpy(output[i].name, 
metadata->output1[i].output_name,
+                                  MRVL_ML_OUTPUT_NAME_LEN);
+                       output[i].dtype = metadata->output1[i].output_type;
+                       output[i].qtype = 
metadata->output1[i].model_output_type;
+                       output[i].shape.format = RTE_ML_IO_FORMAT_1D;
+                       output[i].shape.w = metadata->output1[i].size;
+                       output[i].shape.x = 1;
+                       output[i].shape.y = 1;
+                       output[i].shape.z = 1;
+               } else {
+                       j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+                       rte_memcpy(output[i].name, 
metadata->output2[j].output_name,
+                                  MRVL_ML_OUTPUT_NAME_LEN);
+                       output[i].dtype = metadata->output2[j].output_type;
+                       output[i].qtype = 
metadata->output2[j].model_output_type;
+                       output[i].shape.format = RTE_ML_IO_FORMAT_1D;
+                       output[i].shape.w = metadata->output2[j].size;
+                       output[i].shape.x = 1;
+                       output[i].shape.y = 1;
+                       output[i].shape.z = 1;
+               }
        }
 }
diff --git a/drivers/ml/cnxk/cn10k_ml_model.h b/drivers/ml/cnxk/cn10k_ml_model.h
index bd863a8c12..5c34e4d747 100644
--- a/drivers/ml/cnxk/cn10k_ml_model.h
+++ b/drivers/ml/cnxk/cn10k_ml_model.h
@@ -30,6 +30,7 @@ enum cn10k_ml_model_state {
 #define MRVL_ML_OUTPUT_NAME_LEN           16
 #define MRVL_ML_NUM_INPUT_OUTPUT_1 8
 #define MRVL_ML_NUM_INPUT_OUTPUT_2 24
+#define MRVL_ML_NUM_INPUT_OUTPUT   (MRVL_ML_NUM_INPUT_OUTPUT_1 + 
MRVL_ML_NUM_INPUT_OUTPUT_2)
 
 /* Header (256-byte) */
 struct cn10k_ml_model_metadata_header {
@@ -413,7 +414,7 @@ struct cn10k_ml_model_addr {
 
                /* Quantized input size */
                uint32_t sz_q;
-       } input[MRVL_ML_NUM_INPUT_OUTPUT_1];
+       } input[MRVL_ML_NUM_INPUT_OUTPUT];
 
        /* Output address and size */
        struct {
@@ -425,7 +426,7 @@ struct cn10k_ml_model_addr {
 
                /* Quantized output size */
                uint32_t sz_q;
-       } output[MRVL_ML_NUM_INPUT_OUTPUT_1];
+       } output[MRVL_ML_NUM_INPUT_OUTPUT];
 
        /* Total size of quantized input */
        uint32_t total_input_sz_q;
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index aecc6e74ad..1033afb1b0 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -269,6 +269,7 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t 
model_id, FILE *fp)
        struct cn10k_ml_ocm *ocm;
        char str[STR_LEN];
        uint8_t i;
+       uint8_t j;
 
        mldev = dev->data->dev_private;
        ocm = &mldev->ocm;
@@ -324,16 +325,36 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t 
model_id, FILE *fp)
                "model_input_type", "quantize", "format");
        print_line(fp, LINE_LEN);
        for (i = 0; i < model->metadata.model.num_input; i++) {
-               fprintf(fp, "%8u  ", i);
-               fprintf(fp, "%*s  ", 16, model->metadata.input1[i].input_name);
-               rte_ml_io_type_to_str(model->metadata.input1[i].input_type, 
str, STR_LEN);
-               fprintf(fp, "%*s  ", 12, str);
-               
rte_ml_io_type_to_str(model->metadata.input1[i].model_input_type, str, STR_LEN);
-               fprintf(fp, "%*s  ", 18, str);
-               fprintf(fp, "%*s", 12, (model->metadata.input1[i].quantize == 1 
? "Yes" : "No"));
-               rte_ml_io_format_to_str(model->metadata.input1[i].shape.format, 
str, STR_LEN);
-               fprintf(fp, "%*s", 16, str);
-               fprintf(fp, "\n");
+               if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       fprintf(fp, "%8u  ", i);
+                       fprintf(fp, "%*s  ", 16, 
model->metadata.input1[i].input_name);
+                       
rte_ml_io_type_to_str(model->metadata.input1[i].input_type, str, STR_LEN);
+                       fprintf(fp, "%*s  ", 12, str);
+                       
rte_ml_io_type_to_str(model->metadata.input1[i].model_input_type, str,
+                                             STR_LEN);
+                       fprintf(fp, "%*s  ", 18, str);
+                       fprintf(fp, "%*s", 12,
+                               (model->metadata.input1[i].quantize == 1 ? 
"Yes" : "No"));
+                       
rte_ml_io_format_to_str(model->metadata.input1[i].shape.format, str,
+                                               STR_LEN);
+                       fprintf(fp, "%*s", 16, str);
+                       fprintf(fp, "\n");
+               } else {
+                       j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+                       fprintf(fp, "%8u  ", i);
+                       fprintf(fp, "%*s  ", 16, 
model->metadata.input2[j].input_name);
+                       
rte_ml_io_type_to_str(model->metadata.input2[j].input_type, str, STR_LEN);
+                       fprintf(fp, "%*s  ", 12, str);
+                       
rte_ml_io_type_to_str(model->metadata.input2[j].model_input_type, str,
+                                             STR_LEN);
+                       fprintf(fp, "%*s  ", 18, str);
+                       fprintf(fp, "%*s", 12,
+                               (model->metadata.input2[j].quantize == 1 ? 
"Yes" : "No"));
+                       
rte_ml_io_format_to_str(model->metadata.input2[j].shape.format, str,
+                                               STR_LEN);
+                       fprintf(fp, "%*s", 16, str);
+                       fprintf(fp, "\n");
+               }
        }
        fprintf(fp, "\n");
 
@@ -342,14 +363,30 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t 
model_id, FILE *fp)
                "model_output_type", "dequantize");
        print_line(fp, LINE_LEN);
        for (i = 0; i < model->metadata.model.num_output; i++) {
-               fprintf(fp, "%8u  ", i);
-               fprintf(fp, "%*s  ", 16, 
model->metadata.output1[i].output_name);
-               rte_ml_io_type_to_str(model->metadata.output1[i].output_type, 
str, STR_LEN);
-               fprintf(fp, "%*s  ", 12, str);
-               
rte_ml_io_type_to_str(model->metadata.output1[i].model_output_type, str, 
STR_LEN);
-               fprintf(fp, "%*s  ", 18, str);
-               fprintf(fp, "%*s", 12, (model->metadata.output1[i].dequantize 
== 1 ? "Yes" : "No"));
-               fprintf(fp, "\n");
+               if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       fprintf(fp, "%8u  ", i);
+                       fprintf(fp, "%*s  ", 16, 
model->metadata.output1[i].output_name);
+                       
rte_ml_io_type_to_str(model->metadata.output1[i].output_type, str, STR_LEN);
+                       fprintf(fp, "%*s  ", 12, str);
+                       
rte_ml_io_type_to_str(model->metadata.output1[i].model_output_type, str,
+                                             STR_LEN);
+                       fprintf(fp, "%*s  ", 18, str);
+                       fprintf(fp, "%*s", 12,
+                               (model->metadata.output1[i].dequantize == 1 ? 
"Yes" : "No"));
+                       fprintf(fp, "\n");
+               } else {
+                       j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+                       fprintf(fp, "%8u  ", i);
+                       fprintf(fp, "%*s  ", 16, 
model->metadata.output2[j].output_name);
+                       
rte_ml_io_type_to_str(model->metadata.output2[j].output_type, str, STR_LEN);
+                       fprintf(fp, "%*s  ", 12, str);
+                       
rte_ml_io_type_to_str(model->metadata.output2[j].model_output_type, str,
+                                             STR_LEN);
+                       fprintf(fp, "%*s  ", 18, str);
+                       fprintf(fp, "%*s", 12,
+                               (model->metadata.output2[j].dequantize == 1 ? 
"Yes" : "No"));
+                       fprintf(fp, "\n");
+               }
        }
        fprintf(fp, "\n");
        print_line(fp, LINE_LEN);
@@ -1863,10 +1900,14 @@ cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t 
model_id, uint16_t nb_batc
                     void *qbuffer)
 {
        struct cn10k_ml_model *model;
+       uint8_t model_input_type;
        uint8_t *lcl_dbuffer;
        uint8_t *lcl_qbuffer;
+       uint8_t input_type;
        uint32_t batch_id;
+       float qscale;
        uint32_t i;
+       uint32_t j;
        int ret;
 
        model = dev->data->models[model_id];
@@ -1882,28 +1923,38 @@ cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t 
model_id, uint16_t nb_batc
 
 next_batch:
        for (i = 0; i < model->metadata.model.num_input; i++) {
-               if (model->metadata.input1[i].input_type ==
-                   model->metadata.input1[i].model_input_type) {
+               if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       input_type = model->metadata.input1[i].input_type;
+                       model_input_type = 
model->metadata.input1[i].model_input_type;
+                       qscale = model->metadata.input1[i].qscale;
+               } else {
+                       j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+                       input_type = model->metadata.input2[j].input_type;
+                       model_input_type = 
model->metadata.input2[j].model_input_type;
+                       qscale = model->metadata.input2[j].qscale;
+               }
+
+               if (input_type == model_input_type) {
                        rte_memcpy(lcl_qbuffer, lcl_dbuffer, 
model->addr.input[i].sz_d);
                } else {
                        switch (model->metadata.input1[i].model_input_type) {
                        case RTE_ML_IO_TYPE_INT8:
-                               ret = 
rte_ml_io_float32_to_int8(model->metadata.input1[i].qscale,
+                               ret = rte_ml_io_float32_to_int8(qscale,
                                                                
model->addr.input[i].nb_elements,
                                                                lcl_dbuffer, 
lcl_qbuffer);
                                break;
                        case RTE_ML_IO_TYPE_UINT8:
-                               ret = 
rte_ml_io_float32_to_uint8(model->metadata.input1[i].qscale,
+                               ret = rte_ml_io_float32_to_uint8(qscale,
                                                                 
model->addr.input[i].nb_elements,
                                                                 lcl_dbuffer, 
lcl_qbuffer);
                                break;
                        case RTE_ML_IO_TYPE_INT16:
-                               ret = 
rte_ml_io_float32_to_int16(model->metadata.input1[i].qscale,
+                               ret = rte_ml_io_float32_to_int16(qscale,
                                                                 
model->addr.input[i].nb_elements,
                                                                 lcl_dbuffer, 
lcl_qbuffer);
                                break;
                        case RTE_ML_IO_TYPE_UINT16:
-                               ret = 
rte_ml_io_float32_to_uint16(model->metadata.input1[i].qscale,
+                               ret = rte_ml_io_float32_to_uint16(qscale,
                                                                  
model->addr.input[i].nb_elements,
                                                                  lcl_dbuffer, 
lcl_qbuffer);
                                break;
@@ -1936,10 +1987,14 @@ cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t 
model_id, uint16_t nb_ba
                       void *qbuffer, void *dbuffer)
 {
        struct cn10k_ml_model *model;
+       uint8_t model_output_type;
        uint8_t *lcl_qbuffer;
        uint8_t *lcl_dbuffer;
+       uint8_t output_type;
        uint32_t batch_id;
+       float dscale;
        uint32_t i;
+       uint32_t j;
        int ret;
 
        model = dev->data->models[model_id];
@@ -1955,28 +2010,38 @@ cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t 
model_id, uint16_t nb_ba
 
 next_batch:
        for (i = 0; i < model->metadata.model.num_output; i++) {
-               if (model->metadata.output1[i].output_type ==
-                   model->metadata.output1[i].model_output_type) {
+               if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       output_type = model->metadata.output1[i].output_type;
+                       model_output_type = 
model->metadata.output1[i].model_output_type;
+                       dscale = model->metadata.output1[i].dscale;
+               } else {
+                       j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+                       output_type = model->metadata.output2[j].output_type;
+                       model_output_type = 
model->metadata.output2[j].model_output_type;
+                       dscale = model->metadata.output2[j].dscale;
+               }
+
+               if (output_type == model_output_type) {
                        rte_memcpy(lcl_dbuffer, lcl_qbuffer, 
model->addr.output[i].sz_q);
                } else {
                        switch (model->metadata.output1[i].model_output_type) {
                        case RTE_ML_IO_TYPE_INT8:
-                               ret = 
rte_ml_io_int8_to_float32(model->metadata.output1[i].dscale,
+                               ret = rte_ml_io_int8_to_float32(dscale,
                                                                
model->addr.output[i].nb_elements,
                                                                lcl_qbuffer, 
lcl_dbuffer);
                                break;
                        case RTE_ML_IO_TYPE_UINT8:
-                               ret = 
rte_ml_io_uint8_to_float32(model->metadata.output1[i].dscale,
+                               ret = rte_ml_io_uint8_to_float32(dscale,
                                                                 
model->addr.output[i].nb_elements,
                                                                 lcl_qbuffer, 
lcl_dbuffer);
                                break;
                        case RTE_ML_IO_TYPE_INT16:
-                               ret = 
rte_ml_io_int16_to_float32(model->metadata.output1[i].dscale,
+                               ret = rte_ml_io_int16_to_float32(dscale,
                                                                 
model->addr.output[i].nb_elements,
                                                                 lcl_qbuffer, 
lcl_dbuffer);
                                break;
                        case RTE_ML_IO_TYPE_UINT16:
-                               ret = 
rte_ml_io_uint16_to_float32(model->metadata.output1[i].dscale,
+                               ret = rte_ml_io_uint16_to_float32(dscale,
                                                                  
model->addr.output[i].nb_elements,
                                                                  lcl_qbuffer, 
lcl_dbuffer);
                                break;
-- 
2.17.1

Reply via email to