This is an automated email from the ASF dual-hosted git repository.

twolf pushed a commit to branch dev_3.0
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git

commit 3bbb38f3150fecd01ba258bfc17c0776a9ccc077
Author: Thomas Wolf <tw...@apache.org>
AuthorDate: Thu Apr 3 22:37:31 2025 +0200

    Limit buffering during KEX
    
    Introduce CoreProperties.MAX_MSGS_BEFORE_KEX_INIT: a maximum number of
    incoming messages that require a response that we accept after having
    sent our own KEXINIT until the peer's KEXINIT arrives. The limit
    concerns only service requests, global requests, and channel requests.
    If we don't receive the peer's KEXINIT within that number of incoming
    messages, we disconnect.
    
    This is a safeguard against OOM attacks: if the peer never sends its
    KEXINIT but only messages that require us to buffer replies (because
    we're already in KEX), we could run out of memory.
    
    Note that SSH_MSG_CHANNEL_DATA is not affected; we don't send back
    window adjustments during KEX, and we close all RemoteWindows during
    KEX to prevent data pumping threads to overrun our buffer capacity.
---
 .../common/session/filters/DelayKexInitFilter.java |   2 +-
 .../common/session/filters/InjectIgnoreFilter.java |  22 +-
 .../sshd/common/session/filters/kex/KexFilter.java | 228 +++++++++++++--------
 .../session/filters/kex/KexOutputHandler.java      |  34 +--
 .../session/filters/kex/MessageCodingSettings.java |  10 +-
 .../org/apache/sshd/core/CoreModuleProperties.java |  22 ++
 6 files changed, 196 insertions(+), 122 deletions(-)

diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/DelayKexInitFilter.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/DelayKexInitFilter.java
index ae3e4c69c..b63e028ff 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/DelayKexInitFilter.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/DelayKexInitFilter.java
@@ -97,7 +97,7 @@ public class DelayKexInitFilter extends IoFilter {
 
         @Override
         public IoWriteFuture send(int cmd, Buffer message) throws IOException {
-            if (cmd != SshConstants.SSH_MSG_KEXINIT || output.get() == null) {
+            if (cmd != SshConstants.SSH_MSG_KEXINIT || output.get() == null || 
message == null) {
                 return owner().send(cmd, message);
             }
             boolean first = isFirst.getAndSet(false);
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/InjectIgnoreFilter.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/InjectIgnoreFilter.java
index 2e3e39b4a..9337b3cd3 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/InjectIgnoreFilter.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/InjectIgnoreFilter.java
@@ -78,17 +78,19 @@ public class InjectIgnoreFilter extends IoFilter {
 
         @Override
         public synchronized IoWriteFuture send(int cmd, Buffer message) throws 
IOException {
-            int length = shouldSendIgnore(cmd);
-            if (length > 0) {
-                if (LOG.isDebugEnabled()) {
-                    LOG.debug("Injector.send({}) injecting SSH_MSG_IGNORE", 
resolver);
-                }
-                owner().send(SshConstants.SSH_MSG_IGNORE, 
createIgnoreBuffer(length)).addListener(f -> {
-                    Throwable t = f.getException();
-                    if (t != null && (resolver instanceof Session)) {
-                        ((Session) resolver).exceptionCaught(t);
+            if (message != null) {
+                int length = shouldSendIgnore(cmd);
+                if (length > 0) {
+                    if (LOG.isDebugEnabled()) {
+                        LOG.debug("Injector.send({}) injecting 
SSH_MSG_IGNORE", resolver);
                     }
-                });
+                    owner().send(SshConstants.SSH_MSG_IGNORE, 
createIgnoreBuffer(length)).addListener(f -> {
+                        Throwable t = f.getException();
+                        if (t != null && (resolver instanceof Session)) {
+                            ((Session) resolver).exceptionCaught(t);
+                        }
+                    });
+                }
             }
             return owner().send(cmd, message);
         }
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexFilter.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexFilter.java
index ac0c65e72..36c15e04b 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexFilter.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexFilter.java
@@ -33,8 +33,8 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.IntSupplier;
 import java.util.stream.Collectors;
 
 import org.apache.sshd.common.NamedFactory;
@@ -100,14 +100,19 @@ import org.slf4j.LoggerFactory;
  * </p>
  * <ul>
  * <li>SSH_MSG_PING from the {@code p...@openssh.com} extension. This is not 
implemented in Apache MINA sshd (yet).
- * These messages are dropped on input. See CVE-2025-26466 linked below.</li>
- * <li>SSH_MSG_GLOBAL_REQUEST or SSH_MSG_CHANNEL_REQUEST with {@code 
want-reply = true}.</li>
- * <li>SSH_MSG_CHANNEL_OPEN messages. User code can guard against this by 
limiting the number of concurrently open
- * channels.</lI>
+ * These messages are dropped on input during KEX in OpenSSH. See 
CVE-2025-26466 linked below.</li>
+ * <li>SSH_MSG_GLOBAL_REQUEST or SSH_MSG_CHANNEL_REQUEST with {@code 
want-reply = true} should send back a success or
+ * failure reply that would be queued.</li>
+ * <li>SSH_MSG_CHANNEL_OPEN messages should send back success or failure 
messages, which would be queued. User code can
+ * guard against this by limiting the number of concurrently open 
channels.</li>
  * <li>SSH_MSG_SERVICE_REQUEST messages. This is somewhat unlikely to occur, 
since normally there are only two such
  * requests in an SSH connection: a first one for user authentication, then a 
second one to switch to the connection
  * service. There should be no key exchanges running at these times; they're 
both early on in the protocol. The request
- * for user auth is sent right after the first key exchange.</li>
+ * for user auth is sent right after the first key exchange. The failure reply 
to this is SSH_MSG_DISCONNECT, which will
+ * not be queued. But the SSH_MSG_SERVICE_ACCEPT would be queued. But the 
number of services is limited (in normal SSH
+ * exactly two: a user authentication service and then a connection service), 
and our implementation allows only one
+ * service to be active. So there will be exactly one SSH_MSG_SERVCIE_ACCEPT 
queued; further SSH_MSG_SERVICE_REQUESTs
+ * will lead to failure replies and disconnection.</li>
  * <li>SSH_MSG_CHANNEL_DATA: these messages <em>must</em> be passed on and 
handled. LocalWindow needs to listen to the
  * KEX state, too, and not send back SSH_CHANNEL_WINDOW_ADJUST because those 
would get queued. At some point, the
  * channel window will be zero, and if the broken or malicious client keeps 
sending data, the channel will be closed
@@ -115,20 +120,17 @@ import org.slf4j.LoggerFactory;
  * <li>SSH_MSG_CHANNEL_WINDOW_ADJUST: see above. We pass these messages on, 
but make the adjustment take effect in the
  * RemoteWindow only after KEX. Sending a large number of window adjustments 
thus does not cause excessive queueing; at
  * worst (if the peer opens its window too far) it may cause trouble at the 
malicious peer.</li>
- * <li>Unknown messages. We should reply with SSH_MSG_UNIMPLEMENTED except if 
in strict KEX. During strict KEX, we will
- * drop any unknown messages on input.</li>
+ * <li>Unknown messages. We should reply with SSH_MSG_UNIMPLEMENTED, which is 
a low-level message that will not be
+ * queued.</li>
  * </p>
  * <p>
- * As an additional guard against this kind of misbehavior we implement two 
configurable parameters:
+ * As an additional guard against this kind of misbehavior we implement a 
configurable parameters:
  * </p>
- * <li>MAX_PACKETS_UNTIL_KEX_INIT: if we haven't received the peer's KEX_INIT 
with the next MAX_PACKETS_UNTIL_KEX_INIT
+ * <li>MAX_MSGS_BEFORE_KEX_INIT: if we haven't received the peer's KEX_INIT 
with the next MAX_MSGS_BEFORE_KEX_INIT
  * incoming messages after having sent our own KEX_INIT, we disconnect the 
session.</li>
- * <li>MAX_TIME_UNTIL_KEX_INIT: if we haven't received the peer's KEX_INIT 
within MAX_TIME_UNTIL_KEX_INIT after having
- * sent our own KEX_INIT, we disconnect the session.</li>
  * <p>
- * Both settings have rather high defaults (1000 messages or 10min). With 
these settings, we will disconnect even if a
- * peer just keeps sending SSH_MSG_IGNORE packets. If a peer doesn't send any 
messages, the session idle timeout will
- * disconnect the session.
+ * The setting has rather high default (1000 messages). With this, we will 
disconnect even if a peer just keeps sending
+ * SSH_MSG_IGNORE packets. If a peer doesn't send any messages, the session 
idle timeout will disconnect the session.
  * </p>
  *
  * @see <a 
href="https://www.cve.org/CVERecord?id=CVE-2025-26466";>CVE-2025-26466</a>
@@ -157,10 +159,14 @@ public class KexFilter extends IoFilter {
 
     private final AtomicReference<byte[]> peerData = new AtomicReference<>();
 
+    private final AtomicReference<KeyExchange> kex = new AtomicReference<>();
+
     private final AtomicReference<MessageCodingSettings> inputSettings = new 
AtomicReference<>();
 
     private final AtomicReference<MessageCodingSettings> outputSettings = new 
AtomicReference<>();
 
+    private final int maxMsgsBeforeKexInit;
+
     // Rekeying
 
     private final long rekeyAfterBytes;
@@ -210,12 +216,6 @@ public class KexFilter extends IoFilter {
 
     private final HostKeyChecker hostKeyChecker;
 
-    private enum KexStart {
-        PEER,
-        BOTH,
-        ONGOING
-    }
-
     private volatile String clientIdent;
 
     private volatile String serverIdent;
@@ -223,9 +223,6 @@ public class KexFilter extends IoFilter {
     // Set and checked on the input chain
     private boolean firstKexPacketFollows;
 
-    // Set and checked on the input chain
-    private KeyExchange kex;
-
     // Guarded by synchronized(KexFilter.this)
     private DefaultKeyExchangeFuture myProposalReady;
 
@@ -252,11 +249,13 @@ public class KexFilter extends IoFilter {
         }
         this.hostKeyChecker = checker;
 
+        maxMsgsBeforeKexInit = 
CoreModuleProperties.MAX_MSGS_BEFORE_KEX_INIT.getRequired(session);
+
         rekeyAfterBytes = 
CoreModuleProperties.REKEY_BYTES_LIMIT.getRequired(session);
         rekeyAfterPackets = 
CoreModuleProperties.REKEY_PACKETS_LIMIT.getRequired(session);
         rekeyAfterBlocks = rekeyAfterBytes / 16; // Initial setting, will be 
updated once we know the cipher
         Duration interval = 
CoreModuleProperties.REKEY_TIME_LIMIT.getRequired(session);
-        if (interval.isZero() || interval.isNegative()) {
+        if (interval.compareTo(Duration.ZERO) <= 0) {
             interval = null;
         }
         rekeyAfter = interval;
@@ -351,7 +350,14 @@ public class KexFilter extends IoFilter {
 
     // Receiving
 
+    private enum KexStart {
+        PEER,
+        BOTH,
+        ONGOING
+    }
+
     private void receiveKexInit(Buffer message) throws Exception {
+        input.messagesBeforeKexInit.set(0);
         // Update the KEX state
         KexStart starting = output.updateState(() -> {
             if (kexState.compareAndSet(KexState.DONE, KexState.RUN)) {
@@ -523,29 +529,6 @@ public class KexFilter extends IoFilter {
         }
     }
 
-    public KeyExchangeFuture startKex() throws Exception {
-        boolean start = output.updateState(() -> {
-            if (kexState.compareAndSet(KexState.DONE, KexState.INIT)) {
-                output.initNewKeyExchange();
-                return true;
-            }
-            return false;
-        });
-        DefaultKeyExchangeFuture result = new 
DefaultKeyExchangeFuture(session.toString(), session.getFutureLock());
-        if (start) {
-            listeners.forEach(listener -> listener.event(true));
-            kexFuture.set(result);
-            sendKexInit().addListener(f -> {
-                if (!f.isWritten()) {
-                    exceptionCaught(f.getException());
-                }
-            });
-        } else {
-            result.setValue(new SshException("KEX already ongoing"));
-        }
-        return result;
-    }
-
     // Negotiation
 
     /**
@@ -726,8 +709,9 @@ public class KexFilter extends IoFilter {
         byte[] iS = isServer ? myData.get() : peerData.get();
         byte[] iC = isServer ? peerData.get() : myData.get();
 
-        kex = kexFactory.createKeyExchange(session);
-        kex.init(vS, vC, iS, iC);
+        KeyExchange k = kexFactory.createKeyExchange(session);
+        k.init(vS, vC, iS, iC);
+        kex.set(k);
 
         synchronized (this) {
             myProposalReady = null;
@@ -739,9 +723,13 @@ public class KexFilter extends IoFilter {
 
     @SuppressWarnings("checkstyle:VariableDeclarationUsageDistance")
     private void prepareNewSettings() throws Exception {
-        byte[] k = kex.getK();
-        byte[] h = kex.getH();
-        Digest hash = kex.getHash();
+        KeyExchange exchange = kex.get();
+        if (exchange == null) {
+            throw new 
SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED, "No KEX");
+        }
+        byte[] k = exchange.getK();
+        byte[] h = exchange.getH();
+        Digest hash = exchange.getHash();
 
         byte[] sessionIdValue = sessionId.get();
         if (sessionIdValue == null) {
@@ -753,9 +741,10 @@ public class KexFilter extends IoFilter {
         }
 
         Buffer buffer = new ByteArrayBuffer();
-        buffer.putBytes(k);
+        buffer.putBytes(k); // K encoded with length, see RFC 4253, section 7.2
         buffer.putRawBytes(h);
-        buffer.putByte((byte) 0x41);
+        int j = buffer.wpos();
+        buffer.putByte((byte) 0x41); // 'A', see RFC 4253, section 7.2
         buffer.putRawBytes(sessionIdValue);
 
         int pos = buffer.available();
@@ -763,25 +752,24 @@ public class KexFilter extends IoFilter {
         hash.update(buf, 0, pos);
 
         byte[] iv_c2s = hash.digest();
-        int j = pos - sessionIdValue.length - 1;
 
-        buf[j]++;
+        buf[j]++; // 'B'
         hash.update(buf, 0, pos);
         byte[] iv_s2c = hash.digest();
 
-        buf[j]++;
+        buf[j]++; // 'C'
         hash.update(buf, 0, pos);
         byte[] e_c2s = hash.digest();
 
-        buf[j]++;
+        buf[j]++; // 'D'
         hash.update(buf, 0, pos);
         byte[] e_s2c = hash.digest();
 
-        buf[j]++;
+        buf[j]++; // 'E'
         hash.update(buf, 0, pos);
         byte[] mac_c2s = hash.digest();
 
-        buf[j]++;
+        buf[j]++; // 'F'
         hash.update(buf, 0, pos);
         byte[] mac_s2c = hash.digest();
 
@@ -870,7 +858,9 @@ public class KexFilter extends IoFilter {
 
     private IoWriteFuture sendNewKeys() throws Exception {
         Buffer buffer = session.createBuffer(SshConstants.SSH_MSG_NEWKEYS, 1);
+
         IoWriteFuture future = forward.send(SshConstants.SSH_MSG_NEWKEYS, 
buffer);
+
         // Use the new settings from now on for any outgoing packet
         setOutputEncoding();
         output.updateState(() -> kexState.set(KexState.KEYS));
@@ -941,8 +931,6 @@ public class KexFilter extends IoFilter {
 
         lastKexEnd.set(Instant.now());
 
-        forward.sequenceNumberCheckEnabled = false;
-
         if (LOG.isDebugEnabled()) {
             LOG.debug("setOutputEncoding({}): cipher {}; mac {}; compression 
{}; blocks limit {}", session, cipher, mac,
                     comp, maxRekeyBlocks);
@@ -958,9 +946,7 @@ public class KexFilter extends IoFilter {
             throw new SshException(SshConstants.SSH2_DISCONNECT_PROTOCOL_ERROR,
                     "KEX: received SSH_MSG_NEWKEYS in state " + currentState);
         }
-        input.sequenceNumberCheckEnabled = false;
         // It is guaranteed that we handle the peer's SSH_MSG_NEWKEYS after 
having sent our own.
-        // prepareNewKeys() was already called in sendNewKeys().
         //
         // From now on, use the new settings for any incoming message.
         setInputEncoding();
@@ -976,7 +962,7 @@ public class KexFilter extends IoFilter {
         listeners.forEach(listener -> listener.event(false));
 
         output.updateState(() -> {
-            kex = null; // discard and GC since KEX is completed
+            kex.set(null); // discard and GC since KEX is completed
             kexState.set(KexState.DONE);
         });
 
@@ -1009,8 +995,6 @@ public class KexFilter extends IoFilter {
         int outBlockSize = outCipher == null ? 8 : 
outCipher.getCipherBlockSize();
         long maxRekeyBlocks = 
determineRekeyBlockLimit(cipher.getCipherBlockSize(), outBlockSize);
 
-        lastKexEnd.set(Instant.now());
-
         if (LOG.isDebugEnabled()) {
             LOG.debug("setInputEncoding({}): cipher {}; mac {}; compression 
{}; blocks limit {}", session, cipher, mac,
                     comp, maxRekeyBlocks);
@@ -1056,7 +1040,7 @@ public class KexFilter extends IoFilter {
 
     // Starting a KEX
 
-    private boolean isKexNeeded(boolean input) {
+    private boolean isKexNeeded() {
         if (!initialKexDone || !session.isOpen()) {
             return false;
         }
@@ -1076,6 +1060,30 @@ public class KexFilter extends IoFilter {
                 || rekeyAfterPackets > 0 && rekeyAfterPackets <= 
counts.getPackets();
     }
 
+    public KeyExchangeFuture startKex() throws Exception {
+        boolean start = output.updateState(() -> {
+            if (kexState.compareAndSet(KexState.DONE, KexState.INIT)) {
+                output.initNewKeyExchange();
+                return true;
+            }
+            return false;
+        });
+        DefaultKeyExchangeFuture result = new 
DefaultKeyExchangeFuture(session.toString(), session.getFutureLock());
+        if (start) {
+            listeners.forEach(listener -> listener.event(true));
+            kexFuture.set(result);
+            input.messagesBeforeKexInit.set(0);
+            sendKexInit().addListener(f -> {
+                if (!f.isWritten()) {
+                    exceptionCaught(f.getException());
+                }
+            });
+        } else {
+            result.setValue(new SshException("KEX already ongoing"));
+        }
+        return result;
+    }
+
     // Entry points for the KexOutputHandler
     IoWriteFuture write(int cmd, Buffer buffer, boolean checkForKex) throws 
IOException {
         IoWriteFuture result = forward.send(cmd, buffer);
@@ -1087,7 +1095,7 @@ public class KexFilter extends IoFilter {
 
     void startKexIfNeeded() throws IOException {
         KexState state = kexState.get();
-        if (state == KexState.DONE && isKexNeeded(true)) {
+        if (state == KexState.DONE && isKexNeeded()) {
             try {
                 startKex();
             } catch (IOException e) {
@@ -1100,8 +1108,6 @@ public class KexFilter extends IoFilter {
 
     private abstract class WithSequenceNumber {
 
-        volatile boolean sequenceNumberCheckEnabled = true;
-
         private int initialSequenceNumber;
 
         private boolean first = true;
@@ -1110,28 +1116,31 @@ public class KexFilter extends IoFilter {
             super();
         }
 
-        protected void checkSequence(String message, IntSupplier sequence) 
throws SshException {
+        protected void checkSequence() throws SshException {
+            if (initialKexDone) {
+                return;
+            }
             if (first) {
                 first = false;
-                initialSequenceNumber = sequence.getAsInt();
-            } else if (!initialKexDone && initialSequenceNumber == 
sequence.getAsInt()) {
+                initialSequenceNumber = crypt.getInputSequenceNumber();
+            } else if (initialSequenceNumber == 
crypt.getInputSequenceNumber()) {
                 throw new 
SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED,
-                        message + " sequence number wraps around during 
initial KEX");
+                        "Incoming sequence number wraps around during initial 
KEX");
             }
         }
     }
 
     private class KexInputHandler extends WithSequenceNumber implements 
BufferInputHandler {
 
+        final AtomicLong messagesBeforeKexInit = new AtomicLong();
+
         KexInputHandler() {
             super();
         }
 
         @Override
         public void handleMessage(Buffer message) throws Exception {
-            if (sequenceNumberCheckEnabled) {
-                checkSequence("Incoming", crypt::getInputSequenceNumber);
-            }
+            checkSequence();
             int cmd = message.rawByte(message.rpos()) & 0xFF;
             if (LOG.isDebugEnabled()) {
                 LOG.debug("KexFilter.handleMessage({}) {} with packet size 
{}", getSession(),
@@ -1143,7 +1152,7 @@ public class KexFilter extends IoFilter {
                 if (cmd == SshConstants.SSH_MSG_KEXINIT) {
                     receiveKexInit(message);
                 } else {
-                    if (isKexNeeded(false)) {
+                    if (isKexNeeded()) {
                         startKex();
                     }
                     owner().passOn(message);
@@ -1187,8 +1196,48 @@ public class KexFilter extends IoFilter {
             return cmd >= SshConstants.SSH_MSG_KEXINIT && cmd <= 
SshConstants.SSH_MSG_KEX_LAST;
         }
 
+        private boolean isWantReply(Buffer message, boolean isChannelRequest) {
+            boolean wantReply = false;
+            int pos = message.rpos();
+            message.getUByte();
+            if (isChannelRequest) {
+                message.getUInt(); // Skip the channel id
+            }
+            long length = message.getUInt();
+            if (length < message.available()) {
+                wantReply = message.rawByte(pos + 5 + (int) length) != 0;
+            }
+            message.rpos(pos);
+            return wantReply;
+        }
+
         private void passOnBeforeKexInit(int cmd, Buffer message) throws 
Exception {
-            // TODO: message handling per the class javadoc.
+            if (maxMsgsBeforeKexInit > 0) {
+                long valueNow = 0;
+                switch (cmd) {
+                    case SshConstants.SSH_MSG_GLOBAL_REQUEST:
+                        if (isWantReply(message, false)) {
+                            valueNow = messagesBeforeKexInit.incrementAndGet();
+                        }
+                        break;
+                    case SshConstants.SSH_MSG_CHANNEL_REQUEST:
+                        if (isWantReply(message, true)) {
+                            valueNow = messagesBeforeKexInit.incrementAndGet();
+                        }
+                        break;
+                    case SshConstants.SSH_MSG_CHANNEL_OPEN:
+                        valueNow = messagesBeforeKexInit.incrementAndGet();
+                        break;
+                    default:
+                        // All other messages do not require a reply; see 
class comment.
+                        break;
+                }
+                if (valueNow > maxMsgsBeforeKexInit) {
+                    throw new 
SshException(SshConstants.SSH2_DISCONNECT_PROTOCOL_ERROR,
+                            "KEX: no SSH_MSG_KEX_INIT received from peer 
within MAX_MSGS_BEFORE_KEX_INIT limit "
+                                                                               
         + maxMsgsBeforeKexInit);
+                }
+            }
             owner().passOn(message);
         }
 
@@ -1211,21 +1260,25 @@ public class KexFilter extends IoFilter {
                     }
                 }
             }
-            if (kex.next(cmd, message)) {
+            KeyExchange exchange = kex.get();
+            if (exchange == null) {
+                throw new 
SshException(SshConstants.SSH2_DISCONNECT_PROTOCOL_ERROR, MessageFormat
+                        .format("KEX message {0} received at the wrong time in 
KEX", SshConstants.getCommandMessageName(cmd)));
+            }
+            if (exchange.next(cmd, message)) {
                 // We're done
                 if (hostKeyChecker != null) {
                     hostKeyChecker.check();
                 }
                 prepareNewSettings();
-                lastKexEnd.set(Instant.now());
                 sendNewKeys();
             } else if (LOG.isDebugEnabled()) {
-                LOG.debug("handleKexMessage({})[{}] more KEX packets expected 
after cmd={}", session, kex.getName(), cmd);
+                LOG.debug("handleKexMessage({})[{}] more KEX packets expected 
after cmd={}", session, exchange.getName(), cmd);
             }
         }
     }
 
-    private class Sender extends WithSequenceNumber implements OutputHandler {
+    private class Sender implements OutputHandler {
 
         Sender() {
             super();
@@ -1233,12 +1286,9 @@ public class KexFilter extends IoFilter {
 
         @Override
         public IoWriteFuture send(int cmd, Buffer message) throws IOException {
-            if (sequenceNumberCheckEnabled) {
-                checkSequence("Outgoing", crypt::getOutputSequenceNumber);
-            }
             if (LOG.isDebugEnabled()) {
                 LOG.debug("KexFilter.send({}) {} with packet size {}", 
getSession(),
-                        
SshConstants.getCommandMessageName(message.rawByte(message.rpos()) & 0xFF), 
message.available());
+                        SshConstants.getCommandMessageName(cmd), 
message.available());
             }
             return owner().send(cmd, message);
         }
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexOutputHandler.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexOutputHandler.java
index 24a28db34..d59ee3250 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexOutputHandler.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexOutputHandler.java
@@ -57,7 +57,7 @@ import org.slf4j.Logger;
  *
  * @see <a href="https://tools.ietf.org/html/rfc4253#section-7";>RFC 4253</a>
  */
-public class KexOutputHandler implements OutputHandler {
+class KexOutputHandler implements OutputHandler {
 
     // With asynchronous flushing we get a classic producer-consumer problem. 
The flushing thread is the single
     // consumer, and there is a risk that it might get overrun by the 
producers. The classical solution of using a
@@ -78,39 +78,39 @@ public class KexOutputHandler implements OutputHandler {
      *
      * @see #flushQueue(DefaultKeyExchangeFuture)
      */
-    protected static ExecutorService flushRunner = 
ThreadUtils.newCachedThreadPool("kex-flusher");
+    private static ExecutorService flushRunner = 
ThreadUtils.newCachedThreadPool("kex-flusher");
 
     /**
      * We need the flushing thread to have priority over writing threads. So 
we use a lock that favors writers over
      * readers, and any state updates and the flushing thread are writers, 
while writePacket() is a reader.
      */
-    protected final ReentrantReadWriteLock lock = new 
ReentrantReadWriteLock(false);
+    private final ReentrantReadWriteLock lock = new 
ReentrantReadWriteLock(false);
 
     /**
      * The {@link KexFilter} this {@link KexOutputHandler} belongs to.
      */
-    protected final KexFilter filter;
+    private final KexFilter filter;
 
     /**
      * The {@link Logger} to use.
      */
-    protected final Logger log;
+    private final Logger log;
 
     /**
      * Queues up high-level packets written during an ongoing key exchange.
      */
-    protected final Queue<PendingWriteFuture> pendingPackets = new 
ConcurrentLinkedQueue<>();
+    private final Queue<PendingWriteFuture> pendingPackets = new 
ConcurrentLinkedQueue<>();
 
     /**
      * Indicates that all pending packets have been flushed. Set to {@code 
true} by the flushing thread, or at the end
      * of KEX if there are no packets to be flushed. Set to {@code false} when 
a new KEX starts. Initially {@code true}.
      */
-    protected final AtomicBoolean kexFlushed = new AtomicBoolean(true);
+    private final AtomicBoolean kexFlushed = new AtomicBoolean(true);
 
     /**
      * Indicates that the handler has been shut down.
      */
-    protected final AtomicBoolean shutDown = new AtomicBoolean();
+    private final AtomicBoolean shutDown = new AtomicBoolean();
 
     /**
      * Never {@code null}. Used to block some threads when writing packets 
while pending packets are still being flushed
@@ -118,7 +118,7 @@ public class KexOutputHandler implements OutputHandler {
      * of a KEX a new future is installed, which is fulfilled at the end of 
the KEX once there are no more pending
      * packets to be flushed.
      */
-    protected final AtomicReference<DefaultKeyExchangeFuture> kexFlushedFuture 
= new AtomicReference<>();
+    private final AtomicReference<DefaultKeyExchangeFuture> kexFlushedFuture = 
new AtomicReference<>();
 
     /**
      * Creates a new {@link KexOutputHandler} for the given {@code session}, 
using the given {@code Logger}.
@@ -126,7 +126,7 @@ public class KexOutputHandler implements OutputHandler {
      * @param filter {@link KexFilter} the new instance belongs to
      * @param log    {@link Logger} to use for writing log messages
      */
-    public KexOutputHandler(KexFilter filter, Logger log) {
+    KexOutputHandler(KexFilter filter, Logger log) {
         this.filter = Objects.requireNonNull(filter);
         this.log = Objects.requireNonNull(log);
         // Start with a fulfilled kexFlushed future.
@@ -135,14 +135,14 @@ public class KexOutputHandler implements OutputHandler {
         kexFlushedFuture.set(initialFuture);
     }
 
-    public void updateState(Runnable update) {
+    void updateState(Runnable update) {
         updateState(() -> {
             update.run();
             return null;
         });
     }
 
-    public <V> V updateState(Supplier<V> update) {
+    <V> V updateState(Supplier<V> update) {
         lock.writeLock().lock();
         try {
             return update.get();
@@ -160,7 +160,7 @@ public class KexOutputHandler implements OutputHandler {
      *
      * @return the previous {@link DefaultKeyExchangeFuture} indicating 
whether all pending packets were flushed.
      */
-    public DefaultKeyExchangeFuture initNewKeyExchange() {
+    DefaultKeyExchangeFuture initNewKeyExchange() {
         return updateState(() -> {
             kexFlushed.set(false);
             return kexFlushedFuture.getAndSet(
@@ -175,7 +175,7 @@ public class KexOutputHandler implements OutputHandler {
      *
      * @return the current {@link DefaultKeyExchangeFuture} and the number of 
currently pending packets
      */
-    public SimpleImmutableEntry<Integer, DefaultKeyExchangeFuture> 
terminateKeyExchange() {
+    SimpleImmutableEntry<Integer, DefaultKeyExchangeFuture> 
terminateKeyExchange() {
         return updateState(() -> {
             int numPending = pendingPackets.size();
             if (numPending == 0) {
@@ -248,7 +248,7 @@ public class KexOutputHandler implements OutputHandler {
      * @return             an {@link IoWriteFuture} that will be fulfilled 
once the packet has indeed been written.
      * @throws IOException if an error occurs
      */
-    protected IoWriteFuture writeOrEnqueue(int cmd, Buffer buffer) throws 
IOException {
+    private IoWriteFuture writeOrEnqueue(int cmd, Buffer buffer) throws 
IOException {
         for (;;) {
             // We must decide _and_ write the packet while holding the lock. 
If we'd write the packet outside this
             // lock, there is no guarantee that a concurrently running 
KEX_INIT received from the peer doesn't change
@@ -291,7 +291,7 @@ public class KexOutputHandler implements OutputHandler {
      * @param  buffer the {@link Buffer} containing the packet to be sent
      * @return        the enqueued {@link PendingWriteFuture}
      */
-    protected PendingWriteFuture enqueuePendingPacket(int cmd, Buffer buffer) {
+    private PendingWriteFuture enqueuePendingPacket(int cmd, Buffer buffer) {
         String cmdName = SshConstants.getCommandMessageName(cmd);
         PendingWriteFuture future;
         int numPending;
@@ -321,7 +321,7 @@ public class KexOutputHandler implements OutputHandler {
      * @param flushDone the future obtained from {@code getFlushedFuture}; 
will be fulfilled once all pending packets
      *                  have been written
      */
-    protected void flushQueue(DefaultKeyExchangeFuture flushDone) {
+    void flushQueue(DefaultKeyExchangeFuture flushDone) {
         // kexFlushed must be set to true in all cases when this thread exits, 
**except** if a new KEX has started while
         // flushing.
         flushRunner.submit(() -> {
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/MessageCodingSettings.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/MessageCodingSettings.java
index 85da9cf90..9ef4b6f4b 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/MessageCodingSettings.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/MessageCodingSettings.java
@@ -26,7 +26,7 @@ import org.apache.sshd.common.util.buffer.BufferUtils;
 /**
  * Message encoding or decoding settings as determined at the end of a key 
exchange.
  */
-public class MessageCodingSettings {
+class MessageCodingSettings {
 
     private final Cipher cipher;
 
@@ -40,7 +40,7 @@ public class MessageCodingSettings {
 
     private byte[] iv;
 
-    public MessageCodingSettings(Cipher cipher, Mac mac, Compression 
compression, Cipher.Mode mode, byte[] key, byte[] iv) {
+    MessageCodingSettings(Cipher cipher, Mac mac, Compression compression, 
Cipher.Mode mode, byte[] key, byte[] iv) {
         this.cipher = cipher;
         this.mac = mac;
         this.compression = compression;
@@ -67,16 +67,16 @@ public class MessageCodingSettings {
      * @return                      the fully initialized cipher
      * @throws Exception            if the cipher cannot be initialized
      */
-    public Cipher getCipher(long packetSequenceNumber) throws Exception {
+    Cipher getCipher(long packetSequenceNumber) throws Exception {
         initCipher(packetSequenceNumber);
         return cipher;
     }
 
-    public Mac getMac() {
+    Mac getMac() {
         return mac;
     }
 
-    public Compression getCompression() {
+    Compression getCompression() {
         return compression;
     }
 }
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java 
b/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java
index 2bec5c3a3..546447dab 100644
--- a/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java
+++ b/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java
@@ -799,6 +799,28 @@ public final class CoreModuleProperties {
         }
     });
 
+    public static final int DEFAULT_MAX_MSGS_BEFORE_KEX_INIT = 1_000;
+
+    /**
+     * After having sent our own SSH_MSG_KEXINIT when starting a new key 
exchange, we expect the peer's SSH_MSG_KEXINIT.
+     * But a peer can send any number of other messages first. If those 
messages require replies, we must buffer these
+     * replies because once we've sent SSH_MSG_KEXINIT, we may send only key 
exchange messages until the key exchange is
+     * over. Buffering replies means a broken peer that just doesn't send its 
SSH_MSG_KEXINIT could cause unbounded
+     * memory consumption on our end. This concerns in particular 
SSH_MSG_GLOBAL_REQUEST with the "want-reply" flag
+     * {@code true}, SSH_MSG_CHANNEL_REQUEST and SSH_MSG_CHANNEL_OPEN. This 
property set an upper limit on the number of
+     * such messages that we'll handle if no SSH_MSG_KEXINIT from the peer is 
received yet after having sent our own. If
+     * the limit is exceeded, the session disconnects.
+     *
+     * <p>
+     * The default value is {@link #DEFAULT_MAX_MSGS_BEFORE_KEX_INIT} (1000), 
which should be very generous. If zero or
+     * negative, no limit is assumed.
+     * </p>
+     *
+     * @see #DEFAULT_MAX_MSGS_BEFORE_KEX_INIT
+     */
+    public static final Property<Integer> MAX_MSGS_BEFORE_KEX_INIT = 
Property.integer("kex-max-msgs-before-kex-init",
+            DEFAULT_MAX_MSGS_BEFORE_KEX_INIT);
+
     private CoreModuleProperties() {
         throw new UnsupportedOperationException("No instance");
     }


Reply via email to