Em qua, 24 de abr de 2019 às 23:15, Guo, Yejun <yejun....@intel.com> escreveu: > > currently, only float is supported as model input, actually, there > are other data types, this patch adds uint8. > > Signed-off-by: Guo, Yejun <yejun....@intel.com> > --- > libavfilter/dnn_backend_native.c | 4 +++- > libavfilter/dnn_backend_tf.c | 28 ++++++++++++++++++++++++---- > libavfilter/dnn_interface.h | 10 +++++++++- > libavfilter/vf_sr.c | 4 +++- > 4 files changed, 39 insertions(+), 7 deletions(-) > > diff --git a/libavfilter/dnn_backend_native.c > b/libavfilter/dnn_backend_native.c > index 8a83c63..06fbdf3 100644 > --- a/libavfilter/dnn_backend_native.c > +++ b/libavfilter/dnn_backend_native.c > @@ -24,8 +24,9 @@ > */ > > #include "dnn_backend_native.h" > +#include "libavutil/avassert.h" > > -static DNNReturnType set_input_output_native(void *model, DNNData *input, > const char *input_name, const char **output_names, uint32_t nb_output) > +static DNNReturnType set_input_output_native(void *model, DNNInputData > *input, const char *input_name, const char **output_names, uint32_t nb_output) > { > ConvolutionalNetwork *network = (ConvolutionalNetwork *)model; > InputParams *input_params; > @@ -45,6 +46,7 @@ static DNNReturnType set_input_output_native(void *model, > DNNData *input, const > if (input->data){ > av_freep(&input->data); > } > + av_assert0(input->dt == DNN_FLOAT); > network->layers[0].output = input->data = av_malloc(cur_height * > cur_width * cur_channels * sizeof(float)); > if (!network->layers[0].output){ > return DNN_ERROR; > diff --git a/libavfilter/dnn_backend_tf.c b/libavfilter/dnn_backend_tf.c > index ca6472d..ba959ae 100644 > --- a/libavfilter/dnn_backend_tf.c > +++ b/libavfilter/dnn_backend_tf.c > @@ -79,10 +79,31 @@ static TF_Buffer *read_graph(const char *model_filename) > return graph_buf; > } > > -static DNNReturnType set_input_output_tf(void *model, DNNData *input, const > char *input_name, const char **output_names, uint32_t nb_output) > +static TF_Tensor *allocate_input_tensor(const DNNInputData *input) > { > - TFModel *tf_model = (TFModel *)model; > + TF_DataType dt; > + size_t size; > int64_t input_dims[] = {1, input->height, input->width, input->channels}; > + switch (input->dt) { > + case DNN_FLOAT: > + dt = TF_FLOAT; > + size = sizeof(float); > + break; > + case DNN_UINT8: > + dt = TF_UINT8; > + size = sizeof(char); > + break; > + default: > + av_assert0(!"should not reach here"); > + } > + > + return TF_AllocateTensor(dt, input_dims, 4, > + input_dims[1] * input_dims[2] * input_dims[3] * > size); > +} > + > +static DNNReturnType set_input_output_tf(void *model, DNNInputData *input, > const char *input_name, const char **output_names, uint32_t nb_output) > +{ > + TFModel *tf_model = (TFModel *)model; > TF_SessionOptions *sess_opts; > const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, > "init"); > > @@ -95,8 +116,7 @@ static DNNReturnType set_input_output_tf(void *model, > DNNData *input, const char > if (tf_model->input_tensor){ > TF_DeleteTensor(tf_model->input_tensor); > } > - tf_model->input_tensor = TF_AllocateTensor(TF_FLOAT, input_dims, 4, > - input_dims[1] * input_dims[2] > * input_dims[3] * sizeof(float)); > + tf_model->input_tensor = allocate_input_tensor(input); > if (!tf_model->input_tensor){ > return DNN_ERROR; > } > diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h > index 73d226e..c24df0e 100644 > --- a/libavfilter/dnn_interface.h > +++ b/libavfilter/dnn_interface.h > @@ -32,6 +32,14 @@ typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType; > > typedef enum {DNN_NATIVE, DNN_TF} DNNBackendType; > > +typedef enum {DNN_FLOAT, DNN_UINT8} DNNDataType; > + > +typedef struct DNNInputData{ > + void *data; > + DNNDataType dt; > + int width, height, channels; > +} DNNInputData; > + > typedef struct DNNData{ > float *data; > int width, height, channels; > @@ -42,7 +50,7 @@ typedef struct DNNModel{ > void *model; > // Sets model input and output. > // Should be called at least once before model execution. > - DNNReturnType (*set_input_output)(void *model, DNNData *input, const > char *input_name, const char **output_names, uint32_t nb_output); > + DNNReturnType (*set_input_output)(void *model, DNNInputData *input, > const char *input_name, const char **output_names, uint32_t nb_output); > } DNNModel; > > // Stores pointers to functions for loading, executing, freeing DNN models > for one of the backends. > diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c > index b4d4165..c0d7126 100644 > --- a/libavfilter/vf_sr.c > +++ b/libavfilter/vf_sr.c > @@ -40,7 +40,8 @@ typedef struct SRContext { > DNNBackendType backend_type; > DNNModule *dnn_module; > DNNModel *model; > - DNNData input, output; > + DNNInputData input; > + DNNData output; > int scale_factor; > struct SwsContext *sws_contexts[3]; > int sws_slice_h, sws_input_linesize, sws_output_linesize; > @@ -87,6 +88,7 @@ static av_cold int init(AVFilterContext *context) > return AVERROR(EIO); > } > > + sr_context->input.dt = DNN_FLOAT; > sr_context->sws_contexts[0] = NULL; > sr_context->sws_contexts[1] = NULL; > sr_context->sws_contexts[2] = NULL; > -- > 2.7.4 >
LGTM. I think it would be valuable to add a few tests covering the features added by this patch series. > _______________________________________________ > ffmpeg-devel mailing list > ffmpeg-devel@ffmpeg.org > https://ffmpeg.org/mailman/listinfo/ffmpeg-devel > > To unsubscribe, visit link above, or email > ffmpeg-devel-requ...@ffmpeg.org with subject "unsubscribe". _______________________________________________ ffmpeg-devel mailing list ffmpeg-devel@ffmpeg.org https://ffmpeg.org/mailman/listinfo/ffmpeg-devel To unsubscribe, visit link above, or email ffmpeg-devel-requ...@ffmpeg.org with subject "unsubscribe".