Implemented driver functions to quantize / dequantize input
and output data. Support is enabled for multiple batches.
Quantization / dequantization use the type conversion functions
defined in ML common code.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c | 151 +++++++++++++++++++++++++++++++++
 1 file changed, 151 insertions(+)

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index b5c89bee40..231c9b340b 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -5,6 +5,8 @@
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
+#include <mldev_utils.h>
+
 #include "cn10k_ml_dev.h"
 #include "cn10k_ml_model.h"
 #include "cn10k_ml_ops.h"
@@ -983,6 +985,153 @@ cn10k_ml_io_output_size_get(struct rte_ml_dev *dev, 
uint16_t model_id, uint32_t
        return 0;
 }
 
+static int
+cn10k_ml_io_quantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t 
nb_batches, void *dbuffer,
+                    void *qbuffer)
+{
+       struct cn10k_ml_model *model;
+       uint8_t *lcl_dbuffer;
+       uint8_t *lcl_qbuffer;
+       uint32_t batch_id;
+       uint32_t i;
+       int ret;
+
+       model = dev->data->models[model_id];
+
+       if (model == NULL) {
+               plt_err("Invalid model_id = %u", model_id);
+               return -EINVAL;
+       }
+
+       lcl_dbuffer = dbuffer;
+       lcl_qbuffer = qbuffer;
+       batch_id = 0;
+
+next_batch:
+       for (i = 0; i < model->metadata.model.num_input; i++) {
+               if (model->metadata.input[i].input_type ==
+                   model->metadata.input[i].model_input_type) {
+                       rte_memcpy(lcl_qbuffer, lcl_dbuffer, 
model->addr.input[i].sz_d);
+               } else {
+                       switch (model->metadata.input[i].model_input_type) {
+                       case RTE_ML_IO_TYPE_INT8:
+                               ret = 
rte_ml_io_float32_to_int8(model->metadata.input[i].qscale,
+                                                               
model->addr.input[i].nb_elements,
+                                                               lcl_dbuffer, 
lcl_qbuffer);
+                               break;
+                       case RTE_ML_IO_TYPE_UINT8:
+                               ret = 
rte_ml_io_float32_to_uint8(model->metadata.input[i].qscale,
+                                                                
model->addr.input[i].nb_elements,
+                                                                lcl_dbuffer, 
lcl_qbuffer);
+                               break;
+                       case RTE_ML_IO_TYPE_INT16:
+                               ret = 
rte_ml_io_float32_to_int16(model->metadata.input[i].qscale,
+                                                                
model->addr.input[i].nb_elements,
+                                                                lcl_dbuffer, 
lcl_qbuffer);
+                               break;
+                       case RTE_ML_IO_TYPE_UINT16:
+                               ret = 
rte_ml_io_float32_to_uint16(model->metadata.input[i].qscale,
+                                                                 
model->addr.input[i].nb_elements,
+                                                                 lcl_dbuffer, 
lcl_qbuffer);
+                               break;
+                       case RTE_ML_IO_TYPE_FP16:
+                               ret = 
rte_ml_io_float32_to_float16(model->addr.input[i].nb_elements,
+                                                                  lcl_dbuffer, 
lcl_qbuffer);
+                               break;
+                       default:
+                               plt_err("Unsupported model_input_type[%u] : 
%u", i,
+                                       
model->metadata.input[i].model_input_type);
+                               ret = -ENOTSUP;
+                       }
+                       if (ret < 0)
+                               return ret;
+               }
+
+               lcl_dbuffer += model->addr.input[i].sz_d;
+               lcl_qbuffer += model->addr.input[i].sz_q;
+       }
+
+       batch_id++;
+       if (batch_id < PLT_DIV_CEIL(nb_batches, model->batch_size))
+               goto next_batch;
+
+       return 0;
+}
+
+static int
+cn10k_ml_io_dequantize(struct rte_ml_dev *dev, uint16_t model_id, uint16_t 
nb_batches,
+                      void *qbuffer, void *dbuffer)
+{
+       struct cn10k_ml_model *model;
+       uint8_t *lcl_qbuffer;
+       uint8_t *lcl_dbuffer;
+       uint32_t batch_id;
+       uint32_t i;
+       int ret;
+
+       model = dev->data->models[model_id];
+
+       if (model == NULL) {
+               plt_err("Invalid model_id = %u", model_id);
+               return -EINVAL;
+       }
+
+       lcl_dbuffer = dbuffer;
+       lcl_qbuffer = qbuffer;
+       batch_id = 0;
+
+next_batch:
+       for (i = 0; i < model->metadata.model.num_output; i++) {
+               if (model->metadata.output[i].output_type ==
+                   model->metadata.output[i].model_output_type) {
+                       rte_memcpy(lcl_dbuffer, lcl_qbuffer, 
model->addr.output[i].sz_q);
+               } else {
+                       switch (model->metadata.output[i].model_output_type) {
+                       case RTE_ML_IO_TYPE_INT8:
+                               ret = 
rte_ml_io_int8_to_float32(model->metadata.output[i].dscale,
+                                                               
model->addr.output[i].nb_elements,
+                                                               lcl_qbuffer, 
lcl_dbuffer);
+                               break;
+                       case RTE_ML_IO_TYPE_UINT8:
+                               ret = 
rte_ml_io_uint8_to_float32(model->metadata.output[i].dscale,
+                                                                
model->addr.output[i].nb_elements,
+                                                                lcl_qbuffer, 
lcl_dbuffer);
+                               break;
+                       case RTE_ML_IO_TYPE_INT16:
+                               ret = 
rte_ml_io_int16_to_float32(model->metadata.output[i].dscale,
+                                                                
model->addr.output[i].nb_elements,
+                                                                lcl_qbuffer, 
lcl_dbuffer);
+                               break;
+                       case RTE_ML_IO_TYPE_UINT16:
+                               ret = 
rte_ml_io_uint16_to_float32(model->metadata.output[i].dscale,
+                                                                 
model->addr.output[i].nb_elements,
+                                                                 lcl_qbuffer, 
lcl_dbuffer);
+                               break;
+                       case RTE_ML_IO_TYPE_FP16:
+                               ret = rte_ml_io_float16_to_float32(
+                                       model->addr.output[i].nb_elements, 
lcl_qbuffer,
+                                       lcl_dbuffer);
+                               break;
+                       default:
+                               plt_err("Unsupported model_output_type[%u] : 
%u", i,
+                                       
model->metadata.output[i].model_output_type);
+                               ret = -ENOTSUP;
+                       }
+                       if (ret < 0)
+                               return ret;
+               }
+
+               lcl_qbuffer += model->addr.output[i].sz_q;
+               lcl_dbuffer += model->addr.output[i].sz_d;
+       }
+
+       batch_id++;
+       if (batch_id < PLT_DIV_CEIL(nb_batches, model->batch_size))
+               goto next_batch;
+
+       return 0;
+}
+
 struct rte_ml_dev_ops cn10k_ml_ops = {
        /* Device control ops */
        .dev_info_get = cn10k_ml_dev_info_get,
@@ -1006,4 +1155,6 @@ struct rte_ml_dev_ops cn10k_ml_ops = {
        /* I/O ops */
        .io_input_size_get = cn10k_ml_io_input_size_get,
        .io_output_size_get = cn10k_ml_io_output_size_get,
+       .io_quantize = cn10k_ml_io_quantize,
+       .io_dequantize = cn10k_ml_io_dequantize,
 };
-- 
2.17.1

Reply via email to