Copilot commented on code in PR #13059:
URL: https://github.com/apache/trafficserver/pull/13059#discussion_r3035855585


##########
src/proxy/http2/Http2ConnectionState.cc:
##########
@@ -76,6 +78,132 @@ const int buffer_size_index[HTTP2_FRAME_TYPE_MAX] = {
   BUFFER_SIZE_INDEX_16K, // HTTP2_FRAME_TYPE_CONTINUATION
 };
 
+/** Fetch a header field value from a decoded HTTP/2 request.
+ *
+ * @param[in] hdr The decoded request header block to inspect.
+ * @param[in] name The field name to look up.
+ * @return The field value, or an empty view if the field is absent.
+ */
+std::string_view
+get_header_field_value(HTTPHdr const &hdr, std::string_view name)
+{
+  if (auto *field = const_cast<HTTPHdr &>(hdr).field_find(name); field != 
nullptr) {

Review Comment:
   `get_header_field_value()` const-casts the `HTTPHdr` to call `field_find()`. 
Since `HTTPHdr` inherits `MIMEHdr`, you can call the `const` overload 
(`hdr.field_find(name)`) or use `hdr.value_get(name)` to avoid the const_cast 
and make the intent clearer.
   ```suggestion
     if (auto *field = hdr.field_find(name); field != nullptr) {
   ```



##########
tests/gold_tests/connect/replays/h2_malformed_request_logging.replay.yaml:
##########
@@ -0,0 +1,90 @@
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+# This replays healthy HTTP/2 requests to verify their logging.
+
+meta:
+  version: "1.0"
+
+sessions:
+  - protocol:
+      stack: http2
+      tls:
+        sni: www.example.com
+    transactions:
+      - client-request:
+          headers:
+            encoding: esc_json
+            fields:
+              - [ :method, GET ]
+              - [ :scheme, https ]
+              - [ :authority, www.example.com ]
+              - [ :path, /valid-get ]
+              - [ uuid, valid-get ]
+          content:
+            encoding: plain
+            size: 0
+
+        proxy-request:
+          method: GET
+
+        server-response:
+          status: 200
+          reason: OK
+          content:
+            encoding: plain
+            data: response_to_valid_get
+            size: 21
+
+        proxy-response:
+          status: 200
+          content:
+            verify: { value: "response_to_valid_get", as: contains }
+  - protocol:
+      stack: http2
+      tls:
+        sni: www.example.com
+    transactions:
+      - client-request:
+          frames:
+            - HEADERS:
+                headers:
+                  fields:
+                    - [:method, CONNECT]
+                    - [:authority, www.example.com:80]
+                    - [uuid, valid-connect]
+                    - [test, connect-request]
+            - DATA:
+                content:
+                  encoding: plain
+                  data: "GET /get HTTP/1.1\r\nuuid: valid-connect\r\ntest: 
real-request\r\n\r\n"
+
+        # Note: the server will received the tunneled GET request.

Review Comment:
   The comment has a grammar typo: “the server will received” should be “the 
server will receive”.
   ```suggestion
           # Note: the server will receive the tunneled GET request.
   ```



##########
src/proxy/logging/LogAccess.cc:
##########
@@ -1777,12 +1855,14 @@ LogAccess::marshal_proxy_protocol_tls_group(char *buf)
 int
 LogAccess::marshal_client_host_port(char *buf)
 {
-  if (m_http_sm) {
+  if (has_http_sm()) {
     auto txn = m_http_sm->get_ua_txn();
     if (txn) {
       uint16_t port = txn->get_client_port();
       marshal_int(buf, port);
     }
+  } else if (auto const *pre = this->get_pre_transaction_log_data(); pre != 
nullptr) {
+    marshal_int(buf, ats_ip_port_host_order(reinterpret_cast<sockaddr const 
*>(&pre->client_addr)));

Review Comment:
   `marshal_client_host_port()` calls `marshal_int(buf, ...)` even when `buf` 
is nullptr. Log marshaling commonly calls these marshal_* methods with 
`buf==nullptr` to compute the field length, so this will dereference null and 
crash. Add a `if (buf) { ... }` guard (similar to `marshal_remote_host_port()`) 
before writing.
   ```suggestion
     if (buf) {
       if (has_http_sm()) {
         auto txn = m_http_sm->get_ua_txn();
         if (txn) {
           uint16_t port = txn->get_client_port();
           marshal_int(buf, port);
         }
       } else if (auto const *pre = this->get_pre_transaction_log_data(); pre 
!= nullptr) {
         marshal_int(buf, ats_ip_port_host_order(reinterpret_cast<sockaddr 
const *>(&pre->client_addr)));
       }
   ```



##########
tests/gold_tests/connect/h2_malformed_request_logging.test.py:
##########
@@ -0,0 +1,207 @@
+'''
+Verify malformed HTTP/2 requests are access logged.
+'''
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+import os
+import re
+import sys
+
+Test.Summary = 'Malformed HTTP/2 requests are logged before transaction 
creation'
+
+
+class MalformedH2RequestLoggingTest:
+    """
+    Exercise malformed and valid HTTP/2 request logging paths.
+    """
+
+    REPLAY_FILE = 'replays/h2_malformed_request_logging.replay.yaml'
+    MALFORMED_CLIENT = 'malformed_h2_request_client.py'
+    MALFORMED_CASES = (
+        {
+            'scenario': 'connect-missing-authority',
+            'uuid': 'malformed-connect',
+            'method': 'CONNECT',
+            'pqu': '/',
+            'description': 'Send malformed HTTP/2 CONNECT request',
+        },
+        {
+            'scenario': 'get-missing-path',
+            'uuid': 'malformed-get-missing-path',
+            'method': 'GET',
+            'pqu': 'https://missing-path.example/',
+            'description': 'Send malformed HTTP/2 GET request without :path',
+        },
+        {
+            'scenario': 'get-connection-header',
+            'uuid': 'malformed-get-connection',
+            'method': 'GET',
+            'pqu': 'https://bad-connection.example/bad-connection',
+            'description': 'Send malformed HTTP/2 GET request with Connection 
header',
+        },
+    )
+
+    def __init__(self):
+        self._setup_server()
+        self._setup_ts()
+        self._processes_started = False
+        Test.Setup.CopyAs(self.MALFORMED_CLIENT, Test.RunDirectory)
+
+    @property
+    def _squid_log_path(self) -> str:
+        return os.path.join(self._ts.Variables.LOGDIR, 'squid.log')
+
+    def _setup_server(self):
+        self._server = 
Test.MakeVerifierServerProcess('malformed-request-server', self.REPLAY_FILE)
+        for case in self.MALFORMED_CASES:
+            self._server.Streams.stdout += Testers.ExcludesExpression(
+                f'uuid: {case["uuid"]}',
+                f'{case["description"]} must not reach the origin server.',
+            )
+        self._server.Streams.stdout += Testers.ContainsExpression(
+            'GET /get HTTP/1.1\nuuid: valid-connect',
+            reflags=re.MULTILINE,
+            description='A valid CONNECT tunnel should still reach the 
origin.',
+        )
+        self._server.Streams.stdout += Testers.ContainsExpression(
+            r'GET /valid-get HTTP/1\.1\n(?:.*\n)*uuid: valid-get',
+            reflags=re.MULTILINE,
+            description='A valid non-CONNECT request should still reach the 
origin.',
+        )
+
+    def _setup_ts(self):
+        self._ts = Test.MakeATSProcess('ts', enable_tls=True, 
enable_cache=False)
+        self._ts.addDefaultSSLFiles()
+        self._ts.Disk.File(
+            os.path.join(self._ts.Variables.CONFIGDIR, 'storage.config'),
+            id='storage_config',
+            typename='ats:config',
+        )
+        self._ts.Disk.storage_config.AddLine('')
+        self._ts.Disk.ssl_multicert_yaml.AddLines(
+            """
+ssl_multicert:
+  - dest_ip: "*"
+    ssl_cert_name: server.pem
+    ssl_key_name: server.key
+""".split('\n'))
+        self._ts.Disk.File(
+            os.path.join(self._ts.Variables.CONFIGDIR, 'ssl_multicert.config'),
+            id='ssl_multicert_config',
+            typename='ats:config',
+        )
+        self._ts.Disk.ssl_multicert_config.AddLine('ssl_cert_name=server.pem 
ssl_key_name=server.key dest_ip=*')
+
+        self._ts.Disk.records_config.update(
+            {
+                'proxy.config.diags.debug.enabled': 1,
+                'proxy.config.diags.debug.tags': 'http|hpack|http2',
+                'proxy.config.ssl.server.cert.path': self._ts.Variables.SSLDir,
+                'proxy.config.ssl.server.private_key.path': 
self._ts.Variables.SSLDir,
+                'proxy.config.http.server_ports': 
f'{self._ts.Variables.ssl_port}:ssl',
+                'proxy.config.http.connect_ports': 
self._server.Variables.http_port,
+            })
+        self._ts.Disk.remap_config.AddLine(f'map / 
http://127.0.0.1:{self._server.Variables.http_port}/')
+        self._ts.Disk.logging_yaml.AddLines(
+            """
+logging:
+  formats:
+    - name: malformed_h2_request
+      format: 'uuid=%<{uuid}cqh> cqpv=%<cqpv> cqhm=%<cqhm> crc=%<crc> 
sstc=%<sstc> pqu=%<pqu>'
+  logs:
+    - filename: squid
+      format: malformed_h2_request
+      mode: ascii
+""".split('\n'))
+        self._ts.Disk.diags_log.Content = Testers.ContainsExpression(
+            'recv headers malformed request',
+            'ATS should reject malformed requests at the HTTP/2 layer.',
+        )
+        for index, case in enumerate(self.MALFORMED_CASES):
+            expected = (
+                rf'uuid={case["uuid"]} cqpv=http/2 cqhm={case["method"]} '
+                rf'crc=ERR_INVALID_REQ sstc=0 pqu={re.escape(case["pqu"])}')
+            tester = Testers.ContainsExpression(
+                expected,
+                f'{case["description"]} should be logged with 
ERR_INVALID_REQ.',
+            )
+            if index == 0:
+                self._ts.Disk.squid_log.Content = tester
+            else:
+                self._ts.Disk.squid_log.Content += tester
+        self._ts.Disk.squid_log.Content += Testers.ContainsExpression(
+            r'uuid=valid-connect cqpv=http/2 cqhm=CONNECT ',
+            'A valid HTTP/2 CONNECT should still use the normal transaction 
log path.',
+        )
+        self._ts.Disk.squid_log.Content += Testers.ContainsExpression(
+            r'uuid=valid-get cqpv=http/2 cqhm=GET ',
+            'A valid HTTP/2 GET should still use the normal transaction log 
path.',
+        )
+        self._ts.Disk.squid_log.Content += Testers.ExcludesExpression(
+            r'uuid=valid-connect .*crc=ERR_INVALID_REQ',
+            'Valid HTTP/2 CONNECT logging must not be marked as malformed.',
+        )
+        self._ts.Disk.squid_log.Content += Testers.ExcludesExpression(
+            r'uuid=valid-get .*crc=ERR_INVALID_REQ',
+            'Valid HTTP/2 GET logging must not be marked as malformed.',
+        )
+
+    def _add_malformed_request_runs(self):
+        for case in self.MALFORMED_CASES:
+            tr = Test.AddTestRun(case['description'])
+            tr.Processes.Default.Command = (
+                f'{sys.executable} {self.MALFORMED_CLIENT} 
{self._ts.Variables.ssl_port} {case["scenario"]}')
+            tr.Processes.Default.ReturnCode = 0
+            self._keep_support_processes_running(tr)
+            tr.Processes.Default.Streams.stdout += Testers.ContainsExpression(
+                r'Received (RST_STREAM on stream 1 with error code 1|GOAWAY 
with error code [01])',
+                'ATS should reject the malformed request at the HTTP/2 layer.',
+            )
+
+    def _add_valid_request_run(self):
+        tr = Test.AddTestRun('Send valid HTTP/2 requests')
+        tr.AddVerifierClientProcess('valid-request-client', self.REPLAY_FILE, 
https_ports=[self._ts.Variables.ssl_port])
+        self._keep_support_processes_running(tr)
+
+    def _await_malformed_log_entries(self):
+        tr = Test.AddAwaitFileContainsTestRun(
+            'Await malformed request squid log entries',
+            self._squid_log_path,
+            'crc=ERR_INVALID_REQ',
+            desired_count=len(self.MALFORMED_CASES),
+        )
+        self._keep_support_processes_running(tr)
+
+    def _keep_support_processes_running(self, tr):
+        if self._processes_started:
+            tr.StillRunningAfter = self._server
+            tr.StillRunningAfter = self._ts
+            return
+
+        tr.Processes.Default.StartBefore(self._server)
+        tr.Processes.Default.StartBefore(self._ts)
+        tr.StillRunningAfter = self._server
+        tr.StillRunningAfter = self._ts

Review Comment:
   `StillRunningAfter` is assigned twice, so the second assignment overwrites 
the first and only one support process is kept running. Use `StillRunningAfter 
+= ...` when adding additional processes (as done in other gold tests).
   ```suggestion
               tr.StillRunningAfter += self._ts
               return
   
           tr.Processes.Default.StartBefore(self._server)
           tr.Processes.Default.StartBefore(self._ts)
           tr.StillRunningAfter = self._server
           tr.StillRunningAfter += self._ts
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to