Implemented cnxk wrapper functions to start and stop
ML models. Wrapper functions would invoke the cn10k
model start and stop functions.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ocm.c |  28 ++--
 drivers/ml/cnxk/cn10k_ml_ocm.h |  12 +-
 drivers/ml/cnxk/cn10k_ml_ops.c | 282 ++++++++++++++++++++-------------
 drivers/ml/cnxk/cn10k_ml_ops.h |   8 +-
 drivers/ml/cnxk/cnxk_ml_ops.c  |  48 +++++-
 drivers/ml/cnxk/cnxk_ml_ops.h  |   1 +
 6 files changed, 240 insertions(+), 139 deletions(-)

diff --git a/drivers/ml/cnxk/cn10k_ml_ocm.c b/drivers/ml/cnxk/cn10k_ml_ocm.c
index d71c36eae6..2197e5e0ed 100644
--- a/drivers/ml/cnxk/cn10k_ml_ocm.c
+++ b/drivers/ml/cnxk/cn10k_ml_ocm.c
@@ -215,11 +215,10 @@ cn10k_ml_ocm_tilecount(uint64_t tilemask, int *start, int 
*end)
  * scratch & WB pages and OCM allocation mode.
  */
 int
-cn10k_ml_ocm_tilemask_find(struct rte_ml_dev *dev, uint8_t num_tiles, uint16_t 
wb_pages,
+cn10k_ml_ocm_tilemask_find(struct cnxk_ml_dev *cnxk_mldev, uint8_t num_tiles, 
uint16_t wb_pages,
                           uint16_t scratch_pages, uint64_t *tilemask)
 {
        struct cn10k_ml_dev *cn10k_mldev;
-       struct cnxk_ml_dev *cnxk_mldev;
        struct cn10k_ml_ocm *ocm;
 
        uint16_t used_scratch_pages_max;
@@ -238,7 +237,6 @@ cn10k_ml_ocm_tilemask_find(struct rte_ml_dev *dev, uint8_t 
num_tiles, uint16_t w
        int max_slot_sz;
        int page_id;
 
-       cnxk_mldev = dev->data->dev_private;
        cn10k_mldev = &cnxk_mldev->cn10k_mldev;
        ocm = &cn10k_mldev->ocm;
 
@@ -333,12 +331,10 @@ cn10k_ml_ocm_tilemask_find(struct rte_ml_dev *dev, 
uint8_t num_tiles, uint16_t w
 }
 
 void
-cn10k_ml_ocm_reserve_pages(struct rte_ml_dev *dev, uint16_t model_id, uint16_t 
layer_id,
+cn10k_ml_ocm_reserve_pages(struct cnxk_ml_dev *cnxk_mldev, uint16_t model_id, 
uint16_t layer_id,
                           uint64_t tilemask, int wb_page_start, uint16_t 
wb_pages,
                           uint16_t scratch_pages)
 {
-       struct cn10k_ml_dev *cn10k_mldev;
-       struct cnxk_ml_dev *cnxk_mldev;
        struct cnxk_ml_model *model;
        struct cnxk_ml_layer *layer;
        struct cn10k_ml_ocm *ocm;
@@ -351,10 +347,8 @@ cn10k_ml_ocm_reserve_pages(struct rte_ml_dev *dev, 
uint16_t model_id, uint16_t l
        int tile_id;
        int page_id;
 
-       cnxk_mldev = dev->data->dev_private;
-       cn10k_mldev = &cnxk_mldev->cn10k_mldev;
-       ocm = &cn10k_mldev->ocm;
-       model = dev->data->models[model_id];
+       ocm = &cnxk_mldev->cn10k_mldev.ocm;
+       model = cnxk_mldev->mldev->data->models[model_id];
        layer = &model->layer[layer_id];
 
        /* Get first set bit, tile_start */
@@ -396,12 +390,10 @@ cn10k_ml_ocm_reserve_pages(struct rte_ml_dev *dev, 
uint16_t model_id, uint16_t l
 }
 
 void
-cn10k_ml_ocm_free_pages(struct rte_ml_dev *dev, uint16_t model_id, uint16_t 
layer_id)
+cn10k_ml_ocm_free_pages(struct cnxk_ml_dev *cnxk_mldev, uint16_t model_id, 
uint16_t layer_id)
 {
        struct cnxk_ml_model *local_model;
        struct cnxk_ml_layer *local_layer;
-       struct cn10k_ml_dev *cn10k_mldev;
-       struct cnxk_ml_dev *cnxk_mldev;
        struct cnxk_ml_model *model;
        struct cnxk_ml_layer *layer;
        struct cn10k_ml_ocm *ocm;
@@ -416,10 +408,8 @@ cn10k_ml_ocm_free_pages(struct rte_ml_dev *dev, uint16_t 
model_id, uint16_t laye
        uint16_t i;
        uint16_t j;
 
-       cnxk_mldev = dev->data->dev_private;
-       cn10k_mldev = &cnxk_mldev->cn10k_mldev;
-       ocm = &cn10k_mldev->ocm;
-       model = dev->data->models[model_id];
+       ocm = &cnxk_mldev->cn10k_mldev.ocm;
+       model = cnxk_mldev->mldev->data->models[model_id];
        layer = &model->layer[layer_id];
 
        /* Update OCM info for WB memory */
@@ -438,8 +428,8 @@ cn10k_ml_ocm_free_pages(struct rte_ml_dev *dev, uint16_t 
model_id, uint16_t laye
 
                /* Get max scratch pages required, excluding the current model 
*/
                scratch_resize_pages = 0;
-               for (i = 0; i < dev->data->nb_models; i++) {
-                       local_model = dev->data->models[i];
+               for (i = 0; i < cnxk_mldev->mldev->data->nb_models; i++) {
+                       local_model = cnxk_mldev->mldev->data->models[i];
                        if (local_model == NULL)
                                continue;
 
diff --git a/drivers/ml/cnxk/cn10k_ml_ocm.h b/drivers/ml/cnxk/cn10k_ml_ocm.h
index 720f8caf76..97b723a56a 100644
--- a/drivers/ml/cnxk/cn10k_ml_ocm.h
+++ b/drivers/ml/cnxk/cn10k_ml_ocm.h
@@ -8,6 +8,8 @@
 #include <rte_mldev.h>
 #include <rte_mldev_pmd.h>
 
+struct cnxk_ml_dev;
+
 /* Number of OCM tiles. */
 #define ML_CN10K_OCM_NUMTILES 0x8
 
@@ -75,12 +77,12 @@ struct cn10k_ml_ocm {
 };
 
 int cn10k_ml_ocm_tilecount(uint64_t tilemask, int *start, int *end);
-int cn10k_ml_ocm_tilemask_find(struct rte_ml_dev *dev, uint8_t num_tiles, 
uint16_t wb_pages,
+int cn10k_ml_ocm_tilemask_find(struct cnxk_ml_dev *cnxk_mldev, uint8_t 
num_tiles, uint16_t wb_pages,
                               uint16_t scratch_pages, uint64_t *tilemask);
-void cn10k_ml_ocm_reserve_pages(struct rte_ml_dev *dev, uint16_t model_id, 
uint16_t layer_id,
-                               uint64_t tilemask, int wb_page_start, uint16_t 
wb_pages,
-                               uint16_t scratch_pages);
-void cn10k_ml_ocm_free_pages(struct rte_ml_dev *dev, uint16_t model_id, 
uint16_t layer_id);
+void cn10k_ml_ocm_reserve_pages(struct cnxk_ml_dev *cnxk_mldev, uint16_t 
model_id,
+                               uint16_t layer_id, uint64_t tilemask, int 
wb_page_start,
+                               uint16_t wb_pages, uint16_t scratch_pages);
+void cn10k_ml_ocm_free_pages(struct cnxk_ml_dev *cnxk_mldev, uint16_t 
model_id, uint16_t layer_id);
 void cn10k_ml_ocm_print(struct rte_ml_dev *dev, FILE *fp);
 
 #endif /* _CN10K_ML_OCM_H_ */
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index ab05896b5e..40f484158a 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -248,26 +248,28 @@ cn10k_ml_model_print(struct rte_ml_dev *dev, uint16_t 
model_id, FILE *fp)
 }
 
 static void
-cn10k_ml_prep_sp_job_descriptor(struct cn10k_ml_dev *cn10k_mldev, struct 
cnxk_ml_model *model,
+cn10k_ml_prep_sp_job_descriptor(struct cnxk_ml_dev *cnxk_mldev, struct 
cnxk_ml_layer *layer,
                                struct cnxk_ml_req *req, enum cn10k_ml_job_type 
job_type)
 {
        struct cn10k_ml_model_metadata *metadata;
        struct cn10k_ml_layer_addr *addr;
+       struct cn10k_ml_dev *cn10k_mldev;
 
-       metadata = &model->glow.metadata;
-       addr = &model->layer[0].glow.addr;
+       cn10k_mldev = &cnxk_mldev->cn10k_mldev;
+       metadata = &layer->glow.metadata;
+       addr = &layer->glow.addr;
 
        memset(&req->cn10k_req.jd, 0, sizeof(struct cn10k_ml_jd));
        req->cn10k_req.jd.hdr.jce.w0.u64 = 0;
        req->cn10k_req.jd.hdr.jce.w1.u64 = PLT_U64_CAST(&req->cn10k_req.status);
-       req->cn10k_req.jd.hdr.model_id = model->model_id;
+       req->cn10k_req.jd.hdr.model_id = layer->index;
        req->cn10k_req.jd.hdr.job_type = job_type;
        req->cn10k_req.jd.hdr.fp_flags = 0x0;
        req->cn10k_req.jd.hdr.result =
                roc_ml_addr_ap2mlip(&cn10k_mldev->roc, &req->cn10k_req.result);
 
        if (job_type == ML_CN10K_JOB_TYPE_MODEL_START) {
-               if (!model->glow.metadata.model.ocm_relocatable)
+               if (!layer->glow.metadata.model.ocm_relocatable)
                        req->cn10k_req.jd.hdr.sp_flags = 
ML_CN10K_SP_FLAGS_OCM_NONRELOCATABLE;
                else
                        req->cn10k_req.jd.hdr.sp_flags = 0x0;
@@ -291,7 +293,7 @@ cn10k_ml_prep_sp_job_descriptor(struct cn10k_ml_dev 
*cn10k_mldev, struct cnxk_ml
                req->cn10k_req.jd.model_start.num_gather_entries = 0;
                req->cn10k_req.jd.model_start.num_scatter_entries = 0;
                req->cn10k_req.jd.model_start.tilemask = 0; /* Updated after 
reserving pages */
-               req->cn10k_req.jd.model_start.batch_size = model->batch_size;
+               req->cn10k_req.jd.model_start.batch_size = layer->batch_size;
                req->cn10k_req.jd.model_start.ocm_wb_base_address =
                        0; /* Updated after reserving pages */
                req->cn10k_req.jd.model_start.ocm_wb_range_start =
@@ -323,9 +325,13 @@ cn10k_ml_prep_sp_job_descriptor(struct cn10k_ml_dev 
*cn10k_mldev, struct cnxk_ml
 }
 
 static __rte_always_inline void
-cn10k_ml_prep_fp_job_descriptor(struct cn10k_ml_dev *cn10k_mldev, struct 
cnxk_ml_req *req,
+cn10k_ml_prep_fp_job_descriptor(struct cnxk_ml_dev *cnxk_mldev, struct 
cnxk_ml_req *req,
                                struct rte_ml_op *op)
 {
+       struct cn10k_ml_dev *cn10k_mldev;
+
+       cn10k_mldev = &cnxk_mldev->cn10k_mldev;
+
        req->cn10k_req.jd.hdr.jce.w0.u64 = 0;
        req->cn10k_req.jd.hdr.jce.w1.u64 = PLT_U64_CAST(req->status);
        req->cn10k_req.jd.hdr.model_id = op->model_id;
@@ -714,10 +720,8 @@ cn10k_ml_model_xstats_reset(struct rte_ml_dev *dev, 
int32_t model_id, const uint
 }
 
 static int
-cn10k_ml_cache_model_data(struct rte_ml_dev *dev, uint16_t model_id)
+cn10k_ml_cache_model_data(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_layer 
*layer)
 {
-       struct rte_ml_model_info *info;
-       struct cnxk_ml_model *model;
        struct rte_ml_buff_seg seg[2];
        struct rte_ml_buff_seg *inp;
        struct rte_ml_buff_seg *out;
@@ -730,22 +734,20 @@ cn10k_ml_cache_model_data(struct rte_ml_dev *dev, 
uint16_t model_id)
        int ret = 0;
        uint32_t i;
 
-       model = dev->data->models[model_id];
-       info = (struct rte_ml_model_info *)model->info;
        inp = &seg[0];
        out = &seg[1];
 
        /* Create input and output buffers. */
-       for (i = 0; i < info->nb_inputs; i++)
-               isize += info->input_info[i].size;
+       for (i = 0; i < layer->info.nb_inputs; i++)
+               isize += layer->info.input[i].sz_q;
 
-       for (i = 0; i < info->nb_outputs; i++)
-               osize += info->output_info[i].size;
+       for (i = 0; i < layer->info.nb_outputs; i++)
+               osize += layer->info.output[i].sz_q;
 
-       isize = model->batch_size * isize;
-       osize = model->batch_size * osize;
+       isize = layer->batch_size * isize;
+       osize = layer->batch_size * osize;
 
-       snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", "ml_dummy_io", model_id);
+       snprintf(str, RTE_MEMZONE_NAMESIZE, "%s_%u", "ml_dummy_io", 
layer->index);
        mz = plt_memzone_reserve_aligned(str, isize + osize, 0, 
ML_CN10K_ALIGN_SIZE);
        if (mz == NULL)
                return -ENOMEM;
@@ -761,15 +763,15 @@ cn10k_ml_cache_model_data(struct rte_ml_dev *dev, 
uint16_t model_id)
        seg[1].length = osize;
        seg[1].next = NULL;
 
-       op.model_id = model_id;
-       op.nb_batches = model->batch_size;
+       op.model_id = layer->index;
+       op.nb_batches = layer->batch_size;
        op.mempool = NULL;
 
        op.input = &inp;
        op.output = &out;
 
-       memset(model->layer[0].glow.req, 0, sizeof(struct cnxk_ml_req));
-       ret = cn10k_ml_inference_sync(dev, &op);
+       memset(layer->glow.req, 0, sizeof(struct cnxk_ml_req));
+       ret = cn10k_ml_inference_sync(cnxk_mldev, &op);
        plt_memzone_free(mz);
 
        return ret;
@@ -1506,14 +1508,16 @@ cn10k_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, 
struct cnxk_ml_model *mode
 }
 
 int
-cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id)
+cn10k_ml_layer_start(void *device, uint16_t model_id, const char *layer_name)
 {
        struct cn10k_ml_dev *cn10k_mldev;
        struct cnxk_ml_dev *cnxk_mldev;
        struct cnxk_ml_model *model;
+       struct cnxk_ml_layer *layer;
        struct cn10k_ml_ocm *ocm;
        struct cnxk_ml_req *req;
 
+       uint16_t layer_id = 0;
        bool job_enqueued;
        bool job_dequeued;
        uint8_t num_tiles;
@@ -1524,85 +1528,89 @@ cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t 
model_id)
        bool locked;
        int ret = 0;
 
-       cnxk_mldev = dev->data->dev_private;
-       cn10k_mldev = &cnxk_mldev->cn10k_mldev;
-       ocm = &cn10k_mldev->ocm;
-       model = dev->data->models[model_id];
+       PLT_SET_USED(layer_name);
 
+       cnxk_mldev = (struct cnxk_ml_dev *)device;
+       if (cnxk_mldev == NULL) {
+               plt_err("Invalid device = %p", device);
+               return -EINVAL;
+       }
+
+       model = cnxk_mldev->mldev->data->models[model_id];
        if (model == NULL) {
                plt_err("Invalid model_id = %u", model_id);
                return -EINVAL;
        }
 
+       layer = &model->layer[layer_id];
+       cn10k_mldev = &cnxk_mldev->cn10k_mldev;
+       ocm = &cn10k_mldev->ocm;
+
        /* Prepare JD */
-       req = model->layer[0].glow.req;
-       cn10k_ml_prep_sp_job_descriptor(cn10k_mldev, model, req, 
ML_CN10K_JOB_TYPE_MODEL_START);
+       req = layer->glow.req;
+       cn10k_ml_prep_sp_job_descriptor(cnxk_mldev, layer, req, 
ML_CN10K_JOB_TYPE_MODEL_START);
        req->cn10k_req.result.error_code = 0x0;
        req->cn10k_req.result.user_ptr = NULL;
 
        plt_write64(ML_CNXK_POLL_JOB_START, &req->cn10k_req.status);
        plt_wmb();
 
-       num_tiles = model->layer[0].glow.metadata.model.tile_end -
-                   model->layer[0].glow.metadata.model.tile_start + 1;
+       num_tiles = layer->glow.metadata.model.tile_end - 
layer->glow.metadata.model.tile_start + 1;
 
        locked = false;
        while (!locked) {
                if (plt_spinlock_trylock(&model->lock) != 0) {
-                       if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
-                               plt_ml_dbg("Model already started, model = 
0x%016lx",
-                                          PLT_U64_CAST(model));
+                       if (layer->state == ML_CNXK_LAYER_STATE_STARTED) {
+                               plt_ml_dbg("Layer already started, model_id = 
%u, layer_id = %u",
+                                          model->model_id, layer_id);
                                plt_spinlock_unlock(&model->lock);
                                return 1;
                        }
 
-                       if (model->state == ML_CNXK_MODEL_STATE_JOB_ACTIVE) {
-                               plt_err("A slow-path job is active for the 
model = 0x%016lx",
-                                       PLT_U64_CAST(model));
+                       if (layer->state == ML_CNXK_LAYER_STATE_JOB_ACTIVE) {
+                               plt_err("A slow-path job is active for the 
model_id = %u",
+                                       model->model_id);
                                plt_spinlock_unlock(&model->lock);
                                return -EBUSY;
                        }
 
-                       model->state = ML_CNXK_MODEL_STATE_JOB_ACTIVE;
+                       layer->state = ML_CNXK_LAYER_STATE_JOB_ACTIVE;
                        plt_spinlock_unlock(&model->lock);
                        locked = true;
                }
        }
 
-       while (!model->layer[0].glow.ocm_map.ocm_reserved) {
+       while (!layer->glow.ocm_map.ocm_reserved) {
                if (plt_spinlock_trylock(&ocm->lock) != 0) {
                        wb_page_start = cn10k_ml_ocm_tilemask_find(
-                               dev, num_tiles, 
model->layer[0].glow.ocm_map.wb_pages,
-                               model->layer[0].glow.ocm_map.scratch_pages, 
&tilemask);
+                               cnxk_mldev, num_tiles, 
layer->glow.ocm_map.wb_pages,
+                               layer->glow.ocm_map.scratch_pages, &tilemask);
 
                        if (wb_page_start == -1) {
                                plt_err("Free pages not available on OCM 
tiles");
-                               plt_err("Failed to start model = 0x%016lx, name 
= %s",
-                                       PLT_U64_CAST(model),
-                                       
model->layer[0].glow.metadata.model.name);
-
+                               plt_err("Failed to start layer, model_id = %u, 
layer_id = %u",
+                                       model->model_id, layer_id);
                                plt_spinlock_unlock(&ocm->lock);
                                return -ENOMEM;
                        }
 
-                       model->layer[0].glow.ocm_map.tilemask = tilemask;
-                       model->layer[0].glow.ocm_map.wb_page_start = 
wb_page_start;
+                       layer->glow.ocm_map.tilemask = tilemask;
+                       layer->glow.ocm_map.wb_page_start = wb_page_start;
 
-                       cn10k_ml_ocm_reserve_pages(dev, model->model_id, 0,
-                                                  
model->layer[0].glow.ocm_map.tilemask,
-                                                  
model->layer[0].glow.ocm_map.wb_page_start,
-                                                  
model->layer[0].glow.ocm_map.wb_pages,
-                                                  
model->layer[0].glow.ocm_map.scratch_pages);
-                       model->layer[0].glow.ocm_map.ocm_reserved = true;
+                       cn10k_ml_ocm_reserve_pages(
+                               cnxk_mldev, model->model_id, layer_id, 
layer->glow.ocm_map.tilemask,
+                               layer->glow.ocm_map.wb_page_start, 
layer->glow.ocm_map.wb_pages,
+                               layer->glow.ocm_map.scratch_pages);
+                       layer->glow.ocm_map.ocm_reserved = true;
                        plt_spinlock_unlock(&ocm->lock);
                }
        }
 
        /* Update JD */
-       cn10k_ml_ocm_tilecount(model->layer[0].glow.ocm_map.tilemask, 
&tile_start, &tile_end);
+       cn10k_ml_ocm_tilecount(layer->glow.ocm_map.tilemask, &tile_start, 
&tile_end);
        req->cn10k_req.jd.model_start.tilemask = GENMASK_ULL(tile_end, 
tile_start);
        req->cn10k_req.jd.model_start.ocm_wb_base_address =
-               model->layer[0].glow.ocm_map.wb_page_start * ocm->page_size;
+               layer->glow.ocm_map.wb_page_start * ocm->page_size;
 
        job_enqueued = false;
        job_dequeued = false;
@@ -1636,66 +1644,94 @@ cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t 
model_id)
        locked = false;
        while (!locked) {
                if (plt_spinlock_trylock(&model->lock) != 0) {
-                       if (ret == 0) {
-                               model->state = ML_CNXK_MODEL_STATE_STARTED;
-                               cnxk_mldev->nb_models_started++;
-                       } else {
-                               model->state = ML_CNXK_MODEL_STATE_UNKNOWN;
-                       }
+                       if (ret == 0)
+                               layer->state = ML_CNXK_LAYER_STATE_STARTED;
+                       else
+                               layer->state = ML_CNXK_LAYER_STATE_UNKNOWN;
 
                        plt_spinlock_unlock(&model->lock);
                        locked = true;
                }
        }
 
-       if (model->state == ML_CNXK_MODEL_STATE_UNKNOWN) {
-               while (model->layer[0].glow.ocm_map.ocm_reserved) {
+       if (layer->state == ML_CNXK_LAYER_STATE_UNKNOWN) {
+               while (layer->glow.ocm_map.ocm_reserved) {
                        if (plt_spinlock_trylock(&ocm->lock) != 0) {
-                               cn10k_ml_ocm_free_pages(dev, model->model_id, 
0);
-                               model->layer[0].glow.ocm_map.ocm_reserved = 
false;
-                               model->layer[0].glow.ocm_map.tilemask = 0x0;
+                               cn10k_ml_ocm_free_pages(cnxk_mldev, 
model->model_id, layer_id);
+                               layer->glow.ocm_map.ocm_reserved = false;
+                               layer->glow.ocm_map.tilemask = 0x0;
                                plt_spinlock_unlock(&ocm->lock);
                        }
                }
        }
 
-       if (ret < 0) { /* Call unload to update model and FW state, ignore 
error */
-               rte_ml_model_stop(dev->data->dev_id, model_id);
+       if (ret < 0) {
+               cn10k_ml_layer_stop(device, model_id, layer_name);
        } else {
-               if (cn10k_mldev->cache_model_data && roc_model_is_cn10ka())
-                       ret = cn10k_ml_cache_model_data(dev, model_id);
+               if (cn10k_mldev->cache_model_data)
+                       ret = cn10k_ml_cache_model_data(cnxk_mldev, layer);
        }
 
        return ret;
 }
 
 int
-cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id)
+cn10k_ml_model_start(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model)
+{
+       struct cnxk_ml_layer *layer;
+       int ret;
+
+       layer = &model->layer[0];
+       ret = cn10k_ml_layer_start(cnxk_mldev, model->model_id, layer->name);
+       if (ret != 0) {
+               plt_err("CN10K Model start failed, model_id = %u, error = %d", 
model->model_id,
+                       ret);
+               return ret;
+       }
+
+       cnxk_mldev->nb_models_started++;
+       model->state = ML_CNXK_MODEL_STATE_STARTED;
+
+       return 0;
+}
+
+int
+cn10k_ml_layer_stop(void *device, uint16_t model_id, const char *layer_name)
 {
        struct cn10k_ml_dev *cn10k_mldev;
        struct cnxk_ml_dev *cnxk_mldev;
        struct cnxk_ml_model *model;
+       struct cnxk_ml_layer *layer;
        struct cn10k_ml_ocm *ocm;
        struct cnxk_ml_req *req;
 
+       uint16_t layer_id = 0;
        bool job_enqueued;
        bool job_dequeued;
        bool locked;
        int ret = 0;
 
-       cnxk_mldev = dev->data->dev_private;
-       cn10k_mldev = &cnxk_mldev->cn10k_mldev;
-       ocm = &cn10k_mldev->ocm;
-       model = dev->data->models[model_id];
+       PLT_SET_USED(layer_name);
+
+       cnxk_mldev = (struct cnxk_ml_dev *)device;
+       if (cnxk_mldev == NULL) {
+               plt_err("Invalid device = %p", device);
+               return -EINVAL;
+       }
 
+       model = cnxk_mldev->mldev->data->models[model_id];
        if (model == NULL) {
                plt_err("Invalid model_id = %u", model_id);
                return -EINVAL;
        }
 
+       layer = &model->layer[layer_id];
+       cn10k_mldev = &cnxk_mldev->cn10k_mldev;
+       ocm = &cn10k_mldev->ocm;
+
        /* Prepare JD */
-       req = model->layer[0].glow.req;
-       cn10k_ml_prep_sp_job_descriptor(cn10k_mldev, model, req, 
ML_CN10K_JOB_TYPE_MODEL_STOP);
+       req = layer->glow.req;
+       cn10k_ml_prep_sp_job_descriptor(cnxk_mldev, layer, req, 
ML_CN10K_JOB_TYPE_MODEL_STOP);
        req->cn10k_req.result.error_code = 0x0;
        req->cn10k_req.result.user_ptr = NULL;
 
@@ -1705,31 +1741,31 @@ cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t 
model_id)
        locked = false;
        while (!locked) {
                if (plt_spinlock_trylock(&model->lock) != 0) {
-                       if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
-                               plt_ml_dbg("Model not started, model = 
0x%016lx",
-                                          PLT_U64_CAST(model));
+                       if (layer->state == ML_CNXK_LAYER_STATE_LOADED) {
+                               plt_ml_dbg("Layer not started, model_id = %u, 
layer_id = %u",
+                                          model->model_id, layer_id);
                                plt_spinlock_unlock(&model->lock);
                                return 1;
                        }
 
-                       if (model->state == ML_CNXK_MODEL_STATE_JOB_ACTIVE) {
-                               plt_err("A slow-path job is active for the 
model = 0x%016lx",
-                                       PLT_U64_CAST(model));
+                       if (layer->state == ML_CNXK_LAYER_STATE_JOB_ACTIVE) {
+                               plt_err("A slow-path job is active for the 
layer, model_id = %u, layer_id = %u",
+                                       model->model_id, layer_id);
                                plt_spinlock_unlock(&model->lock);
                                return -EBUSY;
                        }
 
-                       model->state = ML_CNXK_MODEL_STATE_JOB_ACTIVE;
+                       layer->state = ML_CNXK_LAYER_STATE_JOB_ACTIVE;
                        plt_spinlock_unlock(&model->lock);
                        locked = true;
                }
        }
 
-       while (model->layer[0].glow.ocm_map.ocm_reserved) {
+       while (layer->glow.ocm_map.ocm_reserved) {
                if (plt_spinlock_trylock(&ocm->lock) != 0) {
-                       cn10k_ml_ocm_free_pages(dev, model->model_id, 0);
-                       model->layer[0].glow.ocm_map.ocm_reserved = false;
-                       model->layer[0].glow.ocm_map.tilemask = 0x0;
+                       cn10k_ml_ocm_free_pages(cnxk_mldev, model->model_id, 
layer_id);
+                       layer->glow.ocm_map.ocm_reserved = false;
+                       layer->glow.ocm_map.tilemask = 0x0;
                        plt_spinlock_unlock(&ocm->lock);
                }
        }
@@ -1766,8 +1802,11 @@ cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t 
model_id)
        locked = false;
        while (!locked) {
                if (plt_spinlock_trylock(&model->lock) != 0) {
-                       cnxk_mldev->nb_models_stopped++;
-                       model->state = ML_CNXK_MODEL_STATE_LOADED;
+                       if (ret == 0)
+                               layer->state = ML_CNXK_LAYER_STATE_LOADED;
+                       else
+                               layer->state = ML_CNXK_LAYER_STATE_UNKNOWN;
+
                        plt_spinlock_unlock(&model->lock);
                        locked = true;
                }
@@ -1776,6 +1815,25 @@ cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t 
model_id)
        return ret;
 }
 
+int
+cn10k_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model)
+{
+       struct cnxk_ml_layer *layer;
+       int ret;
+
+       layer = &model->layer[0];
+       ret = cn10k_ml_layer_stop(cnxk_mldev, model->model_id, layer->name);
+       if (ret != 0) {
+               plt_err("CN10K Model stop failed, model_id = %u, error = %d", 
model->model_id, ret);
+               return ret;
+       }
+
+       cnxk_mldev->nb_models_stopped++;
+       model->state = ML_CNXK_MODEL_STATE_LOADED;
+
+       return 0;
+}
+
 int
 cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
                        struct rte_ml_model_info *model_info)
@@ -2003,30 +2061,35 @@ queue_free_count(uint64_t head, uint64_t tail, uint64_t 
nb_desc)
 }
 
 static __rte_always_inline void
-cn10k_ml_result_update(struct rte_ml_dev *dev, int qp_id, struct cnxk_ml_req 
*req)
+cn10k_ml_result_update(struct cnxk_ml_dev *cnxk_mldev, int qp_id, struct 
cnxk_ml_req *req)
 {
        union cn10k_ml_error_code *error_code;
        struct cn10k_ml_layer_xstats *xstats;
        struct cn10k_ml_dev *cn10k_mldev;
-       struct cnxk_ml_dev *cnxk_mldev;
        struct cn10k_ml_result *result;
        struct cnxk_ml_model *model;
+       struct cnxk_ml_layer *layer;
        struct cnxk_ml_qp *qp;
        struct rte_ml_op *op;
        uint64_t hw_latency;
        uint64_t fw_latency;
+       uint16_t model_id;
+       uint16_t layer_id;
 
        result = &req->cn10k_req.result;
        op = req->op;
 
        if (likely(result->error_code == 0)) {
-               model = dev->data->models[op->model_id];
+               model_id = cnxk_mldev->index_map[op->model_id].model_id;
+               layer_id = cnxk_mldev->index_map[op->model_id].layer_id;
+               model = cnxk_mldev->mldev->data->models[model_id];
+               layer = &model->layer[layer_id];
                if (likely(qp_id >= 0)) {
-                       qp = dev->data->queue_pairs[qp_id];
+                       qp = cnxk_mldev->mldev->data->queue_pairs[qp_id];
                        qp->stats.dequeued_count++;
-                       xstats = &model->layer[0].glow.burst_xstats[qp_id];
+                       xstats = &layer->glow.burst_xstats[qp_id];
                } else {
-                       xstats = model->layer[0].glow.sync_xstats;
+                       xstats = layer->glow.sync_xstats;
                }
 
                if (unlikely(xstats->dequeued_count == xstats->hw_reset_count)) 
{
@@ -2054,14 +2117,13 @@ cn10k_ml_result_update(struct rte_ml_dev *dev, int 
qp_id, struct cnxk_ml_req *re
                op->status = RTE_ML_OP_STATUS_SUCCESS;
        } else {
                if (likely(qp_id >= 0)) {
-                       qp = dev->data->queue_pairs[qp_id];
+                       qp = cnxk_mldev->mldev->data->queue_pairs[qp_id];
                        qp->stats.dequeue_err_count++;
                }
 
                /* Handle driver error */
                error_code = (union cn10k_ml_error_code *)&result->error_code;
                if (error_code->s.etype == ML_ETYPE_DRIVER) {
-                       cnxk_mldev = dev->data->dev_private;
                        cn10k_mldev = &cnxk_mldev->cn10k_mldev;
 
                        /* Check for exception */
@@ -2116,7 +2178,7 @@ cn10k_ml_enqueue_burst(struct rte_ml_dev *dev, uint16_t 
qp_id, struct rte_ml_op
        req = &queue->reqs[head];
 
        cn10k_mldev->set_poll_addr(req);
-       cn10k_ml_prep_fp_job_descriptor(cn10k_mldev, req, op);
+       cn10k_ml_prep_fp_job_descriptor(cnxk_mldev, req, op);
 
        memset(&req->cn10k_req.result, 0, sizeof(struct cn10k_ml_result));
        error_code = (union cn10k_ml_error_code 
*)&req->cn10k_req.result.error_code;
@@ -2183,7 +2245,7 @@ cn10k_ml_dequeue_burst(struct rte_ml_dev *dev, uint16_t 
qp_id, struct rte_ml_op
                }
        }
 
-       cn10k_ml_result_update(dev, qp_id, req);
+       cn10k_ml_result_update(cnxk_mldev, qp_id, req);
        ops[count] = req->op;
 
        queue_index_advance(&tail, qp->nb_desc);
@@ -2232,23 +2294,27 @@ cn10k_ml_op_error_get(struct rte_ml_dev *dev, struct 
rte_ml_op *op, struct rte_m
 }
 
 __rte_hot int
-cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op *op)
+cn10k_ml_inference_sync(struct cnxk_ml_dev *cnxk_mldev, struct rte_ml_op *op)
 {
        union cn10k_ml_error_code *error_code;
        struct cn10k_ml_dev *cn10k_mldev;
-       struct cnxk_ml_dev *cnxk_mldev;
        struct cnxk_ml_model *model;
+       struct cnxk_ml_layer *layer;
        struct cnxk_ml_req *req;
+       uint16_t model_id;
+       uint16_t layer_id;
        bool timeout;
        int ret = 0;
 
-       cnxk_mldev = dev->data->dev_private;
        cn10k_mldev = &cnxk_mldev->cn10k_mldev;
-       model = dev->data->models[op->model_id];
-       req = model->layer[0].glow.req;
+       model_id = cnxk_mldev->index_map[op->model_id].model_id;
+       layer_id = cnxk_mldev->index_map[op->model_id].layer_id;
+       model = cnxk_mldev->mldev->data->models[model_id];
+       layer = &model->layer[layer_id];
+       req = layer->glow.req;
 
        cn10k_ml_set_poll_addr(req);
-       cn10k_ml_prep_fp_job_descriptor(cn10k_mldev, req, op);
+       cn10k_ml_prep_fp_job_descriptor(cnxk_mldev, req, op);
 
        memset(&req->cn10k_req.result, 0, sizeof(struct cn10k_ml_result));
        error_code = (union cn10k_ml_error_code 
*)&req->cn10k_req.result.error_code;
@@ -2284,7 +2350,7 @@ cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct 
rte_ml_op *op)
        if (timeout)
                ret = -ETIME;
        else
-               cn10k_ml_result_update(dev, -1, req);
+               cn10k_ml_result_update(cnxk_mldev, -1, req);
 
 error_enqueue:
        return ret;
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index 677219dfdf..a222a43d55 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -315,8 +315,8 @@ int cn10k_ml_dev_xstats_reset(struct rte_ml_dev *dev, enum 
rte_ml_dev_xstats_mod
 int cn10k_ml_model_load(struct cnxk_ml_dev *cnxk_mldev, struct 
rte_ml_model_params *params,
                        struct cnxk_ml_model *model);
 int cn10k_ml_model_unload(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model);
-int cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id);
-int cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id);
+int cn10k_ml_model_start(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model);
+int cn10k_ml_model_stop(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_model 
*model);
 int cn10k_ml_model_info_get(struct rte_ml_dev *dev, uint16_t model_id,
                            struct rte_ml_model_info *model_info);
 int cn10k_ml_model_params_update(struct rte_ml_dev *dev, uint16_t model_id, 
void *buffer);
@@ -335,7 +335,7 @@ __rte_hot uint16_t cn10k_ml_dequeue_burst(struct rte_ml_dev 
*dev, uint16_t qp_id
                                          struct rte_ml_op **ops, uint16_t 
nb_ops);
 __rte_hot int cn10k_ml_op_error_get(struct rte_ml_dev *dev, struct rte_ml_op 
*op,
                                    struct rte_ml_op_error *error);
-__rte_hot int cn10k_ml_inference_sync(struct rte_ml_dev *dev, struct rte_ml_op 
*op);
+__rte_hot int cn10k_ml_inference_sync(struct cnxk_ml_dev *cnxk_mldev, struct 
rte_ml_op *op);
 
 /* Misc ops */
 void cn10k_ml_qp_initialize(struct cnxk_ml_dev *cnxk_mldev, struct cnxk_ml_qp 
*qp);
@@ -344,5 +344,7 @@ void cn10k_ml_qp_initialize(struct cnxk_ml_dev *cnxk_mldev, 
struct cnxk_ml_qp *q
 int cn10k_ml_layer_load(void *device, uint16_t model_id, const char 
*layer_name, uint8_t *buffer,
                        size_t size, uint16_t *index);
 int cn10k_ml_layer_unload(void *device, uint16_t model_id, const char 
*layer_name);
+int cn10k_ml_layer_start(void *device, uint16_t model_id, const char 
*layer_name);
+int cn10k_ml_layer_stop(void *device, uint16_t model_id, const char 
*layer_name);
 
 #endif /* _CN10K_ML_OPS_H_ */
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
index 1d8b84269d..b61ed45876 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.c
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -240,7 +240,7 @@ cnxk_ml_dev_configure(struct rte_ml_dev *dev, const struct 
rte_ml_dev_config *co
                        model = dev->data->models[model_id];
                        if (model != NULL) {
                                if (model->state == 
ML_CNXK_MODEL_STATE_STARTED) {
-                                       if (cn10k_ml_model_stop(dev, model_id) 
!= 0)
+                                       if (cnxk_ml_model_stop(dev, model_id) 
!= 0)
                                                plt_err("Could not stop model 
%u", model_id);
                                }
                                if (model->state == ML_CNXK_MODEL_STATE_LOADED) 
{
@@ -332,7 +332,7 @@ cnxk_ml_dev_close(struct rte_ml_dev *dev)
                model = dev->data->models[model_id];
                if (model != NULL) {
                        if (model->state == ML_CNXK_MODEL_STATE_STARTED) {
-                               if (cn10k_ml_model_stop(dev, model_id) != 0)
+                               if (cnxk_ml_model_stop(dev, model_id) != 0)
                                        plt_err("Could not stop model %u", 
model_id);
                        }
                        if (model->state == ML_CNXK_MODEL_STATE_LOADED) {
@@ -564,6 +564,46 @@ cnxk_ml_model_unload(struct rte_ml_dev *dev, uint16_t 
model_id)
        return plt_memzone_free(plt_memzone_lookup(str));
 }
 
+static int
+cnxk_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id)
+{
+       struct cnxk_ml_dev *cnxk_mldev;
+       struct cnxk_ml_model *model;
+
+       if (dev == NULL)
+               return -EINVAL;
+
+       cnxk_mldev = dev->data->dev_private;
+
+       model = dev->data->models[model_id];
+       if (model == NULL) {
+               plt_err("Invalid model_id = %u", model_id);
+               return -EINVAL;
+       }
+
+       return cn10k_ml_model_start(cnxk_mldev, model);
+}
+
+int
+cnxk_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id)
+{
+       struct cnxk_ml_dev *cnxk_mldev;
+       struct cnxk_ml_model *model;
+
+       if (dev == NULL)
+               return -EINVAL;
+
+       cnxk_mldev = dev->data->dev_private;
+
+       model = dev->data->models[model_id];
+       if (model == NULL) {
+               plt_err("Invalid model_id = %u", model_id);
+               return -EINVAL;
+       }
+
+       return cn10k_ml_model_stop(cnxk_mldev, model);
+}
+
 struct rte_ml_dev_ops cnxk_ml_ops = {
        /* Device control ops */
        .dev_info_get = cnxk_ml_dev_info_get,
@@ -589,8 +629,8 @@ struct rte_ml_dev_ops cnxk_ml_ops = {
        /* Model ops */
        .model_load = cnxk_ml_model_load,
        .model_unload = cnxk_ml_model_unload,
-       .model_start = cn10k_ml_model_start,
-       .model_stop = cn10k_ml_model_stop,
+       .model_start = cnxk_ml_model_start,
+       .model_stop = cnxk_ml_model_stop,
        .model_info_get = cn10k_ml_model_info_get,
        .model_params_update = cn10k_ml_model_params_update,
 
diff --git a/drivers/ml/cnxk/cnxk_ml_ops.h b/drivers/ml/cnxk/cnxk_ml_ops.h
index bc14f6e5b9..d27ca0d0cb 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.h
+++ b/drivers/ml/cnxk/cnxk_ml_ops.h
@@ -63,5 +63,6 @@ struct cnxk_ml_qp {
 extern struct rte_ml_dev_ops cnxk_ml_ops;
 
 int cnxk_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id);
+int cnxk_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id);
 
 #endif /* _CNXK_ML_OPS_H_ */
-- 
2.42.0

Reply via email to