This commit rearranges the existing code to create two separate functions for filling request with execution data and the completion callback.
Signed-off-by: Shubhanshu Saxena <shubhanshu....@gmail.com> --- libavfilter/dnn/dnn_backend_tf.c | 81 ++++++++++++++++++++++---------- 1 file changed, 57 insertions(+), 24 deletions(-) diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c index 793b108e55..5d34da5db1 100644 --- a/libavfilter/dnn/dnn_backend_tf.c +++ b/libavfilter/dnn/dnn_backend_tf.c @@ -826,20 +826,16 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, DNNFunctionType func_ return model; } -static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_queue) -{ - TFModel *tf_model; - TFContext *ctx; - tf_infer_request *infer_request; +static DNNReturnType fill_model_input_tf(TFModel *tf_model, RequestItem *request) { + DNNData input; InferenceItem *inference; TaskItem *task; - DNNData input, *outputs; + tf_infer_request *infer_request; + TFContext *ctx = &tf_model->ctx; - inference = ff_queue_pop_front(inference_queue); + inference = ff_queue_pop_front(tf_model->inference_queue); av_assert0(inference); task = inference->task; - tf_model = task->model; - ctx = &tf_model->ctx; request->inference = inference; if (get_input_tf(tf_model, &input, task->input_name) != DNN_SUCCESS) @@ -852,7 +848,7 @@ static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_que infer_request->tf_input = av_malloc(sizeof(TF_Output)); infer_request->tf_input->oper = TF_GraphOperationByName(tf_model->graph, task->input_name); if (!infer_request->tf_input->oper){ - av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name); + av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", task->input_name); return DNN_ERROR; } infer_request->tf_input->index = 0; @@ -902,22 +898,23 @@ static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_que infer_request->tf_outputs[i].index = 0; } - TF_SessionRun(tf_model->session, NULL, - infer_request->tf_input, &infer_request->input_tensor, 1, - infer_request->tf_outputs, infer_request->output_tensors, - task->nb_output, NULL, 0, NULL, - tf_model->status); - if (TF_GetCode(tf_model->status) != TF_OK) { - tf_free_request(infer_request); - av_log(ctx, AV_LOG_ERROR, "Failed to run session when executing model\n"); - return DNN_ERROR; - } + return DNN_SUCCESS; +} + +static void infer_completion_callback(void *args) { + RequestItem *request = args; + InferenceItem *inference = request->inference; + TaskItem *task = inference->task; + DNNData *outputs; + tf_infer_request *infer_request = request->infer_request; + TFModel *tf_model = task->model; + TFContext *ctx = &tf_model->ctx; outputs = av_malloc_array(task->nb_output, sizeof(*outputs)); if (!outputs) { tf_free_request(infer_request); av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for *outputs\n"); - return DNN_ERROR; + return; } for (uint32_t i = 0; i < task->nb_output; ++i) { @@ -944,7 +941,7 @@ static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_que case DFT_ANALYTICS_DETECT: if (!tf_model->model->detect_post_proc) { av_log(ctx, AV_LOG_ERROR, "Detect filter needs provide post proc\n"); - return DNN_ERROR; + return; } tf_model->model->detect_post_proc(task->out_frame, outputs, task->nb_output, tf_model->model->filter_ctx); break; @@ -955,7 +952,7 @@ static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_que } } av_log(ctx, AV_LOG_ERROR, "Tensorflow backend does not support this kind of dnn filter now\n"); - return DNN_ERROR; + return; } for (uint32_t i = 0; i < task->nb_output; ++i) { if (infer_request->output_tensors[i]) { @@ -966,7 +963,43 @@ static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_que tf_free_request(infer_request); av_freep(&outputs); ff_safe_queue_push_back(tf_model->request_queue, request); - return (task->inference_done == task->inference_todo) ? DNN_SUCCESS : DNN_ERROR; +} + +static DNNReturnType execute_model_tf(RequestItem *request, Queue *inference_queue) +{ + TFModel *tf_model; + TFContext *ctx; + tf_infer_request *infer_request; + InferenceItem *inference; + TaskItem *task; + + inference = ff_queue_peek_front(inference_queue); + task = inference->task; + tf_model = task->model; + ctx = &tf_model->ctx; + + if (task->async) { + avpriv_report_missing_feature(ctx, "Async execution not supported"); + return DNN_ERROR; + } else { + if (fill_model_input_tf(tf_model, request) != DNN_SUCCESS) { + return DNN_ERROR; + } + + infer_request = request->infer_request; + TF_SessionRun(tf_model->session, NULL, + infer_request->tf_input, &infer_request->input_tensor, 1, + infer_request->tf_outputs, infer_request->output_tensors, + task->nb_output, NULL, 0, NULL, + tf_model->status); + if (TF_GetCode(tf_model->status) != TF_OK) { + tf_free_request(infer_request); + av_log(ctx, AV_LOG_ERROR, "Failed to run session when executing model\n"); + return DNN_ERROR; + } + infer_completion_callback(request); + return (task->inference_done == task->inference_todo) ? DNN_SUCCESS : DNN_ERROR; + } } DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNExecBaseParams *exec_params) -- 2.25.1 _______________________________________________ ffmpeg-devel mailing list ffmpeg-devel@ffmpeg.org https://ffmpeg.org/mailman/listinfo/ffmpeg-devel To unsubscribe, visit link above, or email ffmpeg-devel-requ...@ffmpeg.org with subject "unsubscribe".