Added driver functions to get input and output buffer sizes
for a given batch size. This function would compute the buffer
size based on specific requirements of the device.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c | 52 ++++++++++++++++++++++++++++++++++
 1 file changed, 52 insertions(+)

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index 92bf1a0854..b5c89bee40 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -935,6 +935,54 @@ cn10k_ml_model_params_update(struct rte_ml_dev *dev, 
uint16_t model_id, void *bu
        return 0;
 }
 
+static int
+cn10k_ml_io_input_size_get(struct rte_ml_dev *dev, uint16_t model_id, uint32_t 
nb_batches,
+                          uint64_t *input_qsize, uint64_t *input_dsize)
+{
+       struct cn10k_ml_model *model;
+
+       model = dev->data->models[model_id];
+
+       if (model == NULL) {
+               plt_err("Invalid model_id = %u", model_id);
+               return -EINVAL;
+       }
+
+       if (input_qsize != NULL)
+               *input_qsize = PLT_U64_CAST(model->addr.total_input_sz_q *
+                                           PLT_DIV_CEIL(nb_batches, 
model->batch_size));
+
+       if (input_dsize != NULL)
+               *input_dsize = PLT_U64_CAST(model->addr.total_input_sz_d *
+                                           PLT_DIV_CEIL(nb_batches, 
model->batch_size));
+
+       return 0;
+}
+
+static int
+cn10k_ml_io_output_size_get(struct rte_ml_dev *dev, uint16_t model_id, 
uint32_t nb_batches,
+                           uint64_t *output_qsize, uint64_t *output_dsize)
+{
+       struct cn10k_ml_model *model;
+
+       model = dev->data->models[model_id];
+
+       if (model == NULL) {
+               plt_err("Invalid model_id = %u", model_id);
+               return -EINVAL;
+       }
+
+       if (output_qsize != NULL)
+               *output_qsize = PLT_U64_CAST(model->addr.total_output_sz_q *
+                                            PLT_DIV_CEIL(nb_batches, 
model->batch_size));
+
+       if (output_dsize != NULL)
+               *output_dsize = PLT_U64_CAST(model->addr.total_output_sz_d *
+                                            PLT_DIV_CEIL(nb_batches, 
model->batch_size));
+
+       return 0;
+}
+
 struct rte_ml_dev_ops cn10k_ml_ops = {
        /* Device control ops */
        .dev_info_get = cn10k_ml_dev_info_get,
@@ -954,4 +1002,8 @@ struct rte_ml_dev_ops cn10k_ml_ops = {
        .model_stop = cn10k_ml_model_stop,
        .model_info_get = cn10k_ml_model_info_get,
        .model_params_update = cn10k_ml_model_params_update,
+
+       /* I/O ops */
+       .io_input_size_get = cn10k_ml_io_input_size_get,
+       .io_output_size_get = cn10k_ml_io_output_size_get,
 };
-- 
2.17.1

Reply via email to