[ 
https://issues.apache.org/jira/browse/BEAM-14337?focusedWorklogId=776377&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-776377
 ]

ASF GitHub Bot logged work on BEAM-14337:
-----------------------------------------

                Author: ASF GitHub Bot
            Created on: 31/May/22 14:50
            Start Date: 31/May/22 14:50
    Worklog Time Spent: 10m 
      Work Description: yeandy commented on code in PR #17470:
URL: https://github.com/apache/beam/pull/17470#discussion_r885738174


##########
sdks/python/apache_beam/ml/inference/pytorch_test.py:
##########
@@ -59,6 +63,23 @@ def forward(self, x):
     return out
 
 
+class PytorchLinearRegressionPredictionParams(torch.nn.Module):
+  def __init__(self, input_dim, output_dim):
+    super().__init__()
+    self.linear = torch.nn.Linear(input_dim, output_dim)
+
+  # k1 is the batched input, and prediction_param_array, prediction_param_bool
+  # are non-batchable inputs (typically model-related info) used to configure
+  # the model before its predict call is invoked
+  def forward(self, k1, k2, prediction_param_array, prediction_param_bool):

Review Comment:
   Fixed.



##########
sdks/python/apache_beam/ml/inference/pytorch_test.py:
##########
@@ -117,11 +135,104 @@ def test_inference_runner_multiple_tensor_features(self):
                      ('linear.bias', torch.Tensor([0.5]))]))
     model.eval()
 
+    inference_runner = PytorchInferenceRunner(torch.device('cpu'))
+    predictions = inference_runner.run_inference(examples, model)
+    for actual, expected in zip(predictions, expected_predictions):
+      self.assertEqual(actual, expected)
+
+  def test_inference_runner_kwargs(self):
+    examples = [
+        {
+            'k1': torch.from_numpy(np.array([1], dtype="float32")),
+            'k2': torch.from_numpy(np.array([1.5], dtype="float32"))
+        },
+        {
+            'k1': torch.from_numpy(np.array([5], dtype="float32")),
+            'k2': torch.from_numpy(np.array([5.5], dtype="float32"))
+        },
+        {
+            'k1': torch.from_numpy(np.array([-3], dtype="float32")),
+            'k2': torch.from_numpy(np.array([-3.5], dtype="float32"))
+        },
+        {
+            'k1': torch.from_numpy(np.array([10.0], dtype="float32")),
+            'k2': torch.from_numpy(np.array([10.5], dtype="float32"))
+        },
+    ]
+    expected_predictions = [
+        PredictionResult(ex, pred) for ex,
+        pred in zip(
+            examples,
+            torch.Tensor([(example['k1'] * 2.0 + 0.5) +
+                          (example['k2'] * 2.0 + 0.5)
+                          for example in examples]).reshape(-1, 1))
+    ]
+
+    class PytorchLinearRegressionMultipleArgs(torch.nn.Module):
+      def __init__(self, input_dim, output_dim):
+        super().__init__()
+        self.linear = torch.nn.Linear(input_dim, output_dim)
+
+      def forward(self, k1, k2):
+        out = self.linear(k1) + self.linear(k2)
+        return out
+
+    model = PytorchLinearRegressionMultipleArgs(input_dim=1, output_dim=1)
+    model.load_state_dict(
+        OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
+                     ('linear.bias', torch.Tensor([0.5]))]))
+    model.eval()
+
     inference_runner = PytorchInferenceRunner(torch.device('cpu'))
     predictions = inference_runner.run_inference(examples, model)
     for actual, expected in zip(predictions, expected_predictions):
       self.assertTrue(_compare_prediction_result(actual, expected))
 
+  def test_inference_runner_prediction_params(self):
+    examples = [

Review Comment:
   Refactored.



##########
sdks/python/apache_beam/ml/inference/pytorch_test.py:
##########
@@ -117,11 +135,104 @@ def test_inference_runner_multiple_tensor_features(self):
                      ('linear.bias', torch.Tensor([0.5]))]))
     model.eval()
 
+    inference_runner = PytorchInferenceRunner(torch.device('cpu'))
+    predictions = inference_runner.run_inference(examples, model)
+    for actual, expected in zip(predictions, expected_predictions):
+      self.assertEqual(actual, expected)
+
+  def test_inference_runner_kwargs(self):

Review Comment:
   Done.





Issue Time Tracking
-------------------

    Worklog Id:     (was: 776377)
    Time Spent: 7h 10m  (was: 7h)

> Support **kwargs for PyTorch models.
> ------------------------------------
>
>                 Key: BEAM-14337
>                 URL: https://issues.apache.org/jira/browse/BEAM-14337
>             Project: Beam
>          Issue Type: Sub-task
>          Components: sdk-py-core
>            Reporter: Anand Inguva
>            Assignee: Andy Ye
>            Priority: P2
>          Time Spent: 7h 10m
>  Remaining Estimate: 0h
>
> Some models in Pytorch instantiating from torch.nn.Module, has extra 
> parameters in the forward function call. These extra parameters can be passed 
> as Dict or as positional arguments. 
> Example of PyTorch models supported by Hugging Face -> 
> [https://huggingface.co/bert-base-uncased]
> [Some torch models on Hugging 
> face|https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py]
> Eg: 
> [https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel]
> {code:java}
> inputs = {
>      input_ids: Tensor1,
>      attention_mask: Tensor2,
>      token_type_ids: Tensor3,
> } 
> model = BertModel.from_pretrained("bert-base-uncased") # which is a  
> # subclass of torch.nn.Module
> outputs = model(**inputs) # model forward method should be expecting the keys 
> in the inputs as the positional arguments.{code}
>  
> [Transformers|https://pytorch.org/hub/huggingface_pytorch-transformers/] 
> integrated in Pytorch is supported by Hugging Face as well. 
>  



--
This message was sent by Atlassian Jira
(v8.20.7#820007)

Reply via email to