This is an automated email from the ASF dual-hosted git repository.
zhangliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 8ea0286a040 Add more test cases on MySQLNegotiateHandlerTest (#37446)
8ea0286a040 is described below
commit 8ea0286a04034d2f356601e041076dbd47ea6746
Author: Liang Zhang <[email protected]>
AuthorDate: Sat Dec 20 17:29:32 2025 +0800
Add more test cases on MySQLNegotiateHandlerTest (#37446)
* Add more test cases on MySQLNegotiateHandlerTest
* Add more test cases on MySQLNegotiateHandlerTest
* Add more test cases on MySQLNegotiateHandlerTest
---
.../client/netty/MySQLNegotiateHandlerTest.java | 140 ++++++++++++++++++---
1 file changed, 125 insertions(+), 15 deletions(-)
diff --git
a/kernel/data-pipeline/dialect/mysql/src/test/java/org/apache/shardingsphere/data/pipeline/mysql/ingest/incremental/client/netty/MySQLNegotiateHandlerTest.java
b/kernel/data-pipeline/dialect/mysql/src/test/java/org/apache/shardingsphere/data/pipeline/mysql/ingest/incremental/client/netty/MySQLNegotiateHandlerTest.java
index e21899daef6..883bcbfdd07 100644
---
a/kernel/data-pipeline/dialect/mysql/src/test/java/org/apache/shardingsphere/data/pipeline/mysql/ingest/incremental/client/netty/MySQLNegotiateHandlerTest.java
+++
b/kernel/data-pipeline/dialect/mysql/src/test/java/org/apache/shardingsphere/data/pipeline/mysql/ingest/incremental/client/netty/MySQLNegotiateHandlerTest.java
@@ -21,43 +21,55 @@ import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.util.concurrent.Promise;
+import lombok.SneakyThrows;
+import
org.apache.shardingsphere.data.pipeline.core.exception.PipelineInternalException;
import
org.apache.shardingsphere.data.pipeline.mysql.ingest.incremental.client.MySQLServerVersion;
+import
org.apache.shardingsphere.data.pipeline.mysql.ingest.incremental.client.PasswordEncryption;
import
org.apache.shardingsphere.database.exception.mysql.vendor.MySQLVendorError;
import
org.apache.shardingsphere.database.protocol.mysql.constant.MySQLAuthenticationMethod;
+import
org.apache.shardingsphere.database.protocol.mysql.constant.MySQLAuthenticationPlugin;
+import
org.apache.shardingsphere.database.protocol.mysql.constant.MySQLCapabilityFlag;
import
org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLErrPacket;
import
org.apache.shardingsphere.database.protocol.mysql.packet.generic.MySQLOKPacket;
+import
org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLAuthMoreDataPacket;
+import
org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLAuthSwitchRequestPacket;
+import
org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLAuthSwitchResponsePacket;
import
org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLAuthenticationPluginData;
import
org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakePacket;
import
org.apache.shardingsphere.database.protocol.mysql.packet.handshake.MySQLHandshakeResponse41Packet;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.internal.configuration.plugins.Plugins;
import org.mockito.junit.jupiter.MockitoExtension;
-import org.mockito.junit.jupiter.MockitoSettings;
-import org.mockito.quality.Strictness;
+import java.security.NoSuchAlgorithmException;
import java.sql.SQLException;
+import java.util.stream.Stream;
import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.lenient;
+import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
-@MockitoSettings(strictness = Strictness.LENIENT)
class MySQLNegotiateHandlerTest {
private static final String USER_NAME = "username";
private static final String PASSWORD = "password";
- @Mock
- private Promise<Object> authResultCallback;
-
@Mock
private ChannelHandlerContext channelHandlerContext;
@@ -67,25 +79,46 @@ class MySQLNegotiateHandlerTest {
@Mock
private ChannelPipeline pipeline;
+ @Mock
+ private Promise<Object> authResultCallback;
+
private MySQLNegotiateHandler mysqlNegotiateHandler;
@BeforeEach
void setUp() {
- when(channelHandlerContext.channel()).thenReturn(channel);
- when(channel.pipeline()).thenReturn(pipeline);
+ lenient().when(channelHandlerContext.channel()).thenReturn(channel);
+ lenient().when(channel.pipeline()).thenReturn(pipeline);
mysqlNegotiateHandler = new MySQLNegotiateHandler(USER_NAME, PASSWORD,
authResultCallback);
}
- @Test
- void assertChannelReadHandshakeInitPacket() throws
ReflectiveOperationException {
+ @ParameterizedTest
+ @MethodSource("handshakeParams")
+ void assertChannelReadHandshakeInitPacket(final String password, final
boolean expectEmptyAuth) throws ReflectiveOperationException {
+ MySQLNegotiateHandler handler = new MySQLNegotiateHandler(USER_NAME,
password, authResultCallback);
MySQLHandshakePacket handshakePacket = new MySQLHandshakePacket(0,
false, new MySQLAuthenticationPluginData(new byte[8], new byte[12]));
handshakePacket.setAuthPluginName(MySQLAuthenticationMethod.NATIVE);
- mysqlNegotiateHandler.channelRead(channelHandlerContext,
handshakePacket);
-
verify(channel).writeAndFlush(ArgumentMatchers.any(MySQLHandshakeResponse41Packet.class));
- MySQLServerVersion serverVersion = (MySQLServerVersion)
Plugins.getMemberAccessor().get(MySQLNegotiateHandler.class.getDeclaredField("serverVersion"),
mysqlNegotiateHandler);
+ handler.channelRead(channelHandlerContext, handshakePacket);
+ MySQLServerVersion serverVersion = (MySQLServerVersion)
Plugins.getMemberAccessor().get(MySQLNegotiateHandler.class.getDeclaredField("serverVersion"),
handler);
assertThat(Plugins.getMemberAccessor().get(MySQLServerVersion.class.getDeclaredField("major"),
serverVersion), is(5));
assertThat(Plugins.getMemberAccessor().get(MySQLServerVersion.class.getDeclaredField("minor"),
serverVersion), is(7));
assertThat(Plugins.getMemberAccessor().get(MySQLServerVersion.class.getDeclaredField("series"),
serverVersion), is(22));
+ ArgumentCaptor<MySQLHandshakeResponse41Packet> responseCaptor =
ArgumentCaptor.forClass(MySQLHandshakeResponse41Packet.class);
+ verify(channel).writeAndFlush(responseCaptor.capture());
+ MySQLHandshakeResponse41Packet actualResponse =
responseCaptor.getValue();
+ assertThat(actualResponse.getAuthResponse().length, expectEmptyAuth ?
is(0) : not(0));
+ assertThat(actualResponse.getCapabilityFlags(),
is(calculateExpectedCapabilities()));
+ }
+
+ private int calculateExpectedCapabilities() {
+ return MySQLCapabilityFlag.calculateCapabilityFlags(
+ MySQLCapabilityFlag.CLIENT_LONG_PASSWORD,
+ MySQLCapabilityFlag.CLIENT_LONG_FLAG,
+ MySQLCapabilityFlag.CLIENT_PROTOCOL_41,
+ MySQLCapabilityFlag.CLIENT_INTERACTIVE,
+ MySQLCapabilityFlag.CLIENT_TRANSACTIONS,
+ MySQLCapabilityFlag.CLIENT_SECURE_CONNECTION,
+ MySQLCapabilityFlag.CLIENT_MULTI_STATEMENTS,
+ MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH);
}
@Test
@@ -102,6 +135,83 @@ class MySQLNegotiateHandlerTest {
void assertChannelReadErrorPacket() {
MySQLErrPacket errorPacket = new MySQLErrPacket(
new SQLException(MySQLVendorError.ER_NO_DB_ERROR.getReason(),
MySQLVendorError.ER_NO_DB_ERROR.getSqlState().getValue(),
MySQLVendorError.ER_NO_DB_ERROR.getVendorCode()));
- assertThrows(RuntimeException.class, () ->
mysqlNegotiateHandler.channelRead(channelHandlerContext, errorPacket));
+ assertThrows(PipelineInternalException.class, () ->
mysqlNegotiateHandler.channelRead(channelHandlerContext, errorPacket));
+ verify(channel).close();
+ }
+
+ @Test
+ void assertHandleCachingSha2RequestPublicKey() throws
ReflectiveOperationException {
+ mysqlNegotiateHandler.channelRead(channelHandlerContext, new
MySQLAuthMoreDataPacket(new byte[]{4}));
+ assertThat(captureAuthSwitchResponse().getAuthPluginResponse()[0],
is((byte) 2));
+ assertTrue((boolean)
Plugins.getMemberAccessor().get(MySQLNegotiateHandler.class.getDeclaredField("publicKeyRequested"),
mysqlNegotiateHandler));
+ }
+
+ @Test
+ void assertHandleCachingSha2SkipPublicKeyRequest() throws
ReflectiveOperationException {
+ mysqlNegotiateHandler.channelRead(channelHandlerContext, new
MySQLAuthMoreDataPacket(new byte[]{1}));
+ verify(channel, never()).writeAndFlush(ArgumentMatchers.any());
+ assertFalse((Boolean)
Plugins.getMemberAccessor().get(MySQLNegotiateHandler.class.getDeclaredField("publicKeyRequested"),
mysqlNegotiateHandler));
+ }
+
+ @ParameterizedTest
+ @MethodSource("authSwitchParams")
+ void assertChannelReadAuthSwitchRequestForPlugin(final String pluginName,
final byte[] expectedResponse) {
+ mysqlNegotiateHandler.channelRead(channelHandlerContext, new
MySQLAuthSwitchRequestPacket(pluginName, new
MySQLAuthenticationPluginData(seedBytesPart1(), seedBytesPart2())));
+ assertThat(captureAuthSwitchResponse().getAuthPluginResponse(),
is(expectedResponse));
+ }
+
+ @ParameterizedTest
+ @MethodSource("publicKeyEncryptParams")
+ void assertHandleCachingSha2EncryptWithPublicKey(final String
serverVersion) throws ReflectiveOperationException {
+
Plugins.getMemberAccessor().set(MySQLNegotiateHandler.class.getDeclaredField("publicKeyRequested"),
mysqlNegotiateHandler, true);
+
Plugins.getMemberAccessor().set(MySQLNegotiateHandler.class.getDeclaredField("serverVersion"),
mysqlNegotiateHandler, new MySQLServerVersion(serverVersion));
+
Plugins.getMemberAccessor().set(MySQLNegotiateHandler.class.getDeclaredField("seed"),
mysqlNegotiateHandler, authenticationPluginSeed());
+ mysqlNegotiateHandler.channelRead(channelHandlerContext, new
MySQLAuthMoreDataPacket(publicKey().getBytes()));
+ MySQLAuthSwitchResponsePacket response = captureAuthSwitchResponse();
+ assertTrue(response.getAuthPluginResponse().length > 0);
+ }
+
+ private String publicKey() {
+ return "-----BEGIN PUBLIC KEY-----\n"
+ +
"MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAL8eq+i+GtqR4344d18PT9bjK5YfX/8r\n"
+ +
"O8uRAZ3kKQEiC5EvhczHxVn9Yx8RJWb1x1oGf4bm/FYnGV8eK3opg+cCAwEAAQ==\n"
+ + "-----END PUBLIC KEY-----";
+ }
+
+ private MySQLAuthSwitchResponsePacket captureAuthSwitchResponse() {
+ ArgumentCaptor<MySQLAuthSwitchResponsePacket> responseCaptor =
ArgumentCaptor.forClass(MySQLAuthSwitchResponsePacket.class);
+ verify(channel).writeAndFlush(responseCaptor.capture());
+ return responseCaptor.getValue();
+ }
+
+ private static Stream<Arguments> handshakeParams() {
+ return Stream.of(Arguments.of(PASSWORD, false), Arguments.of("",
true));
+ }
+
+ @SneakyThrows(NoSuchAlgorithmException.class)
+ private static Stream<Arguments> authSwitchParams() {
+ return Stream.of(
+ Arguments.of(MySQLAuthenticationPlugin.NATIVE.getPluginName(),
PasswordEncryption.encryptWithMySQL41(PASSWORD.getBytes(),
authenticationPluginSeed())),
+
Arguments.of(MySQLAuthenticationPlugin.CACHING_SHA2.getPluginName(),
PasswordEncryption.encryptWithSha2(PASSWORD.getBytes(),
authenticationPluginSeed())),
+ Arguments.of(MySQLAuthenticationPlugin.SHA256.getPluginName(),
PASSWORD.getBytes()));
+ }
+
+ private static byte[] authenticationPluginSeed() {
+ byte[] result = new byte[seedBytesPart1().length +
seedBytesPart2().length];
+ System.arraycopy(seedBytesPart1(), 0, result, 0,
seedBytesPart1().length);
+ System.arraycopy(seedBytesPart2(), 0, result, seedBytesPart1().length,
seedBytesPart2().length);
+ return result;
+ }
+
+ private static byte[] seedBytesPart1() {
+ return new byte[]{0, 1, 2, 3, 4, 5, 6, 7};
+ }
+
+ private static byte[] seedBytesPart2() {
+ return new byte[]{8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19};
+ }
+
+ private static Stream<Arguments> publicKeyEncryptParams() {
+ return Stream.of(Arguments.of("8.0.5"), Arguments.of("5.7.22"));
}
}