This is an automated email from the ASF dual-hosted git repository.

zhangliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 97e6be5a19c Add more test cases on OpenGaussPacketCodecEngine (#38194)
97e6be5a19c is described below

commit 97e6be5a19c0a3dd0c5ab8c142877e9af00568fe
Author: Liang Zhang <[email protected]>
AuthorDate: Wed Feb 25 17:52:13 2026 +0800

    Add more test cases on OpenGaussPacketCodecEngine (#38194)
---
 .../codec/OpenGaussPacketCodecEngineTest.java      | 222 ++++++++++++++++-----
 1 file changed, 169 insertions(+), 53 deletions(-)

diff --git 
a/database/protocol/dialect/opengauss/src/test/java/org/apache/shardingsphere/database/protocol/opengauss/codec/OpenGaussPacketCodecEngineTest.java
 
b/database/protocol/dialect/opengauss/src/test/java/org/apache/shardingsphere/database/protocol/opengauss/codec/OpenGaussPacketCodecEngineTest.java
index bc510ea4bda..b65faf3c898 100644
--- 
a/database/protocol/dialect/opengauss/src/test/java/org/apache/shardingsphere/database/protocol/opengauss/codec/OpenGaussPacketCodecEngineTest.java
+++ 
b/database/protocol/dialect/opengauss/src/test/java/org/apache/shardingsphere/database/protocol/opengauss/codec/OpenGaussPacketCodecEngineTest.java
@@ -18,30 +18,37 @@
 package org.apache.shardingsphere.database.protocol.opengauss.codec;
 
 import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.Unpooled;
 import io.netty.channel.ChannelHandlerContext;
-import io.netty.util.AttributeKey;
-import 
org.apache.shardingsphere.database.protocol.postgresql.packet.PostgreSQLPacket;
+import lombok.SneakyThrows;
+import org.apache.shardingsphere.database.protocol.constant.CommonConstants;
+import org.apache.shardingsphere.database.protocol.packet.DatabasePacket;
 import 
org.apache.shardingsphere.database.protocol.postgresql.packet.identifier.PostgreSQLIdentifierPacket;
 import 
org.apache.shardingsphere.database.protocol.postgresql.packet.identifier.PostgreSQLMessagePacketType;
 import 
org.apache.shardingsphere.database.protocol.postgresql.payload.PostgreSQLPacketPayload;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
 import org.mockito.Answers;
 import org.mockito.Mock;
+import org.mockito.internal.configuration.plugins.Plugins;
 import org.mockito.junit.jupiter.MockitoExtension;
 import org.mockito.junit.jupiter.MockitoSettings;
 import org.mockito.quality.Strictness;
 
-import java.nio.charset.Charset;
 import java.nio.charset.StandardCharsets;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.stream.Stream;
 
-import static org.hamcrest.Matchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.hamcrest.Matchers.is;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -51,76 +58,185 @@ import static org.mockito.Mockito.when;
 @MockitoSettings(strictness = Strictness.LENIENT)
 class OpenGaussPacketCodecEngineTest {
     
+    private static final int SSL_REQUEST_PAYLOAD_LENGTH = 8;
+    
+    private static final int SSL_REQUEST_CODE = (1234 << 16) + 5679;
+    
     @Mock(answer = Answers.RETURNS_DEEP_STUBS)
     private ChannelHandlerContext context;
     
-    @Mock
-    private ByteBuf byteBuf;
-    
     @BeforeEach
-    void setup() {
-        
when(context.channel().attr(AttributeKey.<Charset>valueOf(Charset.class.getName())).get()).thenReturn(StandardCharsets.UTF_8);
+    void setUp() {
+        
when(context.channel().attr(CommonConstants.CHARSET_ATTRIBUTE_KEY).get()).thenReturn(StandardCharsets.UTF_8);
+        when(context.alloc().compositeBuffer(anyInt())).thenAnswer(invocation 
-> Unpooled.compositeBuffer(invocation.getArgument(0)));
     }
     
-    @Test
-    void assertIsValidHeader() {
-        assertTrue(new OpenGaussPacketCodecEngine().isValidHeader(50));
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("validHeaderCases")
+    void assertIsValidHeader(final String name, final int readableBytes, final 
boolean expectedValid) {
+        OpenGaussPacketCodecEngine codecEngine = new 
OpenGaussPacketCodecEngine();
+        assertThat(codecEngine.isValidHeader(readableBytes), 
is(expectedValid));
     }
     
-    @Test
-    void assertIsInvalidHeader() {
-        assertTrue(new OpenGaussPacketCodecEngine().isValidHeader(4));
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("validHeaderWhenStartupCompletedCases")
+    void assertIsValidHeaderWhenStartupCompleted(final String name, final int 
readableBytes, final boolean expectedValid) {
+        OpenGaussPacketCodecEngine codecEngine = new 
OpenGaussPacketCodecEngine();
+        setStartupPhase(codecEngine, false);
+        assertThat(codecEngine.isValidHeader(readableBytes), 
is(expectedValid));
     }
     
-    @Test
-    void assertDecode() {
-        when(byteBuf.readableBytes()).thenReturn(51, 47, 0);
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("decodeStartupPhaseCases")
+    void assertDecodeStartupPhase(final String name, final ByteBuf packet, 
final int expectedOutSize, final boolean expectedStartupPhase) {
+        OpenGaussPacketCodecEngine codecEngine = new 
OpenGaussPacketCodecEngine();
         List<Object> out = new LinkedList<>();
-        new OpenGaussPacketCodecEngine().decode(context, byteBuf, out);
-        assertThat(out.size(), is(1));
+        codecEngine.decode(context, packet, out);
+        assertThat(out.size(), is(expectedOutSize));
+        assertThat(getStartupPhase(codecEngine), is(expectedStartupPhase));
     }
     
-    @Test
-    void assertDecodeWithStickyPacket() {
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("decodeWithPreparedStateCases")
+    void assertDecodeWithPreparedState(final String name, final ByteBuf 
packet, final int initialPendingMessages, final int expectedOutSize,
+                                       final int expectedPendingMessages, 
final boolean expectedComposite, final int expectedResultReadableBytes) {
+        OpenGaussPacketCodecEngine codecEngine = new 
OpenGaussPacketCodecEngine();
+        setStartupPhase(codecEngine, false);
+        for (int i = 0; i < initialPendingMessages; i++) {
+            getPendingMessages(codecEngine).add(createCommandPacket('P', 4));
+        }
         List<Object> out = new LinkedList<>();
-        new OpenGaussPacketCodecEngine().decode(context, byteBuf, out);
-        assertTrue(out.isEmpty());
+        codecEngine.decode(context, packet, out);
+        assertThat(out.size(), is(expectedOutSize));
+        assertThat(getPendingMessages(codecEngine).size(), 
is(expectedPendingMessages));
+        if (0 < expectedOutSize) {
+            assertThat(out.get(0) instanceof CompositeByteBuf, 
is(expectedComposite));
+            assertThat(((ByteBuf) out.get(0)).readableBytes(), 
is(expectedResultReadableBytes));
+        }
     }
     
-    @Test
-    void assertEncodePostgreSQLPacket() {
-        PostgreSQLPacket packet = mock(PostgreSQLPacket.class);
-        new OpenGaussPacketCodecEngine().encode(context, packet, byteBuf);
-        verify(packet).write(any(PostgreSQLPacketPayload.class));
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("decodeWithoutPendingMessagesCases")
+    void assertDecodeWithoutPendingMessages(final String name, final char 
commandType) {
+        OpenGaussPacketCodecEngine codecEngine = new 
OpenGaussPacketCodecEngine();
+        setStartupPhase(codecEngine, false);
+        List<Object> out = new LinkedList<>();
+        codecEngine.decode(context, createCommandPacket(commandType, 4), out);
+        assertThat(out.size(), is(1));
+        assertThat(((ByteBuf) out.get(0)).readableBytes(), is(5));
     }
     
-    @Test
-    void assertEncodePostgreSQLIdentifierPacket() {
-        PostgreSQLIdentifierPacket packet = 
mock(PostgreSQLIdentifierPacket.class);
-        
when(packet.getIdentifier()).thenReturn(PostgreSQLMessagePacketType.AUTHENTICATION_REQUEST);
-        when(byteBuf.readableBytes()).thenReturn(9);
-        new OpenGaussPacketCodecEngine().encode(context, packet, byteBuf);
-        
verify(byteBuf).writeByte(PostgreSQLMessagePacketType.AUTHENTICATION_REQUEST.getValue());
-        verify(byteBuf).writeInt(0);
-        verify(packet).write(any(PostgreSQLPacketPayload.class));
-        verify(byteBuf).setInt(1, 8);
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("encodeCases")
+    void assertEncode(final String name, final boolean identifierPacket, final 
boolean writeException, final boolean expectedHeader, final char 
expectedIdentifier) {
+        OpenGaussPacketCodecEngine codecEngine = new 
OpenGaussPacketCodecEngine();
+        DatabasePacket message = createPacket(identifierPacket);
+        if (writeException) {
+            doThrow(new 
RuntimeException("Error")).when(message).write(any(PostgreSQLPacketPayload.class));
+        }
+        ByteBuf out = Unpooled.buffer();
+        codecEngine.encode(context, message, out);
+        verify(message).write(any(PostgreSQLPacketPayload.class));
+        assertThat(out.readableBytes() > 0, is(expectedHeader));
+        if (expectedHeader) {
+            assertThat((char) out.getByte(0), is(expectedIdentifier));
+            assertThat(out.getInt(1), is(out.readableBytes() - 1));
+        }
     }
     
     @Test
-    void assertEncodeOccursException() {
-        PostgreSQLPacket packet = mock(PostgreSQLPacket.class);
-        RuntimeException ex = mock(RuntimeException.class);
-        when(ex.getMessage()).thenReturn("Error");
-        doThrow(ex).when(packet).write(any(PostgreSQLPacketPayload.class));
-        when(byteBuf.readableBytes()).thenReturn(9);
-        new OpenGaussPacketCodecEngine().encode(context, packet, byteBuf);
-        verify(byteBuf).resetWriterIndex();
-        
verify(byteBuf).writeByte(PostgreSQLMessagePacketType.ERROR_RESPONSE.getValue());
-        verify(byteBuf).setInt(1, 8);
+    void assertCreatePacketPayload() {
+        OpenGaussPacketCodecEngine codecEngine = new 
OpenGaussPacketCodecEngine();
+        ByteBuf message = Unpooled.buffer();
+        assertThat(codecEngine.createPacketPayload(message, 
StandardCharsets.UTF_8).getByteBuf(), is(message));
     }
     
-    @Test
-    void assertCreatePacketPayload() {
-        assertThat(new 
OpenGaussPacketCodecEngine().createPacketPayload(byteBuf, 
StandardCharsets.UTF_8).getByteBuf(), is(byteBuf));
+    private DatabasePacket createPacket(final boolean identifierPacket) {
+        if (identifierPacket) {
+            PostgreSQLIdentifierPacket result = 
mock(PostgreSQLIdentifierPacket.class);
+            
when(result.getIdentifier()).thenReturn(PostgreSQLMessagePacketType.AUTHENTICATION_REQUEST);
+            return result;
+        }
+        return mock(DatabasePacket.class);
+    }
+    
+    @SneakyThrows(ReflectiveOperationException.class)
+    private void setStartupPhase(final OpenGaussPacketCodecEngine codecEngine, 
final boolean startupPhase) {
+        
Plugins.getMemberAccessor().set(OpenGaussPacketCodecEngine.class.getDeclaredField("startupPhase"),
 codecEngine, startupPhase);
+    }
+    
+    @SneakyThrows(ReflectiveOperationException.class)
+    private boolean getStartupPhase(final OpenGaussPacketCodecEngine 
codecEngine) {
+        return (boolean) 
Plugins.getMemberAccessor().get(OpenGaussPacketCodecEngine.class.getDeclaredField("startupPhase"),
 codecEngine);
+    }
+    
+    @SuppressWarnings("unchecked")
+    @SneakyThrows(ReflectiveOperationException.class)
+    private List<ByteBuf> getPendingMessages(final OpenGaussPacketCodecEngine 
codecEngine) {
+        return (List<ByteBuf>) 
Plugins.getMemberAccessor().get(OpenGaussPacketCodecEngine.class.getDeclaredField("pendingMessages"),
 codecEngine);
+    }
+    
+    private static ByteBuf createStartupPacket(final int length, final int 
code) {
+        ByteBuf result = Unpooled.buffer(SSL_REQUEST_PAYLOAD_LENGTH);
+        result.writeInt(length);
+        result.writeInt(code);
+        return result;
+    }
+    
+    private static ByteBuf createStartupPacketWithAdditionalByte(final int 
length, final int code) {
+        ByteBuf result = createStartupPacket(length, code);
+        result.writeByte(0);
+        return result;
+    }
+    
+    private static ByteBuf createCommandPacket(final char commandType, final 
int payloadLength) {
+        ByteBuf result = Unpooled.buffer(1 + Integer.BYTES);
+        result.writeByte(commandType);
+        result.writeInt(payloadLength);
+        return result;
+    }
+    
+    private static Stream<Arguments> validHeaderCases() {
+        return Stream.of(
+                Arguments.of("startup phase: less than minimum header", 3, 
false),
+                Arguments.of("startup phase: equal minimum header", 4, true),
+                Arguments.of("startup phase: greater than minimum header", 8, 
true));
+    }
+    
+    private static Stream<Arguments> validHeaderWhenStartupCompletedCases() {
+        return Stream.of(
+                Arguments.of("non-startup phase: less than minimum header", 4, 
false),
+                Arguments.of("non-startup phase: equal minimum header", 5, 
true),
+                Arguments.of("non-startup phase: greater than minimum header", 
9, true));
+    }
+    
+    private static Stream<Arguments> decodeStartupPhaseCases() {
+        return Stream.of(
+                Arguments.of("decode ssl request packet", 
createStartupPacket(SSL_REQUEST_PAYLOAD_LENGTH, SSL_REQUEST_CODE), 1, true),
+                Arguments.of("decode startup packet and enter command phase", 
createStartupPacket(SSL_REQUEST_PAYLOAD_LENGTH, 1), 1, false),
+                Arguments.of("decode startup packet with mismatched declared 
length", createStartupPacket(SSL_REQUEST_PAYLOAD_LENGTH - 1, 1), 0, true),
+                Arguments.of("decode startup packet with additional payload 
byte", createStartupPacketWithAdditionalByte(SSL_REQUEST_PAYLOAD_LENGTH + 1, 
1), 1, false));
+    }
+    
+    private static Stream<Arguments> decodeWithoutPendingMessagesCases() {
+        return Stream.of(
+                Arguments.of("decode simple query command", 'Q'),
+                Arguments.of("decode sync command", 'S'),
+                Arguments.of("decode flush command", 'H'));
+    }
+    
+    private static Stream<Arguments> decodeWithPreparedStateCases() {
+        return Stream.of(
+                Arguments.of("decode with invalid header", 
Unpooled.wrappedBuffer(new byte[4]), 0, 0, 0, false, 0),
+                Arguments.of("decode with incomplete payload", 
createCommandPacket('Q', 8), 0, 0, 0, false, 0),
+                Arguments.of("decode with aggregation command", 
createCommandPacket('P', 4), 0, 0, 1, false, 0),
+                Arguments.of("decode with pending messages", 
createCommandPacket('Q', 4), 1, 1, 0, true, 10));
+    }
+    
+    private static Stream<Arguments> encodeCases() {
+        return Stream.of(
+                Arguments.of("encode non identifier packet", false, false, 
false, '\0'),
+                Arguments.of("encode identifier packet", true, false, true, 
PostgreSQLMessagePacketType.AUTHENTICATION_REQUEST.getValue()),
+                Arguments.of("encode packet with write exception", false, 
true, true, PostgreSQLMessagePacketType.ERROR_RESPONSE.getValue()));
     }
 }

Reply via email to