From: Anup Prabhu <apra...@marvell.com>

Enabled check for OCM size requirement for multi-layer
TVM model. Compute OCM scratch and WB requirement for
all layers during the load stage.

Signed-off-by: Anup Prabhu <apra...@marvell.com>
---
 drivers/ml/cnxk/cnxk_ml_ops.c | 60 +++++++++++++++++++++++++++++++++++
 1 file changed, 60 insertions(+)

diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c
index cd95a3c7ad..03f4783b3f 100644
--- a/drivers/ml/cnxk/cnxk_ml_ops.c
+++ b/drivers/ml/cnxk/cnxk_ml_ops.c
@@ -1023,8 +1023,12 @@ cnxk_ml_model_load(struct rte_ml_dev *dev, struct 
rte_ml_model_params *params, u
 
        char str[RTE_MEMZONE_NAMESIZE];
        const struct plt_memzone *mz;
+       uint16_t max_scratch_pages;
+       struct cn10k_ml_ocm *ocm;
        uint64_t model_info_size;
+       uint16_t total_wb_pages;
        uint16_t lcl_model_id;
+       uint16_t layer_id;
        uint64_t mz_size;
        bool found;
        int ret;
@@ -1086,6 +1090,62 @@ cnxk_ml_model_load(struct rte_ml_dev *dev, struct 
rte_ml_model_params *params, u
        if (ret != 0)
                goto error;
 
+       max_scratch_pages = 0;
+       total_wb_pages = 0;
+       layer_id = 0;
+
+       ocm = &cnxk_mldev->cn10k_mldev.ocm;
+
+       if (model->type == ML_CNXK_MODEL_TYPE_GLOW) {
+               total_wb_pages = total_wb_pages + 
model->layer[layer_id].glow.ocm_map.wb_pages;
+               max_scratch_pages = PLT_MAX(max_scratch_pages,
+                                           
model->layer[layer_id].glow.ocm_map.scratch_pages);
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+       } else {
+               for (layer_id = 0; layer_id < 
model->mvtvm.metadata.model.nb_layers; layer_id++) {
+                       if (model->layer[layer_id].type == 
ML_CNXK_LAYER_TYPE_MRVL) {
+                               total_wb_pages = total_wb_pages +
+                                                
model->layer[layer_id].glow.ocm_map.wb_pages;
+                               max_scratch_pages =
+                                       PLT_MAX(max_scratch_pages,
+                                               
model->layer[layer_id].glow.ocm_map.scratch_pages);
+                       }
+               }
+#endif
+       }
+
+       if ((total_wb_pages + max_scratch_pages) > ocm->num_pages) {
+               plt_err("model_id = %u: total_wb_pages (%u) + scratch_pages 
(%u) >  %u\n",
+                       lcl_model_id, total_wb_pages, max_scratch_pages, 
ocm->num_pages);
+
+               if (model->type == ML_CNXK_MODEL_TYPE_GLOW) {
+                       plt_ml_dbg("layer_id = %u: wb_pages = %u, scratch_pages 
= %u\n", layer_id,
+                                  model->layer[layer_id].glow.ocm_map.wb_pages,
+                                  
model->layer[layer_id].glow.ocm_map.scratch_pages);
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+               } else {
+                       for (layer_id = 0; layer_id < 
model->mvtvm.metadata.model.nb_layers;
+                            layer_id++) {
+                               if (model->layer[layer_id].type == 
ML_CNXK_LAYER_TYPE_MRVL) {
+                                       plt_ml_dbg(
+                                               "layer_id = %u: wb_pages = %u, 
scratch_pages = %u\n",
+                                               layer_id,
+                                               
model->layer[layer_id].glow.ocm_map.wb_pages,
+                                               
model->layer[layer_id].glow.ocm_map.scratch_pages);
+                               }
+                       }
+#endif
+               }
+
+               if (model->type == ML_CNXK_MODEL_TYPE_GLOW)
+                       cn10k_ml_model_unload(cnxk_mldev, model);
+#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM
+               else {
+                       mvtvm_ml_model_unload(cnxk_mldev, model);
+                       return -ENOMEM;
+               }
+#endif
+       }
        plt_spinlock_init(&model->lock);
        model->state = ML_CNXK_MODEL_STATE_LOADED;
        cnxk_mldev->nb_models_loaded++;
-- 
2.42.0

Reply via email to