Hi, how should I test this patch?
Em sex., 22 de nov. de 2019 às 04:57, Guo, Yejun <yejun....@intel.com> escreveu: > Signed-off-by: Guo, Yejun <yejun....@intel.com> > --- > doc/filters.texi | 8 ++- > libavfilter/vf_dnn_processing.c | 147 > ++++++++++++++++++++++++++++++---------- > 2 files changed, 118 insertions(+), 37 deletions(-) > > diff --git a/doc/filters.texi b/doc/filters.texi > index 1f86ae1..c3f7997 100644 > --- a/doc/filters.texi > +++ b/doc/filters.texi > @@ -8992,7 +8992,13 @@ Set the input name of the dnn network. > Set the output name of the dnn network. > > @item fmt > -Set the pixel format for the Frame. Allowed values are > @code{AV_PIX_FMT_RGB24}, and @code{AV_PIX_FMT_BGR24}. > +Set the pixel format for the Frame, the value is determined by the input > of the dnn network model. > This sentence is a bit confusing, also I think this property should be removed. (I will explain bellow). + > +If the model handles RGB (or BGR) image and the data type of model input > is uint8, fmt must be @code{AV_PIX_FMT_RGB24} (or @code{AV_PIX_FMT_BGR24}. > +If the model handles RGB (or BGR) image and the data type of model input > is float, fmt must be @code{AV_PIX_FMT_RGB24} (or @code{AV_PIX_FMT_BGR24}, > and this filter will do data type conversion internally. > +If the model handles GRAY image and the data type of model input is > uint8, fmt must be @code{AV_PIX_FMT_GRAY8}. > +If the model handles GRAY image and the data type of model input is > float, fmt must be @code{AV_PIX_FMT_GRAYF32}. > + > Default value is @code{AV_PIX_FMT_RGB24}. > > @end table > diff --git a/libavfilter/vf_dnn_processing.c > b/libavfilter/vf_dnn_processing.c > index ce976ec..963dd5e 100644 > --- a/libavfilter/vf_dnn_processing.c > +++ b/libavfilter/vf_dnn_processing.c > @@ -70,10 +70,12 @@ static av_cold int init(AVFilterContext *context) > { > DnnProcessingContext *ctx = context->priv; > int supported = 0; > - // as the first step, only rgb24 and bgr24 are supported > + // to support more formats > const enum AVPixelFormat supported_pixel_fmts[] = { > AV_PIX_FMT_RGB24, > AV_PIX_FMT_BGR24, > + AV_PIX_FMT_GRAY8, > + AV_PIX_FMT_GRAYF32, > }; > for (int i = 0; i < sizeof(supported_pixel_fmts) / sizeof(enum > AVPixelFormat); ++i) { > if (supported_pixel_fmts[i] == ctx->fmt) { > @@ -156,14 +158,38 @@ static int config_input(AVFilterLink *inlink) > return AVERROR(EIO); > } > I think the filter should not check formats manually in the init function (unless I'm missing something), it would be best if you query for all the above supported formats in query_formats and later in config_input you make sure the expected model format matches the frame format. > - if (model_input.channels != 3) { > - av_log(ctx, AV_LOG_ERROR, "the model requires input channels > %d\n", > - model_input.channels); > - return AVERROR(EIO); > - } > - if (model_input.dt != DNN_FLOAT && model_input.dt != DNN_UINT8) { > - av_log(ctx, AV_LOG_ERROR, "only support dnn models with input > data type as float32 and uint8.\n"); > - return AVERROR(EIO); > + if (ctx->fmt == AV_PIX_FMT_RGB24 || ctx->fmt == AV_PIX_FMT_BGR24) { > + if (model_input.channels != 3) { > + av_log(ctx, AV_LOG_ERROR, "channel number 3 is required, but > the actual channel number is %d\n", > + model_input.channels); > + return AVERROR(EIO); > + } > + if (model_input.dt != DNN_FLOAT && model_input.dt != DNN_UINT8) { > + av_log(ctx, AV_LOG_ERROR, "only support dnn models with input > data type as float32 and uint8.\n"); > + return AVERROR(EIO); > + } > + } else if (ctx->fmt == AV_PIX_FMT_GRAY8) { > + if (model_input.channels != 1) { > + av_log(ctx, AV_LOG_ERROR, "channel number 1 is required, but > the actual channel number is %d\n", > + model_input.channels); > + return AVERROR(EIO); > + } > + if (model_input.dt != DNN_UINT8) { > + av_log(ctx, AV_LOG_ERROR, "only support dnn models with input > data type as uint8.\n"); > + return AVERROR(EIO); > + } > + } else if (ctx->fmt == AV_PIX_FMT_GRAYF32) { > + if (model_input.channels != 1) { > + av_log(ctx, AV_LOG_ERROR, "channel number 1 is required, but > the actual channel number is %d\n", > + model_input.channels); > + return AVERROR(EIO); > + } > + if (model_input.dt != DNN_FLOAT) { > + av_log(ctx, AV_LOG_ERROR, "only support dnn models with input > data type as float.\n"); > + return AVERROR(EIO); > + } > + } else { > + av_assert0(!"should not reach here."); > } > General comment on the above and following chained ifs testing pixel formats, personally, using switch(fmt) seems more readable. > > ctx->input.width = inlink->w; > @@ -203,28 +229,49 @@ static int config_output(AVFilterLink *outlink) > > static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame > *frame) > { > - // extend this function to support more formats > - av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == > AV_PIX_FMT_BGR24); > - > - if (dnn_input->dt == DNN_FLOAT) { > - float *dnn_input_data = dnn_input->data; > - for (int i = 0; i < frame->height; i++) { > - for(int j = 0; j < frame->width * 3; j++) { > - int k = i * frame->linesize[0] + j; > - int t = i * frame->width * 3 + j; > - dnn_input_data[t] = frame->data[0][k] / 255.0f; > + if (frame->format == AV_PIX_FMT_RGB24 || frame->format == > AV_PIX_FMT_BGR24) { > + if (dnn_input->dt == DNN_FLOAT) { > + float *dnn_input_data = dnn_input->data; > + for (int i = 0; i < frame->height; i++) { > + for(int j = 0; j < frame->width * 3; j++) { > + int k = i * frame->linesize[0] + j; > + int t = i * frame->width * 3 + j; > + dnn_input_data[t] = frame->data[0][k] / 255.0f; > + } > + } > + } else { > + uint8_t *dnn_input_data = dnn_input->data; > + av_assert0(dnn_input->dt == DNN_UINT8); > + for (int i = 0; i < frame->height; i++) { > + for(int j = 0; j < frame->width * 3; j++) { > + int k = i * frame->linesize[0] + j; > + int t = i * frame->width * 3 + j; > + dnn_input_data[t] = frame->data[0][k]; > + } > } > } > - } else { > + } else if (frame->format == AV_PIX_FMT_GRAY8) { > uint8_t *dnn_input_data = dnn_input->data; > av_assert0(dnn_input->dt == DNN_UINT8); > for (int i = 0; i < frame->height; i++) { > - for(int j = 0; j < frame->width * 3; j++) { > + for(int j = 0; j < frame->width; j++) { > int k = i * frame->linesize[0] + j; > - int t = i * frame->width * 3 + j; > + int t = i * frame->width + j; > dnn_input_data[t] = frame->data[0][k]; > } > } > + } else if (frame->format == AV_PIX_FMT_GRAYF32) { > + float *dnn_input_data = dnn_input->data; > + av_assert0(dnn_input->dt == DNN_FLOAT); > + for (int i = 0; i < frame->height; i++) { > + for(int j = 0; j < frame->width; j++) { > + int k = i * frame->linesize[0] + j * sizeof(float); > + int t = i * frame->width + j; > + dnn_input_data[t] = *(float*)(frame->data[0] + k); > + } > + } > + } else { > + av_assert0(!"should not reach here."); > } > > return 0; > @@ -232,28 +279,49 @@ static int copy_from_frame_to_dnn(DNNData > *dnn_input, const AVFrame *frame) > > static int copy_from_dnn_to_frame(AVFrame *frame, const DNNData > *dnn_output) > { > - // extend this function to support more formats > - av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == > AV_PIX_FMT_BGR24); > - > - if (dnn_output->dt == DNN_FLOAT) { > - float *dnn_output_data = dnn_output->data; > - for (int i = 0; i < frame->height; i++) { > - for(int j = 0; j < frame->width * 3; j++) { > - int k = i * frame->linesize[0] + j; > - int t = i * frame->width * 3 + j; > - frame->data[0][k] = > av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8); > + if (frame->format == AV_PIX_FMT_RGB24 || frame->format == > AV_PIX_FMT_BGR24) { > + if (dnn_output->dt == DNN_FLOAT) { > + float *dnn_output_data = dnn_output->data; > + for (int i = 0; i < frame->height; i++) { > + for(int j = 0; j < frame->width * 3; j++) { > + int k = i * frame->linesize[0] + j; > + int t = i * frame->width * 3 + j; > + frame->data[0][k] = > av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8); > + } > + } > + } else { > + uint8_t *dnn_output_data = dnn_output->data; > + av_assert0(dnn_output->dt == DNN_UINT8); > + for (int i = 0; i < frame->height; i++) { > + for(int j = 0; j < frame->width * 3; j++) { > + int k = i * frame->linesize[0] + j; > + int t = i * frame->width * 3 + j; > + frame->data[0][k] = dnn_output_data[t]; > + } > } > } > - } else { > + } else if (frame->format == AV_PIX_FMT_GRAY8) { > uint8_t *dnn_output_data = dnn_output->data; > av_assert0(dnn_output->dt == DNN_UINT8); > for (int i = 0; i < frame->height; i++) { > - for(int j = 0; j < frame->width * 3; j++) { > + for(int j = 0; j < frame->width; j++) { > int k = i * frame->linesize[0] + j; > - int t = i * frame->width * 3 + j; > + int t = i * frame->width + j; > frame->data[0][k] = dnn_output_data[t]; > } > } > + } else if (frame->format == AV_PIX_FMT_GRAYF32) { > + float *dnn_output_data = dnn_output->data; > + av_assert0(dnn_output->dt == DNN_FLOAT); > + for (int i = 0; i < frame->height; i++) { > + for(int j = 0; j < frame->width; j++) { > + int k = i * frame->linesize[0] + j * sizeof(float); > + int t = i * frame->width + j; > + *(float*)(frame->data[0] + k) = dnn_output_data[t]; > + } > + } > + } else { > + av_assert0(!"should not reach here."); > } > > return 0; > @@ -275,7 +343,14 @@ static int filter_frame(AVFilterLink *inlink, AVFrame > *in) > av_frame_free(&in); > return AVERROR(EIO); > } > - av_assert0(ctx->output.channels == 3); > + > + if (ctx->fmt == AV_PIX_FMT_RGB24 || ctx->fmt == AV_PIX_FMT_BGR24) { > + av_assert0(ctx->output.channels == 3); > + } else if (ctx->fmt == AV_PIX_FMT_GRAY8 || ctx->fmt == > AV_PIX_FMT_GRAYF32) { > + av_assert0(ctx->output.channels == 1); > + } else { > + av_assert0(!"should not reach here"); > + } > > out = ff_get_video_buffer(outlink, outlink->w, outlink->h); > if (!out) { > -- > 2.7.4 > > _______________________________________________ > ffmpeg-devel mailing list > ffmpeg-devel@ffmpeg.org > https://ffmpeg.org/mailman/listinfo/ffmpeg-devel > > To unsubscribe, visit link above, or email > ffmpeg-devel-requ...@ffmpeg.org with subject "unsubscribe". _______________________________________________ ffmpeg-devel mailing list ffmpeg-devel@ffmpeg.org https://ffmpeg.org/mailman/listinfo/ffmpeg-devel To unsubscribe, visit link above, or email ffmpeg-devel-requ...@ffmpeg.org with subject "unsubscribe".