Implemented model stop driver function. A model stop job is
enqueued through scratch registers and is checked for
completion through polling in a synchronous mode. OCM pages
are released after model stop completion.

Signed-off-by: Srikanth Yalavarthi <syalavar...@marvell.com>
---
 drivers/ml/cnxk/cn10k_ml_ops.c | 115 ++++++++++++++++++++++++++++++++-
 drivers/ml/cnxk/cn10k_ml_ops.h |   1 +
 2 files changed, 114 insertions(+), 2 deletions(-)

diff --git a/drivers/ml/cnxk/cn10k_ml_ops.c b/drivers/ml/cnxk/cn10k_ml_ops.c
index e8ce65b182..77d3728d8d 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.c
+++ b/drivers/ml/cnxk/cn10k_ml_ops.c
@@ -295,10 +295,14 @@ cn10k_ml_dev_configure(struct rte_ml_dev *dev, const 
struct rte_ml_dev_config *c
                /* Re-configure */
                void **models;
 
-               /* Unload all models */
+               /* Stop and unload all models */
                for (model_id = 0; model_id < dev->data->nb_models; model_id++) 
{
                        model = dev->data->models[model_id];
                        if (model != NULL) {
+                               if (model->state == 
ML_CN10K_MODEL_STATE_STARTED) {
+                                       if (cn10k_ml_model_stop(dev, model_id) 
!= 0)
+                                               plt_err("Could not stop model 
%u", model_id);
+                               }
                                if (model->state == 
ML_CN10K_MODEL_STATE_LOADED) {
                                        if (cn10k_ml_model_unload(dev, 
model_id) != 0)
                                                plt_err("Could not unload model 
%u", model_id);
@@ -362,10 +366,14 @@ cn10k_ml_dev_close(struct rte_ml_dev *dev)
 
        mldev = dev->data->dev_private;
 
-       /* Unload all models */
+       /* Stop and unload all models */
        for (model_id = 0; model_id < dev->data->nb_models; model_id++) {
                model = dev->data->models[model_id];
                if (model != NULL) {
+                       if (model->state == ML_CN10K_MODEL_STATE_STARTED) {
+                               if (cn10k_ml_model_stop(dev, model_id) != 0)
+                                       plt_err("Could not stop model %u", 
model_id);
+                       }
                        if (model->state == ML_CN10K_MODEL_STATE_LOADED) {
                                if (cn10k_ml_model_unload(dev, model_id) != 0)
                                        plt_err("Could not unload model %u", 
model_id);
@@ -767,6 +775,108 @@ cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t 
model_id)
        return ret;
 }
 
+int
+cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id)
+{
+       struct cn10k_ml_model *model;
+       struct cn10k_ml_dev *mldev;
+       struct cn10k_ml_ocm *ocm;
+       struct cn10k_ml_req *req;
+
+       bool job_enqueued;
+       bool job_dequeued;
+       bool locked;
+       int ret = 0;
+
+       mldev = dev->data->dev_private;
+       ocm = &mldev->ocm;
+       model = dev->data->models[model_id];
+
+       if (model == NULL) {
+               plt_err("Invalid model_id = %u", model_id);
+               return -EINVAL;
+       }
+
+       /* Prepare JD */
+       req = model->req;
+       cn10k_ml_prep_sp_job_descriptor(mldev, model, req, 
ML_CN10K_JOB_TYPE_MODEL_STOP);
+       req->result.error_code = 0x0;
+       req->result.user_ptr = NULL;
+
+       plt_write64(ML_CN10K_POLL_JOB_START, &req->status);
+       plt_wmb();
+
+       locked = false;
+       while (!locked) {
+               if (plt_spinlock_trylock(&model->lock) != 0) {
+                       if (model->state == ML_CN10K_MODEL_STATE_LOADED) {
+                               plt_ml_dbg("Model not started, model = 
0x%016lx",
+                                          PLT_U64_CAST(model));
+                               plt_spinlock_unlock(&model->lock);
+                               return 1;
+                       }
+
+                       if (model->state == ML_CN10K_MODEL_STATE_JOB_ACTIVE) {
+                               plt_err("A slow-path job is active for the 
model = 0x%016lx",
+                                       PLT_U64_CAST(model));
+                               plt_spinlock_unlock(&model->lock);
+                               return -EBUSY;
+                       }
+
+                       model->state = ML_CN10K_MODEL_STATE_JOB_ACTIVE;
+                       plt_spinlock_unlock(&model->lock);
+                       locked = true;
+               }
+       }
+
+       while (model->model_mem_map.ocm_reserved) {
+               if (plt_spinlock_trylock(&ocm->lock) != 0) {
+                       cn10k_ml_ocm_free_pages(dev, model->model_id);
+                       model->model_mem_map.ocm_reserved = false;
+                       model->model_mem_map.tilemask = 0x0;
+                       plt_spinlock_unlock(&ocm->lock);
+               }
+       }
+
+       job_enqueued = false;
+       job_dequeued = false;
+       do {
+               if (!job_enqueued) {
+                       req->timeout = plt_tsc_cycles() + ML_CN10K_CMD_TIMEOUT 
* plt_tsc_hz();
+                       job_enqueued = roc_ml_scratch_enqueue(&mldev->roc, 
&req->jd);
+               }
+
+               if (job_enqueued && !job_dequeued)
+                       job_dequeued = roc_ml_scratch_dequeue(&mldev->roc, 
&req->jd);
+
+               if (job_dequeued)
+                       break;
+       } while (plt_tsc_cycles() < req->timeout);
+
+       if (job_dequeued) {
+               if (plt_read64(&req->status) == ML_CN10K_POLL_JOB_FINISH) {
+                       if (req->result.error_code == 0x0)
+                               ret = 0;
+                       else
+                               ret = -1;
+               }
+       } else {
+               roc_ml_scratch_queue_reset(&mldev->roc);
+               ret = -ETIME;
+       }
+
+       locked = false;
+       while (!locked) {
+               if (plt_spinlock_trylock(&model->lock) != 0) {
+                       model->state = ML_CN10K_MODEL_STATE_LOADED;
+                       plt_spinlock_unlock(&model->lock);
+                       locked = true;
+               }
+       }
+
+       return ret;
+}
+
 struct rte_ml_dev_ops cn10k_ml_ops = {
        /* Device control ops */
        .dev_info_get = cn10k_ml_dev_info_get,
@@ -783,4 +893,5 @@ struct rte_ml_dev_ops cn10k_ml_ops = {
        .model_load = cn10k_ml_model_load,
        .model_unload = cn10k_ml_model_unload,
        .model_start = cn10k_ml_model_start,
+       .model_stop = cn10k_ml_model_stop,
 };
diff --git a/drivers/ml/cnxk/cn10k_ml_ops.h b/drivers/ml/cnxk/cn10k_ml_ops.h
index 989af978c4..22576b93c0 100644
--- a/drivers/ml/cnxk/cn10k_ml_ops.h
+++ b/drivers/ml/cnxk/cn10k_ml_ops.h
@@ -65,5 +65,6 @@ int cn10k_ml_model_load(struct rte_ml_dev *dev, struct 
rte_ml_model_params *para
                        uint16_t *model_id);
 int cn10k_ml_model_unload(struct rte_ml_dev *dev, uint16_t model_id);
 int cn10k_ml_model_start(struct rte_ml_dev *dev, uint16_t model_id);
+int cn10k_ml_model_stop(struct rte_ml_dev *dev, uint16_t model_id);
 
 #endif /* _CN10K_ML_OPS_H_ */
-- 
2.17.1

Reply via email to