[ https://issues.apache.org/jira/browse/BEAM-14068?focusedWorklogId=778252&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-778252 ]
ASF GitHub Bot logged work on BEAM-14068: ----------------------------------------- Author: ASF GitHub Bot Created on: 03/Jun/22 18:40 Start Date: 03/Jun/22 18:40 Worklog Time Spent: 10m Work Description: tvalentyn commented on code in PR #17462: URL: https://github.com/apache/beam/pull/17462#discussion_r889236325 ########## sdks/python/apache_beam/examples/inference/pytorch_image_classification.py: ########## @@ -0,0 +1,160 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""""A pipeline that uses RunInference API to perform image classification.""" + +import argparse +import io +import os +from functools import partial +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Tuple + +import apache_beam as beam +import torch +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.api import PredictionResult +from apache_beam.ml.inference.api import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelLoader +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from PIL import Image +from torchvision import transforms +from torchvision.models.mobilenetv2 import MobileNetV2 + + +def read_image(image_file_name: str, + path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + if path_to_dir is not None: + image_file_name = os.path.join(path_to_dir, image_file_name) + with FileSystems().open(image_file_name, 'r') as file: + data = Image.open(io.BytesIO(file.read())).convert('RGB') + return image_file_name, data + + +def preprocess_image(data: Image.Image) -> torch.Tensor: + image_size = (224, 224) + # Pre-trained PyTorch models expect input images normalized the + # below values ref: https://pytorch.org/vision/stable/models.html# + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + transform = transforms.Compose([ + transforms.Resize(image_size), + transforms.ToTensor(), + normalize, + ]) + return transform(data) + + +class PostProcessor(beam.DoFn): + def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + filename, prediction_result = element + prediction = torch.argmax(prediction_result.inference, dim=0) + yield filename + ',' + str(prediction.item()) + + +def run_pipeline( + options: PipelineOptions, + model_class: Optional[torch.nn.Module], + model_params: Optional[Dict], + args=None): + """ + Args: + options: options used to set up the pipeline. + model_class: Reference to the class definition of the model. + If None, MobilenetV2 will be used as default . + model_params: Parameters passed to the constructor of the model_class. + These will be used to instantiate the model object in the + RunInference API. + args: Command line arguments defined for this example. + """ + if not model_class: + model_class = MobileNetV2 + model_params = {'num_classes': 1000} + + model_loader = PytorchModelLoader( + state_dict_path=args.model_state_dict_path, + model_class=model_class, + model_params=model_params) + + with beam.Pipeline(options=options) as p: + filename_value_pair = ( + p + | 'ReadImageNames' >> beam.io.ReadFromText( + args.input, skip_header_lines=1) + | 'ReadImageData' >> beam.Map( + partial(read_image, path_to_dir=args.images_dir)) + | 'PreprocessImages' >> beam.MapTuple( + lambda file_name, data: (file_name, preprocess_image(data)))) + predictions = ( + filename_value_pair + | 'PyTorchRunInference' >> RunInference(model_loader).with_output_types( + Tuple[str, PredictionResult]) + | 'ProcessOutput' >> beam.ParDo(PostProcessor())) + + if args.output: + predictions | "WriteOutputToGCS" >> beam.io.WriteToText( # pylint: disable=expression-not-assigned + args.output, + shard_name_template='', + append_trailing_newlines=True) + + +def parse_known_args(argv): + """Parses args for the workflow.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--input', + dest='input', + default='gs://apache-beam-ml/testing/inputs/' + 'it_mobilenetv2_imagenet_validation_inputs.txt', + help='Path to the text file containing image names.') + parser.add_argument( + '--output', + dest='output', + help='Predictions are saved to the output' + ' text file.') + parser.add_argument( + '--model_state_dict_path', + dest='model_state_dict_path', + default='gs://apache-beam-ml/' + 'models/imagenet_classification_mobilenet_v2.pt', + help='Path to load the model\'s state_dict. ' Review Comment: ```suggestion help="Path to the model's state_dict." ``` Issue Time Tracking ------------------- Worklog Id: (was: 778252) Time Spent: 9h 40m (was: 9.5h) > RunInference Benchmarking tests > ------------------------------- > > Key: BEAM-14068 > URL: https://issues.apache.org/jira/browse/BEAM-14068 > Project: Beam > Issue Type: Sub-task > Components: sdk-py-core > Reporter: Anand Inguva > Assignee: Anand Inguva > Priority: P2 > Time Spent: 9h 40m > Remaining Estimate: 0h > > RunInference benchmarks will evaluate performance of Pipelines, which > represent common use cases of Beam + Dataflow in Pytorch, sklearn and > possibly TFX. These benchmarks would be the integration tests that exercise > several software components using Beam, PyTorch, Scikit learn and TensorFlow > extended. > we would use the datasets that's available publicly (Eg; Kaggle). > Size: small / 10 GB / 1 TB etc > The default execution runner would be Dataflow unless specified otherwise. > These tests would be run very less frequently(every release cycle). -- This message was sent by Atlassian Jira (v8.20.7#820007)