This is an automated email from the ASF dual-hosted git repository.
tuichenchuxin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 5e715611f14 Fix gsql 3.0 may be stuck when connecting Proxy (#21869)
5e715611f14 is described below
commit 5e715611f140db6fe298d7c2d4ba450c60ab3fa6
Author: 吴伟杰 <[email protected]>
AuthorDate: Tue Nov 1 10:10:08 2022 +0800
Fix gsql 3.0 may be stuck when connecting Proxy (#21869)
* Fix gsql 3.0 may be stuck when connecting Proxy
* Complete OpenGaussAuthenticationSCRAMSha256PacketTest
* Refactor OpenGaussAuthenticationHandler
* Complete OpenGaussAuthenticationEngineTest
* Generate server signature even if user unknown
---
.../constant/OpenGaussProtocolVersion.java | 35 +++++++++++++++++++++
.../OpenGaussAuthenticationSCRAMSha256Packet.java | 12 +++++++-
...enGaussAuthenticationSCRAMSha256PacketTest.java | 36 +++++++++++++++++++---
.../handshake/PostgreSQLComStartupPacket.java | 7 ++++-
.../OpenGaussAuthenticationEngine.java | 30 +++++++++++++++---
.../OpenGaussAuthenticationHandler.java | 31 +++++++++++++++++++
.../OpenGaussAuthenticationEngineTest.java | 8 ++---
7 files changed, 144 insertions(+), 15 deletions(-)
diff --git
a/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/constant/OpenGaussProtocolVersion.java
b/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/constant/OpenGaussProtocolVersion.java
new file mode 100644
index 00000000000..a09f1fcfd73
--- /dev/null
+++
b/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/constant/OpenGaussProtocolVersion.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.db.protocol.opengauss.constant;
+
+import lombok.Getter;
+import lombok.RequiredArgsConstructor;
+
+/**
+ * Protocol version of openGauss.
+ */
+@RequiredArgsConstructor
+@Getter
+public enum OpenGaussProtocolVersion {
+
+ PROTOCOL_350(3 << 16 | 50),
+
+ PROTOCOL_351(3 << 16 | 51);
+
+ private final int version;
+}
diff --git
a/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256Packet.java
b/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256Packet.java
index 19c1b1649f2..6dcb4e76a15 100644
---
a/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256Packet.java
+++
b/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256Packet.java
@@ -18,6 +18,7 @@
package org.apache.shardingsphere.db.protocol.opengauss.packet.authentication;
import lombok.RequiredArgsConstructor;
+import
org.apache.shardingsphere.db.protocol.opengauss.constant.OpenGaussProtocolVersion;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierPacket;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierTag;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLMessagePacketType;
@@ -33,10 +34,14 @@ public final class OpenGaussAuthenticationSCRAMSha256Packet
implements PostgreSQ
private static final int PASSWORD_STORED_METHOD_SHA256 = 2;
+ private final int version;
+
private final byte[] random64Code;
private final byte[] token;
+ private final byte[] serverSignature;
+
private final int serverIteration;
@Override
@@ -45,7 +50,12 @@ public final class OpenGaussAuthenticationSCRAMSha256Packet
implements PostgreSQ
payload.writeInt4(PASSWORD_STORED_METHOD_SHA256);
payload.writeBytes(random64Code);
payload.writeBytes(token);
- payload.writeInt4(serverIteration);
+ if (version < OpenGaussProtocolVersion.PROTOCOL_350.getVersion()) {
+ payload.writeBytes(serverSignature);
+ }
+ if (OpenGaussProtocolVersion.PROTOCOL_351.getVersion() == version) {
+ payload.writeInt4(serverIteration);
+ }
}
@Override
diff --git
a/db-protocol/opengauss/src/test/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256PacketTest.java
b/db-protocol/opengauss/src/test/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256PacketTest.java
index 79bc71dfa74..a24a36f7301 100644
---
a/db-protocol/opengauss/src/test/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256PacketTest.java
+++
b/db-protocol/opengauss/src/test/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256PacketTest.java
@@ -17,6 +17,7 @@
package org.apache.shardingsphere.db.protocol.opengauss.packet.authentication;
+import
org.apache.shardingsphere.db.protocol.opengauss.constant.OpenGaussProtocolVersion;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLMessagePacketType;
import
org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
import org.junit.Test;
@@ -32,23 +33,50 @@ public final class
OpenGaussAuthenticationSCRAMSha256PacketTest {
private static final byte[] TOKEN = new byte[8];
- private static final int SERVER_ITERATION = 2048;
+ private static final byte[] SERVER_SIGNATURE = new byte[64];
- private final OpenGaussAuthenticationSCRAMSha256Packet packet = new
OpenGaussAuthenticationSCRAMSha256Packet(RANDOM_64_CODE, TOKEN,
SERVER_ITERATION);
+ @Test
+ public void assertWriteProtocol300Packet() {
+ PostgreSQLPacketPayload payload = mock(PostgreSQLPacketPayload.class);
+ OpenGaussAuthenticationSCRAMSha256Packet packet = new
OpenGaussAuthenticationSCRAMSha256Packet(
+ OpenGaussProtocolVersion.PROTOCOL_350.getVersion(),
RANDOM_64_CODE, TOKEN, SERVER_SIGNATURE, 2048);
+ packet.write(payload);
+ verify(payload).writeInt4(10);
+ verify(payload).writeInt4(2);
+ verify(payload).writeBytes(RANDOM_64_CODE);
+ verify(payload).writeBytes(TOKEN);
+ verify(payload).writeBytes(SERVER_SIGNATURE);
+ }
+
+ @Test
+ public void assertWriteProtocol350Packet() {
+ PostgreSQLPacketPayload payload = mock(PostgreSQLPacketPayload.class);
+ OpenGaussAuthenticationSCRAMSha256Packet packet = new
OpenGaussAuthenticationSCRAMSha256Packet(
+ OpenGaussProtocolVersion.PROTOCOL_350.getVersion(),
RANDOM_64_CODE, TOKEN, SERVER_SIGNATURE, 2048);
+ packet.write(payload);
+ verify(payload).writeInt4(10);
+ verify(payload).writeInt4(2);
+ verify(payload).writeBytes(RANDOM_64_CODE);
+ verify(payload).writeBytes(TOKEN);
+ }
@Test
- public void assertWrite() {
+ public void assertWriteProtocol351Packet() {
PostgreSQLPacketPayload payload = mock(PostgreSQLPacketPayload.class);
+ OpenGaussAuthenticationSCRAMSha256Packet packet = new
OpenGaussAuthenticationSCRAMSha256Packet(
+ OpenGaussProtocolVersion.PROTOCOL_351.getVersion(),
RANDOM_64_CODE, TOKEN, SERVER_SIGNATURE, 10000);
packet.write(payload);
verify(payload).writeInt4(10);
verify(payload).writeInt4(2);
verify(payload).writeBytes(RANDOM_64_CODE);
verify(payload).writeBytes(TOKEN);
- verify(payload).writeInt4(SERVER_ITERATION);
+ verify(payload).writeInt4(10000);
}
@Test
public void assertIdentifierTag() {
+ OpenGaussAuthenticationSCRAMSha256Packet packet = new
OpenGaussAuthenticationSCRAMSha256Packet(
+ OpenGaussProtocolVersion.PROTOCOL_351.getVersion(),
RANDOM_64_CODE, TOKEN, SERVER_SIGNATURE, 10000);
assertThat(packet.getIdentifier(),
is(PostgreSQLMessagePacketType.AUTHENTICATION_REQUEST));
}
}
diff --git
a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
index 7f5ac5c9fe1..bd9986c583c 100644
---
a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
+++
b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
@@ -17,6 +17,7 @@
package org.apache.shardingsphere.db.protocol.postgresql.packet.handshake;
+import lombok.Getter;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.PostgreSQLPacket;
import
org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
@@ -34,10 +35,14 @@ public final class PostgreSQLComStartupPacket implements
PostgreSQLPacket {
private static final String CLIENT_ENCODING_KEY = "client_encoding";
+ @Getter
+ private final int version;
+
private final Map<String, String> parametersMap = new HashMap<>();
public PostgreSQLComStartupPacket(final PostgreSQLPacketPayload payload) {
- payload.skipReserved(8);
+ payload.skipReserved(4);
+ version = payload.readInt4();
while (payload.bytesBeforeZero() > 0) {
parametersMap.put(payload.readStringNul(),
payload.readStringNul());
}
diff --git
a/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngine.java
b/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngine.java
index c7c44590875..a7b70dce8cb 100644
---
a/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngine.java
+++
b/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngine.java
@@ -19,7 +19,9 @@ package
org.apache.shardingsphere.proxy.frontend.opengauss.authentication;
import com.google.common.base.Strings;
import io.netty.channel.ChannelHandlerContext;
+import org.apache.shardingsphere.authority.rule.AuthorityRule;
import org.apache.shardingsphere.db.protocol.CommonConstants;
+import
org.apache.shardingsphere.db.protocol.opengauss.constant.OpenGaussProtocolVersion;
import
org.apache.shardingsphere.db.protocol.opengauss.packet.authentication.OpenGaussAuthenticationSCRAMSha256Packet;
import org.apache.shardingsphere.db.protocol.payload.PacketPayload;
import
org.apache.shardingsphere.db.protocol.postgresql.constant.PostgreSQLServerInfo;
@@ -33,6 +35,9 @@ import
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.Postgr
import
org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
import
org.apache.shardingsphere.dialect.postgresql.exception.authority.EmptyUsernameException;
import
org.apache.shardingsphere.dialect.postgresql.exception.protocol.ProtocolViolationException;
+import org.apache.shardingsphere.infra.metadata.user.Grantee;
+import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
+import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import
org.apache.shardingsphere.proxy.backend.handler.admin.postgresql.PostgreSQLCharacterSets;
import
org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationEngine;
import
org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationResult;
@@ -50,22 +55,25 @@ public final class OpenGaussAuthenticationEngine implements
AuthenticationEngine
private static final int SSL_REQUEST_CODE = 80877103;
- private boolean startupMessageReceived;
+ private static final int PROTOCOL_351_SERVER_ITERATOR = 10000;
- private String clientEncoding;
+ private static final int PROTOCOL_350_SERVER_ITERATOR = 2048;
private final String saltHexString;
private final String nonceHexString;
- private final int serverIteration;
+ private boolean startupMessageReceived;
+
+ private String clientEncoding;
+
+ private int serverIteration;
private AuthenticationResult currentAuthResult;
public OpenGaussAuthenticationEngine() {
saltHexString = generateRandomHexString(64);
nonceHexString = generateRandomHexString(8);
- serverIteration = 10000;
}
private String generateRandomHexString(final int length) {
@@ -101,11 +109,23 @@ public final class OpenGaussAuthenticationEngine
implements AuthenticationEngine
if (Strings.isNullOrEmpty(user)) {
throw new EmptyUsernameException();
}
- context.writeAndFlush(new
OpenGaussAuthenticationSCRAMSha256Packet(saltHexString.getBytes(),
nonceHexString.getBytes(), serverIteration));
+ serverIteration = comStartupPacket.getVersion() ==
OpenGaussProtocolVersion.PROTOCOL_351.getVersion() ?
PROTOCOL_351_SERVER_ITERATOR : PROTOCOL_350_SERVER_ITERATOR;
+ String serverSignature =
calculateServerSignature(comStartupPacket.getVersion(), user);
+ context.writeAndFlush(new OpenGaussAuthenticationSCRAMSha256Packet(
+ comStartupPacket.getVersion(), saltHexString.getBytes(),
nonceHexString.getBytes(), serverSignature.getBytes(), serverIteration));
currentAuthResult = AuthenticationResultBuilder.continued(user, "",
comStartupPacket.getDatabase());
return currentAuthResult;
}
+ private String calculateServerSignature(final int version, final String
username) {
+ if (version >= OpenGaussProtocolVersion.PROTOCOL_350.getVersion()) {
+ return "";
+ }
+ String password =
ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().findSingleRule(AuthorityRule.class)
+ .flatMap(authorityRule -> authorityRule.findUser(new
Grantee(username, "%"))).map(ShardingSphereUser::getPassword).orElse("");
+ return
OpenGaussAuthenticationHandler.calculateServerSignature(password,
saltHexString, nonceHexString, serverIteration);
+ }
+
private AuthenticationResult processPasswordMessage(final
ChannelHandlerContext context, final PostgreSQLPacketPayload payload) {
char messageType = (char) payload.readInt1();
if (PostgreSQLMessagePacketType.PASSWORD_MESSAGE.getValue() !=
messageType) {
diff --git
a/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationHandler.java
b/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationHandler.java
index cf51f096436..a066dc45048 100644
---
a/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationHandler.java
+++
b/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationHandler.java
@@ -37,6 +37,7 @@ import javax.crypto.Mac;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;
+import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
@@ -60,8 +61,38 @@ public final class OpenGaussAuthenticationHandler {
private static final String SHA256_ALGORITHM = "SHA-256";
+ private static final String SERVER_KEY = "Server Key";
+
private static final String CLIENT_KEY = "Client Key";
+ /**
+ * Calculate server signature.
+ *
+ * @param password password
+ * @param salt salt in hex string
+ * @param nonce nonce in hex string
+ * @param serverIteration server iteration
+ * @return server signature
+ */
+ public static String calculateServerSignature(final String password, final
String salt, final String nonce, final int serverIteration) {
+ byte[] k = generateKFromPBKDF2(password, salt, serverIteration);
+ byte[] serverKey = getKeyFromHmac(k,
SERVER_KEY.getBytes(StandardCharsets.UTF_8));
+ byte[] result = getKeyFromHmac(serverKey, hexStringToBytes(nonce));
+ return bytesToHexString(result);
+ }
+
+ private static String bytesToHexString(final byte[] src) {
+ StringBuilder result = new StringBuilder();
+ for (byte each : src) {
+ String hex = Integer.toHexString(each & 255);
+ if (hex.length() < 2) {
+ result.append(0);
+ }
+ result.append(hex);
+ }
+ return result.toString();
+ }
+
/**
* Login with SCRAM SHA-256 password.
*
diff --git
a/proxy/frontend/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngineTest.java
b/proxy/frontend/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngineTest.java
index 48b9c6aea43..774270a8278 100644
---
a/proxy/frontend/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngineTest.java
+++
b/proxy/frontend/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngineTest.java
@@ -133,6 +133,10 @@ public final class OpenGaussAuthenticationEngineTest
extends ProxyContextRestore
}
private void assertLogin(final String inputPassword) {
+ MetaDataContexts metaDataContexts = getMetaDataContexts(new
ShardingSphereUser(username, password, ""));
+ ContextManager contextManager = mock(ContextManager.class,
RETURNS_DEEP_STUBS);
+
when(contextManager.getMetaDataContexts()).thenReturn(metaDataContexts);
+ ProxyContext.init(contextManager);
PostgreSQLPacketPayload payload = new
PostgreSQLPacketPayload(createByteBuf(16, 128), StandardCharsets.UTF_8);
payload.writeInt4(64);
payload.writeInt4(196608);
@@ -155,10 +159,6 @@ public final class OpenGaussAuthenticationEngineTest
extends ProxyContextRestore
payload.writeInt1('p');
payload.writeInt4(4 + clientDigest.length() + 1);
payload.writeStringNul(clientDigest);
- MetaDataContexts metaDataContexts = getMetaDataContexts(new
ShardingSphereUser(username, password, ""));
- ContextManager contextManager = mock(ContextManager.class,
RETURNS_DEEP_STUBS);
-
when(contextManager.getMetaDataContexts()).thenReturn(metaDataContexts);
- ProxyContext.init(contextManager);
actual = engine.authenticate(channelHandlerContext, payload);
assertThat(actual.isFinished(), is(password.equals(inputPassword)));
}