[ https://issues.apache.org/jira/browse/BEAM-14337?focusedWorklogId=775543&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-775543 ]
ASF GitHub Bot logged work on BEAM-14337: ----------------------------------------- Author: ASF GitHub Bot Created on: 27/May/22 17:53 Start Date: 27/May/22 17:53 Worklog Time Spent: 10m Work Description: yeandy commented on code in PR #17470: URL: https://github.com/apache/beam/pull/17470#discussion_r883863798 ########## sdks/python/apache_beam/ml/inference/pytorch_test.py: ########## @@ -43,10 +43,14 @@ raise unittest.SkipTest('PyTorch dependencies are not installed') -def _compare_prediction_result(a, b): - return ( - torch.equal(a.inference, b.inference) and - torch.equal(a.example, b.example)) +def _compare_prediction_result(x, y): Review Comment: As opposed to the other pattern ``` for actual, expected in zip(predictions, expected_predictions): self.assertTrue(_compare_prediction_result(actual, expected)) ``` which will won't return anything meaningful besides saying that `False` is not `True` (which I've already fixed), the call ``` assert_that( predictions, equal_to(expected_predictions, equals_fn=_compare_prediction_result)) ``` will output the details of how `a` and `b` are not equal. So I will keep this function name the same. Issue Time Tracking ------------------- Worklog Id: (was: 775543) Time Spent: 6h 20m (was: 6h 10m) > Support **kwargs for PyTorch models. > ------------------------------------ > > Key: BEAM-14337 > URL: https://issues.apache.org/jira/browse/BEAM-14337 > Project: Beam > Issue Type: Sub-task > Components: sdk-py-core > Reporter: Anand Inguva > Assignee: Andy Ye > Priority: P2 > Time Spent: 6h 20m > Remaining Estimate: 0h > > Some models in Pytorch instantiating from torch.nn.Module, has extra > parameters in the forward function call. These extra parameters can be passed > as Dict or as positional arguments. > Example of PyTorch models supported by Hugging Face -> > [https://huggingface.co/bert-base-uncased] > [Some torch models on Hugging > face|https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py] > Eg: > [https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel] > {code:java} > inputs = { > input_ids: Tensor1, > attention_mask: Tensor2, > token_type_ids: Tensor3, > } > model = BertModel.from_pretrained("bert-base-uncased") # which is a > # subclass of torch.nn.Module > outputs = model(**inputs) # model forward method should be expecting the keys > in the inputs as the positional arguments.{code} > > [Transformers|https://pytorch.org/hub/huggingface_pytorch-transformers/] > integrated in Pytorch is supported by Hugging Face as well. > -- This message was sent by Atlassian Jira (v8.20.7#820007)