Updated rte_ml_io_info to support shape of arbitrary
number of dimensions. Dropped use of rte_ml_io_shape
and rte_ml_io_format. Introduced new fields nb_elements
and size in rte_ml_io_info.

Updated drivers and app/mldev to support the changes.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
---
 app/test-mldev/test_inference_common.c | 97 +++++---------------------
 drivers/ml/cnxk/cn10k_ml_model.c       | 78 +++++++++++++--------
 drivers/ml/cnxk/cn10k_ml_model.h       | 12 ++++
 drivers/ml/cnxk/cn10k_ml_ops.c         | 11 +--
 lib/mldev/mldev_utils.c                | 30 --------
 lib/mldev/mldev_utils.h                | 16 -----
 lib/mldev/rte_mldev.h                  | 59 ++++------------
 lib/mldev/version.map                  |  1 -
 8 files changed, 94 insertions(+), 210 deletions(-)

diff --git a/app/test-mldev/test_inference_common.c 
b/app/test-mldev/test_inference_common.c
index 05b221401b..b40519b5e3 100644
--- a/app/test-mldev/test_inference_common.c
+++ b/app/test-mldev/test_inference_common.c
@@ -3,6 +3,7 @@
  */
 
 #include <errno.h>
+#include <math.h>
 #include <stdio.h>
 #include <unistd.h>
 
@@ -18,11 +19,6 @@
 #include "ml_common.h"
 #include "test_inference_common.h"
 
-#define ML_TEST_READ_TYPE(buffer, type) (*((type *)buffer))
-
-#define ML_TEST_CHECK_OUTPUT(output, reference, tolerance) \
-       (((float)output - (float)reference) <= (((float)reference * tolerance) 
/ 100.0))
-
 #define ML_OPEN_WRITE_GET_ERR(name, buffer, size, err) \
        do { \
                FILE *fp = fopen(name, "w+"); \
@@ -763,9 +759,9 @@ ml_inference_validation(struct ml_test *test, struct 
ml_request *req)
 {
        struct test_inference *t = ml_test_priv((struct ml_test *)test);
        struct ml_model *model;
-       uint32_t nb_elements;
-       uint8_t *reference;
-       uint8_t *output;
+       float *reference;
+       float *output;
+       float deviation;
        bool match;
        uint32_t i;
        uint32_t j;
@@ -777,89 +773,30 @@ ml_inference_validation(struct ml_test *test, struct 
ml_request *req)
                match = (rte_hash_crc(model->output, model->out_dsize, 0) ==
                         rte_hash_crc(model->reference, model->out_dsize, 0));
        } else {
-               output = model->output;
-               reference = model->reference;
+               output = (float *)model->output;
+               reference = (float *)model->reference;
 
                i = 0;
 next_output:
-               nb_elements =
-                       model->info.output_info[i].shape.w * 
model->info.output_info[i].shape.x *
-                       model->info.output_info[i].shape.y * 
model->info.output_info[i].shape.z;
                j = 0;
 next_element:
                match = false;
-               switch (model->info.output_info[i].dtype) {
-               case RTE_ML_IO_TYPE_INT8:
-                       if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, 
int8_t),
-                                                ML_TEST_READ_TYPE(reference, 
int8_t),
-                                                t->cmn.opt->tolerance))
-                               match = true;
-
-                       output += sizeof(int8_t);
-                       reference += sizeof(int8_t);
-                       break;
-               case RTE_ML_IO_TYPE_UINT8:
-                       if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, 
uint8_t),
-                                                ML_TEST_READ_TYPE(reference, 
uint8_t),
-                                                t->cmn.opt->tolerance))
-                               match = true;
-
-                       output += sizeof(float);
-                       reference += sizeof(float);
-                       break;
-               case RTE_ML_IO_TYPE_INT16:
-                       if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, 
int16_t),
-                                                ML_TEST_READ_TYPE(reference, 
int16_t),
-                                                t->cmn.opt->tolerance))
-                               match = true;
-
-                       output += sizeof(int16_t);
-                       reference += sizeof(int16_t);
-                       break;
-               case RTE_ML_IO_TYPE_UINT16:
-                       if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, 
uint16_t),
-                                                ML_TEST_READ_TYPE(reference, 
uint16_t),
-                                                t->cmn.opt->tolerance))
-                               match = true;
-
-                       output += sizeof(uint16_t);
-                       reference += sizeof(uint16_t);
-                       break;
-               case RTE_ML_IO_TYPE_INT32:
-                       if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, 
int32_t),
-                                                ML_TEST_READ_TYPE(reference, 
int32_t),
-                                                t->cmn.opt->tolerance))
-                               match = true;
-
-                       output += sizeof(int32_t);
-                       reference += sizeof(int32_t);
-                       break;
-               case RTE_ML_IO_TYPE_UINT32:
-                       if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, 
uint32_t),
-                                                ML_TEST_READ_TYPE(reference, 
uint32_t),
-                                                t->cmn.opt->tolerance))
-                               match = true;
-
-                       output += sizeof(uint32_t);
-                       reference += sizeof(uint32_t);
-                       break;
-               case RTE_ML_IO_TYPE_FP32:
-                       if (ML_TEST_CHECK_OUTPUT(ML_TEST_READ_TYPE(output, 
float),
-                                                ML_TEST_READ_TYPE(reference, 
float),
-                                                t->cmn.opt->tolerance))
-                               match = true;
-
-                       output += sizeof(float);
-                       reference += sizeof(float);
-                       break;
-               default: /* other types, fp8, fp16, bfloat16 */
+               deviation =
+                       (*reference == 0 ? 0 : 100 * fabs(*output - *reference) 
/ fabs(*reference));
+               if (deviation <= t->cmn.opt->tolerance)
                        match = true;
-               }
+               else
+                       ml_err("id = %d, element = %d, output = %f, reference = 
%f, deviation = %f %%\n",
+                              i, j, *output, *reference, deviation);
+
+               output++;
+               reference++;
 
                if (!match)
                        goto done;
+
                j++;
-               if (j < nb_elements)
+               if (j < model->info.output_info[i].nb_elements)
                        goto next_element;
 
                i++;
diff --git a/drivers/ml/cnxk/cn10k_ml_model.c b/drivers/ml/cnxk/cn10k_ml_model.c
index 92c47d39ba..26df8d9ff9 100644
--- a/drivers/ml/cnxk/cn10k_ml_model.c
+++ b/drivers/ml/cnxk/cn10k_ml_model.c
@@ -366,6 +366,12 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, 
uint8_t *buffer, uint8_
        addr->total_input_sz_q = 0;
        for (i = 0; i < metadata->model.num_input; i++) {
                if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       addr->input[i].nb_dims = 4;
+                       addr->input[i].shape[0] = metadata->input1[i].shape.w;
+                       addr->input[i].shape[1] = metadata->input1[i].shape.x;
+                       addr->input[i].shape[2] = metadata->input1[i].shape.y;
+                       addr->input[i].shape[3] = metadata->input1[i].shape.z;
+
                        addr->input[i].nb_elements =
                                metadata->input1[i].shape.w * 
metadata->input1[i].shape.x *
                                metadata->input1[i].shape.y * 
metadata->input1[i].shape.z;
@@ -386,6 +392,13 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, 
uint8_t *buffer, uint8_
                                addr->input[i].sz_q);
                } else {
                        j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+
+                       addr->input[i].nb_dims = 4;
+                       addr->input[i].shape[0] = metadata->input2[j].shape.w;
+                       addr->input[i].shape[1] = metadata->input2[j].shape.x;
+                       addr->input[i].shape[2] = metadata->input2[j].shape.y;
+                       addr->input[i].shape[3] = metadata->input2[j].shape.z;
+
                        addr->input[i].nb_elements =
                                metadata->input2[j].shape.w * 
metadata->input2[j].shape.x *
                                metadata->input2[j].shape.y * 
metadata->input2[j].shape.z;
@@ -412,6 +425,8 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, 
uint8_t *buffer, uint8_
        addr->total_output_sz_d = 0;
        for (i = 0; i < metadata->model.num_output; i++) {
                if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
+                       addr->output[i].nb_dims = 1;
+                       addr->output[i].shape[0] = metadata->output1[i].size;
                        addr->output[i].nb_elements = metadata->output1[i].size;
                        addr->output[i].sz_d =
                                addr->output[i].nb_elements *
@@ -426,6 +441,9 @@ cn10k_ml_model_addr_update(struct cn10k_ml_model *model, 
uint8_t *buffer, uint8_
                                   model->model_id, i, addr->output[i].sz_d, 
addr->output[i].sz_q);
                } else {
                        j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+
+                       addr->output[i].nb_dims = 1;
+                       addr->output[i].shape[0] = metadata->output2[j].size;
                        addr->output[i].nb_elements = metadata->output2[j].size;
                        addr->output[i].sz_d =
                                addr->output[i].nb_elements *
@@ -498,6 +516,7 @@ void
 cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct cn10k_ml_model *model)
 {
        struct cn10k_ml_model_metadata *metadata;
+       struct cn10k_ml_model_addr *addr;
        struct rte_ml_model_info *info;
        struct rte_ml_io_info *output;
        struct rte_ml_io_info *input;
@@ -508,6 +527,7 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct 
cn10k_ml_model *model)
        info = PLT_PTR_CAST(model->info);
        input = PLT_PTR_ADD(info, sizeof(struct rte_ml_model_info));
        output = PLT_PTR_ADD(input, metadata->model.num_input * sizeof(struct 
rte_ml_io_info));
+       addr = &model->addr;
 
        /* Set model info */
        memset(info, 0, sizeof(struct rte_ml_model_info));
@@ -529,24 +549,25 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct 
cn10k_ml_model *model)
                if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
                        rte_memcpy(input[i].name, 
metadata->input1[i].input_name,
                                   MRVL_ML_INPUT_NAME_LEN);
-                       input[i].dtype = metadata->input1[i].input_type;
-                       input[i].qtype = metadata->input1[i].model_input_type;
-                       input[i].shape.format = 
metadata->input1[i].shape.format;
-                       input[i].shape.w = metadata->input1[i].shape.w;
-                       input[i].shape.x = metadata->input1[i].shape.x;
-                       input[i].shape.y = metadata->input1[i].shape.y;
-                       input[i].shape.z = metadata->input1[i].shape.z;
+                       input[i].nb_dims = addr->input[i].nb_dims;
+                       input[i].shape = addr->input[i].shape;
+                       input[i].type = metadata->input1[i].model_input_type;
+                       input[i].nb_elements = addr->input[i].nb_elements;
+                       input[i].size =
+                               addr->input[i].nb_elements *
+                               
rte_ml_io_type_size_get(metadata->input1[i].model_input_type);
                } else {
                        j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+
                        rte_memcpy(input[i].name, 
metadata->input2[j].input_name,
                                   MRVL_ML_INPUT_NAME_LEN);
-                       input[i].dtype = metadata->input2[j].input_type;
-                       input[i].qtype = metadata->input2[j].model_input_type;
-                       input[i].shape.format = 
metadata->input2[j].shape.format;
-                       input[i].shape.w = metadata->input2[j].shape.w;
-                       input[i].shape.x = metadata->input2[j].shape.x;
-                       input[i].shape.y = metadata->input2[j].shape.y;
-                       input[i].shape.z = metadata->input2[j].shape.z;
+                       input[i].nb_dims = addr->input[i].nb_dims;
+                       input[i].shape = addr->input[i].shape;
+                       input[i].type = metadata->input2[j].model_input_type;
+                       input[i].nb_elements = addr->input[i].nb_elements;
+                       input[i].size =
+                               addr->input[i].nb_elements *
+                               
rte_ml_io_type_size_get(metadata->input2[j].model_input_type);
                }
        }
 
@@ -555,24 +576,25 @@ cn10k_ml_model_info_set(struct rte_ml_dev *dev, struct 
cn10k_ml_model *model)
                if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
                        rte_memcpy(output[i].name, 
metadata->output1[i].output_name,
                                   MRVL_ML_OUTPUT_NAME_LEN);
-                       output[i].dtype = metadata->output1[i].output_type;
-                       output[i].qtype = 
metadata->output1[i].model_output_type;
-                       output[i].shape.format = RTE_ML_IO_FORMAT_1D;
-                       output[i].shape.w = metadata->output1[i].size;
-                       output[i].shape.x = 1;
-                       output[i].shape.y = 1;
-                       output[i].shape.z = 1;
+                       output[i].nb_dims = addr->output[i].nb_dims;
+                       output[i].shape = addr->output[i].shape;
+                       output[i].type = metadata->output1[i].model_output_type;
+                       output[i].nb_elements = addr->output[i].nb_elements;
+                       output[i].size =
+                               addr->output[i].nb_elements *
+                               
rte_ml_io_type_size_get(metadata->output1[i].model_output_type);
                } else {
                        j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+
                        rte_memcpy(output[i].name, 
metadata->output2[j].output_name,
                                   MRVL_ML_OUTPUT_NAME_LEN);
-                       output[i].dtype = metadata->output2[j].output_type;
-                       output[i].qtype = 
metadata->output2[j].model_output_type;
-                       output[i].shape.format = RTE_ML_IO_FORMAT_1D;
-                       output[i].shape.w = metadata->output2[j].size;
-                       output[i].shape.x = 1;
-                       output[i].shape.y = 1;
-                       output[i].shape.z = 1;
+                       output[i].nb_dims = addr->output[i].nb_dims;
+                       output[i].shape = addr->output[i].shape;
+                       output[i].type = metadata->output2[j].model_output_type;
+                       output[i].nb_elements = addr->output[i].nb_elements;
+                       output[i].size =
+                               addr->output[i].nb_elements *
+                               
rte_ml_io_type_size_get(metadata->output2[j].model_output_type);
                }
        }
 }
diff --git a/drivers/ml/cnxk/cn10k_ml_model.h b/drivers/ml/cnxk/cn10k_ml_model.h
index 1f689363fc..4cc0744891 100644
--- a/drivers/ml/cnxk/cn10k_ml_model.h
+++ b/drivers/ml/cnxk/cn10k_ml_model.h
@@ -409,6 +409,12 @@ struct cn10k_ml_model_addr {
 
        /* Input address and size */
        struct {
+               /* Number of dimensions in shape */
+               uint32_t nb_dims;
+
+               /* Shape of input */
+               uint32_t shape[4];
+
                /* Number of elements */
                uint32_t nb_elements;
 
@@ -421,6 +427,12 @@ struct cn10k_ml_model_addr {
 
        /* Output address and size */
        struct {
+               /* Number of dimensions in shape */
+               uint32_t nb_dims;
+
+               /* Shape of input */
+               uint32_t shape[4];
+
                /* Number of elements */
                uint32_t nb_elements;
 
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index 656467d891..e3faab81ba 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -321,8 +321,8 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t 
model_id, FILE *fp)
        fprintf(fp, "\n");
 
        print_line(fp, LINE_LEN);
-       fprintf(fp, "%8s  %16s  %12s  %18s  %12s  %14s\n", "input", 
"input_name", "input_type",
-               "model_input_type", "quantize", "format");
+       fprintf(fp, "%8s  %16s  %12s  %18s  %12s\n", "input", "input_name", 
"input_type",
+               "model_input_type", "quantize");
        print_line(fp, LINE_LEN);
        for (i = 0; i < model->metadata.model.num_input; i++) {
                if (i < MRVL_ML_NUM_INPUT_OUTPUT_1) {
@@ -335,12 +335,10 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t 
model_id, FILE *fp)
                        fprintf(fp, "%*s  ", 18, str);
                        fprintf(fp, "%*s", 12,
                                (model->metadata.input1[i].quantize == 1 ? 
"Yes" : "No"));
-                       
rte_ml_io_format_to_str(model->metadata.input1[i].shape.format, str,
-                                               STR_LEN);
-                       fprintf(fp, "%*s", 16, str);
                        fprintf(fp, "\n");
                } else {
                        j = i - MRVL_ML_NUM_INPUT_OUTPUT_1;
+
                        fprintf(fp, "%8u  ", i);
                        fprintf(fp, "%*s  ", 16, 
model->metadata.input2[j].input_name);
                        
rte_ml_io_type_to_str(model->metadata.input2[j].input_type, str, STR_LEN);
@@ -350,9 +348,6 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t 
model_id, FILE *fp)
                        fprintf(fp, "%*s  ", 18, str);
                        fprintf(fp, "%*s", 12,
                                (model->metadata.input2[j].quantize == 1 ? 
"Yes" : "No"));
-                       
rte_ml_io_format_to_str(model->metadata.input2[j].shape.format, str,
-                                               STR_LEN);
-                       fprintf(fp, "%*s", 16, str);
                        fprintf(fp, "\n");
                }
        }
diff --git a/lib/mldev/mldev_utils.c b/lib/mldev/mldev_utils.c
index d2442b123b..ccd2c39ca8 100644
--- a/lib/mldev/mldev_utils.c
+++ b/lib/mldev/mldev_utils.c
@@ -86,33 +86,3 @@ rte_ml_io_type_to_str(enum rte_ml_io_type type, char *str, 
int len)
                rte_strlcpy(str, "invalid", len);
        }
 }
-
-void
-rte_ml_io_format_to_str(enum rte_ml_io_format format, char *str, int len)
-{
-       switch (format) {
-       case RTE_ML_IO_FORMAT_NCHW:
-               rte_strlcpy(str, "NCHW", len);
-               break;
-       case RTE_ML_IO_FORMAT_NHWC:
-               rte_strlcpy(str, "NHWC", len);
-               break;
-       case RTE_ML_IO_FORMAT_CHWN:
-               rte_strlcpy(str, "CHWN", len);
-               break;
-       case RTE_ML_IO_FORMAT_3D:
-               rte_strlcpy(str, "3D", len);
-               break;
-       case RTE_ML_IO_FORMAT_2D:
-               rte_strlcpy(str, "Matrix", len);
-               break;
-       case RTE_ML_IO_FORMAT_1D:
-               rte_strlcpy(str, "Vector", len);
-               break;
-       case RTE_ML_IO_FORMAT_SCALAR:
-               rte_strlcpy(str, "Scalar", len);
-               break;
-       default:
-               rte_strlcpy(str, "invalid", len);
-       }
-}
diff --git a/lib/mldev/mldev_utils.h b/lib/mldev/mldev_utils.h
index 5bc8020453..220afb42f0 100644
--- a/lib/mldev/mldev_utils.h
+++ b/lib/mldev/mldev_utils.h
@@ -52,22 +52,6 @@ __rte_internal
 void
 rte_ml_io_type_to_str(enum rte_ml_io_type type, char *str, int len);
 
-/**
- * @internal
- *
- * Get the name of an ML IO format.
- *
- * @param[in] type
- *     Enumeration of ML IO format.
- * @param[in] str
- *     Address of character array.
- * @param[in] len
- *     Length of character array.
- */
-__rte_internal
-void
-rte_ml_io_format_to_str(enum rte_ml_io_format format, char *str, int len);
-
 /**
  * @internal
  *
diff --git a/lib/mldev/rte_mldev.h b/lib/mldev/rte_mldev.h
index fc3525c1ab..6204df0930 100644
--- a/lib/mldev/rte_mldev.h
+++ b/lib/mldev/rte_mldev.h
@@ -863,47 +863,6 @@ enum rte_ml_io_type {
        /**< 16-bit brain floating point number. */
 };
 
-/**
- * Input and output format. This is used to represent the encoding type of 
multi-dimensional
- * used by ML models.
- */
-enum rte_ml_io_format {
-       RTE_ML_IO_FORMAT_NCHW = 1,
-       /**< Batch size (N) x channels (C) x height (H) x width (W) */
-       RTE_ML_IO_FORMAT_NHWC,
-       /**< Batch size (N) x height (H) x width (W) x channels (C) */
-       RTE_ML_IO_FORMAT_CHWN,
-       /**< Channels (C) x height (H) x width (W) x batch size (N) */
-       RTE_ML_IO_FORMAT_3D,
-       /**< Format to represent a 3 dimensional data */
-       RTE_ML_IO_FORMAT_2D,
-       /**< Format to represent matrix data */
-       RTE_ML_IO_FORMAT_1D,
-       /**< Format to represent vector data */
-       RTE_ML_IO_FORMAT_SCALAR,
-       /**< Format to represent scalar data */
-};
-
-/**
- * Input and output shape. This structure represents the encoding format and 
dimensions
- * of the tensor or vector.
- *
- * The data can be a 4D / 3D tensor, matrix, vector or a scalar. Number of 
dimensions used
- * for the data would depend on the format. Unused dimensions to be set to 1.
- */
-struct rte_ml_io_shape {
-       enum rte_ml_io_format format;
-       /**< Format of the data */
-       uint32_t w;
-       /**< First dimension */
-       uint32_t x;
-       /**< Second dimension */
-       uint32_t y;
-       /**< Third dimension */
-       uint32_t z;
-       /**< Fourth dimension */
-};
-
 /** Input and output data information structure
  *
  * Specifies the type and shape of input and output data.
@@ -911,12 +870,18 @@ struct rte_ml_io_shape {
 struct rte_ml_io_info {
        char name[RTE_ML_STR_MAX];
        /**< Name of data */
-       struct rte_ml_io_shape shape;
-       /**< Shape of data */
-       enum rte_ml_io_type qtype;
-       /**< Type of quantized data */
-       enum rte_ml_io_type dtype;
-       /**< Type of de-quantized data */
+       uint32_t nb_dims;
+       /**< Number of dimensions in shape */
+       uint32_t *shape;
+       /**< Shape of the tensor */
+       enum rte_ml_io_type type;
+       /**< Type of data
+        * @see enum rte_ml_io_type
+        */
+       uint64_t nb_elements;
+       /** Number of elements in tensor */
+       uint64_t size;
+       /** Size of tensor in bytes */
 };
 
 /** Model information structure */
diff --git a/lib/mldev/version.map b/lib/mldev/version.map
index 0706b565be..40ff27f4b9 100644
--- a/lib/mldev/version.map
+++ b/lib/mldev/version.map
@@ -51,7 +51,6 @@ INTERNAL {
 
        rte_ml_io_type_size_get;
        rte_ml_io_type_to_str;
-       rte_ml_io_format_to_str;
        rte_ml_io_float32_to_int8;
        rte_ml_io_int8_to_float32;
        rte_ml_io_float32_to_uint8;
-- 
2.41.0

Reply via email to