From: Wenbin Chen <wenbin.c...@intel.com> Add xpu device support to libtorch backend. To enable xpu support you need to add "-Wl,--no-as-needed -lintel-ext-pt-gpu -Wl,--as-needed" to "--extra-libs" when configure ffmpeg.
Signed-off-by: Wenbin Chen <wenbin.c...@intel.com> --- libavfilter/dnn/dnn_backend_torch.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/libavfilter/dnn/dnn_backend_torch.cpp b/libavfilter/dnn/dnn_backend_torch.cpp index 2557264713..ea493f5873 100644 --- a/libavfilter/dnn/dnn_backend_torch.cpp +++ b/libavfilter/dnn/dnn_backend_torch.cpp @@ -250,6 +250,10 @@ static int th_start_inference(void *args) av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n"); return DNN_GENERIC_ERROR; } + // Transfer tensor to the same device as model + c10::Device device = (*th_model->jit_model->parameters().begin()).device(); + if (infer_request->input_tensor->device() != device) + *infer_request->input_tensor = infer_request->input_tensor->to(device); inputs.push_back(*infer_request->input_tensor); *infer_request->output = th_model->jit_model->forward(inputs).toTensor(); @@ -285,6 +289,9 @@ static void infer_completion_callback(void *args) { switch (th_model->model.func_type) { case DFT_PROCESS_FRAME: if (task->do_ioproc) { + // Post process can only deal with CPU memory. + if (output->device() != torch::kCPU) + *output = output->to(torch::kCPU); outputs.scale = 255; outputs.data = output->data_ptr(); if (th_model->model.frame_post_proc != NULL) { @@ -424,7 +431,13 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, A th_model->ctx = ctx; c10::Device device = c10::Device(device_name); - if (!device.is_cpu()) { + if (device.is_xpu()) { + if (!at::hasXPU()) { + av_log(ctx, AV_LOG_ERROR, "No XPU device found\n"); + goto fail; + } + at::detail::getXPUHooks().initXPU(); + } else if (!device.is_cpu()) { av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", device_name); goto fail; } @@ -432,6 +445,7 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, A try { th_model->jit_model = new torch::jit::Module; (*th_model->jit_model) = torch::jit::load(ctx->model_filename); + th_model->jit_model->to(device); } catch (const c10::Error& e) { av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n"); goto fail; -- 2.34.1 _______________________________________________ ffmpeg-devel mailing list ffmpeg-devel@ffmpeg.org https://ffmpeg.org/mailman/listinfo/ffmpeg-devel To unsubscribe, visit link above, or email ffmpeg-devel-requ...@ffmpeg.org with subject "unsubscribe".