From: Prince Takkar <ptak...@marvell.com>

Added support for quantize and dequantize callback
functions for TVM models.

Signed-off-by: Prince Takkar <ptak...@marvell.com>
---
 drivers/ml/cnxk/mvtvm_ml_ops.c | 129 +++++++++++++++++++++++++++++++++
 drivers/ml/cnxk/mvtvm_ml_ops.h |   4 +
 2 files changed, 133 insertions(+)

diff --git a/drivers/ml/cnxk/mvtvm_ml_ops.c b/drivers/ml/cnxk/mvtvm_ml_ops.c
index b627355917..776675843a 100644
--- a/drivers/ml/cnxk/mvtvm_ml_ops.c
+++ b/drivers/ml/cnxk/mvtvm_ml_ops.c
@@ -2,11 +2,15 @@
  * Copyright (c) 2023 Marvell.
  */
 
+#include <dlpack/dlpack.h>
+
 #include <rte_common.h>
 #include <rte_cycles.h>
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
+#include <mldev_utils.h>
+
 #include "cnxk_ml_dev.h"
 #include "cnxk_ml_model.h"
 #include "cnxk_ml_ops.h"
@@ -236,6 +240,8 @@ mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct 
rte_ml_model_params *
                callback->tvmrt_io_free = cn10k_ml_io_free;
                callback->tvmrt_malloc = cn10k_ml_malloc;
                callback->tvmrt_free = cn10k_ml_free;
+               callback->tvmrt_quantize = mvtvm_ml_io_quantize;
+               callback->tvmrt_dequantize = mvtvm_ml_io_dequantize;
        } else {
                callback = NULL;
        }
@@ -366,3 +372,126 @@ mvtvm_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, 
struct cnxk_ml_model *model)
 
        return 0;
 }
+
+int
+mvtvm_ml_io_quantize(void *device, uint16_t model_id, const char *layer_name,
+                    const DLTensor **deq_tensor, void *qbuffer)
+{
+       struct cnxk_ml_io_info *info = NULL;
+       struct cnxk_ml_dev *cnxk_mldev;
+       struct cnxk_ml_model *model;
+       uint16_t layer_id = 0;
+       uint8_t *lcl_dbuffer;
+       uint8_t *lcl_qbuffer;
+       uint32_t i;
+       int ret;
+
+#ifdef CNXK_ML_DEV_DEBUG
+       if ((device == NULL) || (deq_tensor == NULL) || (qbuffer == NULL))
+               return -EINVAL;
+#endif
+
+       cnxk_mldev = (struct cnxk_ml_dev *)device;
+
+       model = cnxk_mldev->mldev->data->models[model_id];
+#ifdef CNXK_ML_DEV_DEBUG
+       if (model == NULL) {
+               plt_err("Invalid model_id = %u", model_id);
+               return -EINVAL;
+       }
+#endif
+
+       /* Get layer id */
+       for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; 
layer_id++) {
+               if (strcmp(model->layer[layer_id].name, layer_name) == 0)
+                       break;
+       }
+
+#ifdef CNXK_ML_DEV_DEBUG
+       if (layer_id == model->mvtvm.metadata.model.nb_layers) {
+               plt_err("Invalid layer name: %s", layer_name);
+               return -EINVAL;
+       }
+
+       if (model->layer[layer_id].type != ML_CNXK_LAYER_TYPE_MRVL) {
+               plt_err("Invalid layer name / type: %s", layer_name);
+               return -EINVAL;
+       }
+#endif
+
+       info = &model->layer[layer_id].info;
+       lcl_qbuffer = (uint8_t *)qbuffer;
+
+       for (i = 0; i < info->nb_inputs; i++) {
+               lcl_dbuffer = PLT_PTR_ADD(deq_tensor[i]->data, 
deq_tensor[i]->byte_offset);
+
+               ret = cnxk_ml_io_quantize_single(&info->input[i], lcl_dbuffer, 
lcl_qbuffer);
+               if (ret < 0)
+                       return ret;
+
+               lcl_qbuffer += info->input[i].sz_q;
+       }
+
+       return 0;
+}
+
+int
+mvtvm_ml_io_dequantize(void *device, uint16_t model_id, const char 
*layer_name, void *qbuffer,
+                      const DLTensor **deq_tensor)
+{
+       struct cnxk_ml_io_info *info = NULL;
+       struct cnxk_ml_dev *cnxk_mldev;
+       struct cnxk_ml_model *model;
+       uint16_t layer_id = 0;
+       uint8_t *lcl_dbuffer;
+       uint8_t *lcl_qbuffer;
+       uint32_t i;
+       int ret;
+
+#ifdef CNXK_ML_DEV_DEBUG
+       if ((device == NULL) || (deq_tensor == NULL) || (qbuffer == NULL))
+               return -EINVAL;
+#endif
+
+       cnxk_mldev = (struct cnxk_ml_dev *)device;
+
+       model = cnxk_mldev->mldev->data->models[model_id];
+#ifdef CNXK_ML_DEV_DEBUG
+       if (model == NULL) {
+               plt_err("Invalid model_id = %u", model_id);
+               return -EINVAL;
+       }
+#endif
+
+       for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; 
layer_id++) {
+               if (strcmp(model->layer[layer_id].name, layer_name) == 0)
+                       break;
+       }
+
+#ifdef CNXK_ML_DEV_DEBUG
+       if (layer_id == model->mvtvm.metadata.model.nb_layers) {
+               plt_err("Invalid layer name: %s", layer_name);
+               return -EINVAL;
+       }
+
+       if (model->layer[layer_id].type != ML_CNXK_LAYER_TYPE_MRVL) {
+               plt_err("Invalid layer name / type: %s", layer_name);
+               return -EINVAL;
+       }
+#endif
+
+       info = &model->layer[layer_id].info;
+       lcl_qbuffer = (uint8_t *)qbuffer;
+
+       for (i = 0; i < info->nb_outputs; i++) {
+               lcl_dbuffer = PLT_PTR_ADD(deq_tensor[i]->data, 
deq_tensor[i]->byte_offset);
+
+               ret = cnxk_ml_io_dequantize_single(&info->output[i], 
lcl_qbuffer, lcl_dbuffer);
+               if (ret < 0)
+                       return ret;
+
+               lcl_qbuffer += info->output[i].sz_q;
+       }
+
+       return 0;
+}
diff --git a/drivers/ml/cnxk/mvtvm_ml_ops.h b/drivers/ml/cnxk/mvtvm_ml_ops.h
index 22e0340146..4cabe30a82 100644
--- a/drivers/ml/cnxk/mvtvm_ml_ops.h
+++ b/drivers/ml/cnxk/mvtvm_ml_ops.h
@@ -24,6 +24,10 @@ int mvtvm_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, 
struct rte_ml_model_para
 int mvtvm_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model);
 int mvtvm_ml_model_start(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model);
 int mvtvm_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model);
+int mvtvm_ml_io_quantize(void *device, uint16_t model_id, const char 
*layer_name,
+                        const DLTensor **deq_tensor, void *qbuffer);
+int mvtvm_ml_io_dequantize(void *device, uint16_t model_id, const char 
*layer_name, void *qbuffer,
+                          const DLTensor **deq_tensor);
 
 void mvtvm_ml_model_xstat_name_set(struct cnxk_ml_dev *cnxk_mldev, struct 
cnxk_ml_model *model,
                                   uint16_t stat_id, uint16_t entry, char 
*suffix);
-- 
2.42.0

Reply via email to