This is an automated email from the ASF dual-hosted git repository.

tuichenchuxin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 5e715611f14 Fix gsql 3.0 may be stuck when connecting Proxy (#21869)
5e715611f14 is described below

commit 5e715611f140db6fe298d7c2d4ba450c60ab3fa6
Author: 吴伟杰 <[email protected]>
AuthorDate: Tue Nov 1 10:10:08 2022 +0800

    Fix gsql 3.0 may be stuck when connecting Proxy (#21869)
    
    * Fix gsql 3.0 may be stuck when connecting Proxy
    
    * Complete OpenGaussAuthenticationSCRAMSha256PacketTest
    
    * Refactor OpenGaussAuthenticationHandler
    
    * Complete OpenGaussAuthenticationEngineTest
    
    * Generate server signature even if user unknown
---
 .../constant/OpenGaussProtocolVersion.java         | 35 +++++++++++++++++++++
 .../OpenGaussAuthenticationSCRAMSha256Packet.java  | 12 +++++++-
 ...enGaussAuthenticationSCRAMSha256PacketTest.java | 36 +++++++++++++++++++---
 .../handshake/PostgreSQLComStartupPacket.java      |  7 ++++-
 .../OpenGaussAuthenticationEngine.java             | 30 +++++++++++++++---
 .../OpenGaussAuthenticationHandler.java            | 31 +++++++++++++++++++
 .../OpenGaussAuthenticationEngineTest.java         |  8 ++---
 7 files changed, 144 insertions(+), 15 deletions(-)

diff --git 
a/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/constant/OpenGaussProtocolVersion.java
 
b/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/constant/OpenGaussProtocolVersion.java
new file mode 100644
index 00000000000..a09f1fcfd73
--- /dev/null
+++ 
b/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/constant/OpenGaussProtocolVersion.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.db.protocol.opengauss.constant;
+
+import lombok.Getter;
+import lombok.RequiredArgsConstructor;
+
+/**
+ * Protocol version of openGauss.
+ */
+@RequiredArgsConstructor
+@Getter
+public enum OpenGaussProtocolVersion {
+    
+    PROTOCOL_350(3 << 16 | 50),
+    
+    PROTOCOL_351(3 << 16 | 51);
+    
+    private final int version;
+}
diff --git 
a/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256Packet.java
 
b/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256Packet.java
index 19c1b1649f2..6dcb4e76a15 100644
--- 
a/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256Packet.java
+++ 
b/db-protocol/opengauss/src/main/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256Packet.java
@@ -18,6 +18,7 @@
 package org.apache.shardingsphere.db.protocol.opengauss.packet.authentication;
 
 import lombok.RequiredArgsConstructor;
+import 
org.apache.shardingsphere.db.protocol.opengauss.constant.OpenGaussProtocolVersion;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierPacket;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLIdentifierTag;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLMessagePacketType;
@@ -33,10 +34,14 @@ public final class OpenGaussAuthenticationSCRAMSha256Packet 
implements PostgreSQ
     
     private static final int PASSWORD_STORED_METHOD_SHA256 = 2;
     
+    private final int version;
+    
     private final byte[] random64Code;
     
     private final byte[] token;
     
+    private final byte[] serverSignature;
+    
     private final int serverIteration;
     
     @Override
@@ -45,7 +50,12 @@ public final class OpenGaussAuthenticationSCRAMSha256Packet 
implements PostgreSQ
         payload.writeInt4(PASSWORD_STORED_METHOD_SHA256);
         payload.writeBytes(random64Code);
         payload.writeBytes(token);
-        payload.writeInt4(serverIteration);
+        if (version < OpenGaussProtocolVersion.PROTOCOL_350.getVersion()) {
+            payload.writeBytes(serverSignature);
+        }
+        if (OpenGaussProtocolVersion.PROTOCOL_351.getVersion() == version) {
+            payload.writeInt4(serverIteration);
+        }
     }
     
     @Override
diff --git 
a/db-protocol/opengauss/src/test/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256PacketTest.java
 
b/db-protocol/opengauss/src/test/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256PacketTest.java
index 79bc71dfa74..a24a36f7301 100644
--- 
a/db-protocol/opengauss/src/test/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256PacketTest.java
+++ 
b/db-protocol/opengauss/src/test/java/org/apache/shardingsphere/db/protocol/opengauss/packet/authentication/OpenGaussAuthenticationSCRAMSha256PacketTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.shardingsphere.db.protocol.opengauss.packet.authentication;
 
+import 
org.apache.shardingsphere.db.protocol.opengauss.constant.OpenGaussProtocolVersion;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.PostgreSQLMessagePacketType;
 import 
org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
 import org.junit.Test;
@@ -32,23 +33,50 @@ public final class 
OpenGaussAuthenticationSCRAMSha256PacketTest {
     
     private static final byte[] TOKEN = new byte[8];
     
-    private static final int SERVER_ITERATION = 2048;
+    private static final byte[] SERVER_SIGNATURE = new byte[64];
     
-    private final OpenGaussAuthenticationSCRAMSha256Packet packet = new 
OpenGaussAuthenticationSCRAMSha256Packet(RANDOM_64_CODE, TOKEN, 
SERVER_ITERATION);
+    @Test
+    public void assertWriteProtocol300Packet() {
+        PostgreSQLPacketPayload payload = mock(PostgreSQLPacketPayload.class);
+        OpenGaussAuthenticationSCRAMSha256Packet packet = new 
OpenGaussAuthenticationSCRAMSha256Packet(
+                OpenGaussProtocolVersion.PROTOCOL_350.getVersion(), 
RANDOM_64_CODE, TOKEN, SERVER_SIGNATURE, 2048);
+        packet.write(payload);
+        verify(payload).writeInt4(10);
+        verify(payload).writeInt4(2);
+        verify(payload).writeBytes(RANDOM_64_CODE);
+        verify(payload).writeBytes(TOKEN);
+        verify(payload).writeBytes(SERVER_SIGNATURE);
+    }
+    
+    @Test
+    public void assertWriteProtocol350Packet() {
+        PostgreSQLPacketPayload payload = mock(PostgreSQLPacketPayload.class);
+        OpenGaussAuthenticationSCRAMSha256Packet packet = new 
OpenGaussAuthenticationSCRAMSha256Packet(
+                OpenGaussProtocolVersion.PROTOCOL_350.getVersion(), 
RANDOM_64_CODE, TOKEN, SERVER_SIGNATURE, 2048);
+        packet.write(payload);
+        verify(payload).writeInt4(10);
+        verify(payload).writeInt4(2);
+        verify(payload).writeBytes(RANDOM_64_CODE);
+        verify(payload).writeBytes(TOKEN);
+    }
     
     @Test
-    public void assertWrite() {
+    public void assertWriteProtocol351Packet() {
         PostgreSQLPacketPayload payload = mock(PostgreSQLPacketPayload.class);
+        OpenGaussAuthenticationSCRAMSha256Packet packet = new 
OpenGaussAuthenticationSCRAMSha256Packet(
+                OpenGaussProtocolVersion.PROTOCOL_351.getVersion(), 
RANDOM_64_CODE, TOKEN, SERVER_SIGNATURE, 10000);
         packet.write(payload);
         verify(payload).writeInt4(10);
         verify(payload).writeInt4(2);
         verify(payload).writeBytes(RANDOM_64_CODE);
         verify(payload).writeBytes(TOKEN);
-        verify(payload).writeInt4(SERVER_ITERATION);
+        verify(payload).writeInt4(10000);
     }
     
     @Test
     public void assertIdentifierTag() {
+        OpenGaussAuthenticationSCRAMSha256Packet packet = new 
OpenGaussAuthenticationSCRAMSha256Packet(
+                OpenGaussProtocolVersion.PROTOCOL_351.getVersion(), 
RANDOM_64_CODE, TOKEN, SERVER_SIGNATURE, 10000);
         assertThat(packet.getIdentifier(), 
is(PostgreSQLMessagePacketType.AUTHENTICATION_REQUEST));
     }
 }
diff --git 
a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
 
b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
index 7f5ac5c9fe1..bd9986c583c 100644
--- 
a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
+++ 
b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/handshake/PostgreSQLComStartupPacket.java
@@ -17,6 +17,7 @@
 
 package org.apache.shardingsphere.db.protocol.postgresql.packet.handshake;
 
+import lombok.Getter;
 import 
org.apache.shardingsphere.db.protocol.postgresql.packet.PostgreSQLPacket;
 import 
org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
 
@@ -34,10 +35,14 @@ public final class PostgreSQLComStartupPacket implements 
PostgreSQLPacket {
     
     private static final String CLIENT_ENCODING_KEY = "client_encoding";
     
+    @Getter
+    private final int version;
+    
     private final Map<String, String> parametersMap = new HashMap<>();
     
     public PostgreSQLComStartupPacket(final PostgreSQLPacketPayload payload) {
-        payload.skipReserved(8);
+        payload.skipReserved(4);
+        version = payload.readInt4();
         while (payload.bytesBeforeZero() > 0) {
             parametersMap.put(payload.readStringNul(), 
payload.readStringNul());
         }
diff --git 
a/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngine.java
 
b/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngine.java
index c7c44590875..a7b70dce8cb 100644
--- 
a/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngine.java
+++ 
b/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngine.java
@@ -19,7 +19,9 @@ package 
org.apache.shardingsphere.proxy.frontend.opengauss.authentication;
 
 import com.google.common.base.Strings;
 import io.netty.channel.ChannelHandlerContext;
+import org.apache.shardingsphere.authority.rule.AuthorityRule;
 import org.apache.shardingsphere.db.protocol.CommonConstants;
+import 
org.apache.shardingsphere.db.protocol.opengauss.constant.OpenGaussProtocolVersion;
 import 
org.apache.shardingsphere.db.protocol.opengauss.packet.authentication.OpenGaussAuthenticationSCRAMSha256Packet;
 import org.apache.shardingsphere.db.protocol.payload.PacketPayload;
 import 
org.apache.shardingsphere.db.protocol.postgresql.constant.PostgreSQLServerInfo;
@@ -33,6 +35,9 @@ import 
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.Postgr
 import 
org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
 import 
org.apache.shardingsphere.dialect.postgresql.exception.authority.EmptyUsernameException;
 import 
org.apache.shardingsphere.dialect.postgresql.exception.protocol.ProtocolViolationException;
+import org.apache.shardingsphere.infra.metadata.user.Grantee;
+import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
+import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
 import 
org.apache.shardingsphere.proxy.backend.handler.admin.postgresql.PostgreSQLCharacterSets;
 import 
org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationEngine;
 import 
org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationResult;
@@ -50,22 +55,25 @@ public final class OpenGaussAuthenticationEngine implements 
AuthenticationEngine
     
     private static final int SSL_REQUEST_CODE = 80877103;
     
-    private boolean startupMessageReceived;
+    private static final int PROTOCOL_351_SERVER_ITERATOR = 10000;
     
-    private String clientEncoding;
+    private static final int PROTOCOL_350_SERVER_ITERATOR = 2048;
     
     private final String saltHexString;
     
     private final String nonceHexString;
     
-    private final int serverIteration;
+    private boolean startupMessageReceived;
+    
+    private String clientEncoding;
+    
+    private int serverIteration;
     
     private AuthenticationResult currentAuthResult;
     
     public OpenGaussAuthenticationEngine() {
         saltHexString = generateRandomHexString(64);
         nonceHexString = generateRandomHexString(8);
-        serverIteration = 10000;
     }
     
     private String generateRandomHexString(final int length) {
@@ -101,11 +109,23 @@ public final class OpenGaussAuthenticationEngine 
implements AuthenticationEngine
         if (Strings.isNullOrEmpty(user)) {
             throw new EmptyUsernameException();
         }
-        context.writeAndFlush(new 
OpenGaussAuthenticationSCRAMSha256Packet(saltHexString.getBytes(), 
nonceHexString.getBytes(), serverIteration));
+        serverIteration = comStartupPacket.getVersion() == 
OpenGaussProtocolVersion.PROTOCOL_351.getVersion() ? 
PROTOCOL_351_SERVER_ITERATOR : PROTOCOL_350_SERVER_ITERATOR;
+        String serverSignature = 
calculateServerSignature(comStartupPacket.getVersion(), user);
+        context.writeAndFlush(new OpenGaussAuthenticationSCRAMSha256Packet(
+                comStartupPacket.getVersion(), saltHexString.getBytes(), 
nonceHexString.getBytes(), serverSignature.getBytes(), serverIteration));
         currentAuthResult = AuthenticationResultBuilder.continued(user, "", 
comStartupPacket.getDatabase());
         return currentAuthResult;
     }
     
+    private String calculateServerSignature(final int version, final String 
username) {
+        if (version >= OpenGaussProtocolVersion.PROTOCOL_350.getVersion()) {
+            return "";
+        }
+        String password = 
ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().findSingleRule(AuthorityRule.class)
+                .flatMap(authorityRule -> authorityRule.findUser(new 
Grantee(username, "%"))).map(ShardingSphereUser::getPassword).orElse("");
+        return 
OpenGaussAuthenticationHandler.calculateServerSignature(password, 
saltHexString, nonceHexString, serverIteration);
+    }
+    
     private AuthenticationResult processPasswordMessage(final 
ChannelHandlerContext context, final PostgreSQLPacketPayload payload) {
         char messageType = (char) payload.readInt1();
         if (PostgreSQLMessagePacketType.PASSWORD_MESSAGE.getValue() != 
messageType) {
diff --git 
a/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationHandler.java
 
b/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationHandler.java
index cf51f096436..a066dc45048 100644
--- 
a/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationHandler.java
+++ 
b/proxy/frontend/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationHandler.java
@@ -37,6 +37,7 @@ import javax.crypto.Mac;
 import javax.crypto.SecretKeyFactory;
 import javax.crypto.spec.PBEKeySpec;
 import javax.crypto.spec.SecretKeySpec;
+import java.nio.charset.StandardCharsets;
 import java.security.InvalidKeyException;
 import java.security.MessageDigest;
 import java.security.NoSuchAlgorithmException;
@@ -60,8 +61,38 @@ public final class OpenGaussAuthenticationHandler {
     
     private static final String SHA256_ALGORITHM = "SHA-256";
     
+    private static final String SERVER_KEY = "Server Key";
+    
     private static final String CLIENT_KEY = "Client Key";
     
+    /**
+     * Calculate server signature.
+     *
+     * @param password password
+     * @param salt salt in hex string
+     * @param nonce nonce in hex string
+     * @param serverIteration server iteration
+     * @return server signature
+     */
+    public static String calculateServerSignature(final String password, final 
String salt, final String nonce, final int serverIteration) {
+        byte[] k = generateKFromPBKDF2(password, salt, serverIteration);
+        byte[] serverKey = getKeyFromHmac(k, 
SERVER_KEY.getBytes(StandardCharsets.UTF_8));
+        byte[] result = getKeyFromHmac(serverKey, hexStringToBytes(nonce));
+        return bytesToHexString(result);
+    }
+    
+    private static String bytesToHexString(final byte[] src) {
+        StringBuilder result = new StringBuilder();
+        for (byte each : src) {
+            String hex = Integer.toHexString(each & 255);
+            if (hex.length() < 2) {
+                result.append(0);
+            }
+            result.append(hex);
+        }
+        return result.toString();
+    }
+    
     /**
      * Login with SCRAM SHA-256 password.
      *
diff --git 
a/proxy/frontend/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngineTest.java
 
b/proxy/frontend/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngineTest.java
index 48b9c6aea43..774270a8278 100644
--- 
a/proxy/frontend/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngineTest.java
+++ 
b/proxy/frontend/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/authentication/OpenGaussAuthenticationEngineTest.java
@@ -133,6 +133,10 @@ public final class OpenGaussAuthenticationEngineTest 
extends ProxyContextRestore
     }
     
     private void assertLogin(final String inputPassword) {
+        MetaDataContexts metaDataContexts = getMetaDataContexts(new 
ShardingSphereUser(username, password, ""));
+        ContextManager contextManager = mock(ContextManager.class, 
RETURNS_DEEP_STUBS);
+        
when(contextManager.getMetaDataContexts()).thenReturn(metaDataContexts);
+        ProxyContext.init(contextManager);
         PostgreSQLPacketPayload payload = new 
PostgreSQLPacketPayload(createByteBuf(16, 128), StandardCharsets.UTF_8);
         payload.writeInt4(64);
         payload.writeInt4(196608);
@@ -155,10 +159,6 @@ public final class OpenGaussAuthenticationEngineTest 
extends ProxyContextRestore
         payload.writeInt1('p');
         payload.writeInt4(4 + clientDigest.length() + 1);
         payload.writeStringNul(clientDigest);
-        MetaDataContexts metaDataContexts = getMetaDataContexts(new 
ShardingSphereUser(username, password, ""));
-        ContextManager contextManager = mock(ContextManager.class, 
RETURNS_DEEP_STUBS);
-        
when(contextManager.getMetaDataContexts()).thenReturn(metaDataContexts);
-        ProxyContext.init(contextManager);
         actual = engine.authenticate(channelHandlerContext, payload);
         assertThat(actual.isFinished(), is(password.equals(inputPassword)));
     }

Reply via email to