From: Anup Prabhu <apra...@marvell.com> Enabled check for OCM size requirement for multi-layer TVM model. Compute OCM scratch and WB requirement for all layers during the load stage.
Signed-off-by: Anup Prabhu <apra...@marvell.com> --- drivers/ml/cnxk/cnxk_ml_ops.c | 60 +++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/drivers/ml/cnxk/cnxk_ml_ops.c b/drivers/ml/cnxk/cnxk_ml_ops.c index cd95a3c7ad..03f4783b3f 100644 --- a/drivers/ml/cnxk/cnxk_ml_ops.c +++ b/drivers/ml/cnxk/cnxk_ml_ops.c @@ -1023,8 +1023,12 @@ cnxk_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params, u char str[RTE_MEMZONE_NAMESIZE]; const struct plt_memzone *mz; + uint16_t max_scratch_pages; + struct cn10k_ml_ocm *ocm; uint64_t model_info_size; + uint16_t total_wb_pages; uint16_t lcl_model_id; + uint16_t layer_id; uint64_t mz_size; bool found; int ret; @@ -1086,6 +1090,62 @@ cnxk_ml_model_load(struct rte_ml_dev *dev, struct rte_ml_model_params *params, u if (ret != 0) goto error; + max_scratch_pages = 0; + total_wb_pages = 0; + layer_id = 0; + + ocm = &cnxk_mldev->cn10k_mldev.ocm; + + if (model->type == ML_CNXK_MODEL_TYPE_GLOW) { + total_wb_pages = total_wb_pages + model->layer[layer_id].glow.ocm_map.wb_pages; + max_scratch_pages = PLT_MAX(max_scratch_pages, + model->layer[layer_id].glow.ocm_map.scratch_pages); +#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM + } else { + for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; layer_id++) { + if (model->layer[layer_id].type == ML_CNXK_LAYER_TYPE_MRVL) { + total_wb_pages = total_wb_pages + + model->layer[layer_id].glow.ocm_map.wb_pages; + max_scratch_pages = + PLT_MAX(max_scratch_pages, + model->layer[layer_id].glow.ocm_map.scratch_pages); + } + } +#endif + } + + if ((total_wb_pages + max_scratch_pages) > ocm->num_pages) { + plt_err("model_id = %u: total_wb_pages (%u) + scratch_pages (%u) > %u\n", + lcl_model_id, total_wb_pages, max_scratch_pages, ocm->num_pages); + + if (model->type == ML_CNXK_MODEL_TYPE_GLOW) { + plt_ml_dbg("layer_id = %u: wb_pages = %u, scratch_pages = %u\n", layer_id, + model->layer[layer_id].glow.ocm_map.wb_pages, + model->layer[layer_id].glow.ocm_map.scratch_pages); +#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM + } else { + for (layer_id = 0; layer_id < model->mvtvm.metadata.model.nb_layers; + layer_id++) { + if (model->layer[layer_id].type == ML_CNXK_LAYER_TYPE_MRVL) { + plt_ml_dbg( + "layer_id = %u: wb_pages = %u, scratch_pages = %u\n", + layer_id, + model->layer[layer_id].glow.ocm_map.wb_pages, + model->layer[layer_id].glow.ocm_map.scratch_pages); + } + } +#endif + } + + if (model->type == ML_CNXK_MODEL_TYPE_GLOW) + cn10k_ml_model_unload(cnxk_mldev, model); +#ifdef RTE_MLDEV_CNXK_ENABLE_MVTVM + else { + mvtvm_ml_model_unload(cnxk_mldev, model); + return -ENOMEM; + } +#endif + } plt_spinlock_init(&model->lock); model->state = ML_CNXK_MODEL_STATE_LOADED; cnxk_mldev->nb_models_loaded++; -- 2.42.0