This is an automated email from the ASF dual-hosted git repository.

twolf pushed a commit to branch dev_3.0
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git

commit 06d3813910e060cf725d231896aab0f87d1a22ae
Author: Thomas Wolf <tw...@apache.org>
AuthorDate: Wed Apr 2 20:58:29 2025 +0200

    Introduce a compound SshTransportFilter
    
    Introduce a SshTransportFilter that encapsulates the RFC 4253 protocol
    up to and including the key exchange.
---
 .../sshd/client/session/ClientSessionImpl.java     |   4 +-
 .../sshd/common/session/filters/CryptFilter.java   |  18 +-
 .../common/session/filters/SshTransportFilter.java | 228 +++++++++++++++++++++
 .../sshd/common/session/filters/kex/KexFilter.java |  10 +-
 .../common/session/helpers/AbstractSession.java    | 114 ++++-------
 .../sshd/server/session/AbstractServerSession.java |   4 +-
 6 files changed, 284 insertions(+), 94 deletions(-)

diff --git 
a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSessionImpl.java 
b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSessionImpl.java
index cdc5d19ba..d713d386e 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSessionImpl.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSessionImpl.java
@@ -98,8 +98,8 @@ public class ClientSessionImpl extends AbstractClientSession {
 
     @Override
     public void setAuthenticated() throws IOException {
-        getCompressionFilter().enableInput();
-        getCompressionFilter().enableOutput();
+        getTransport().enableInputCompression();
+        getTransport().enableOutputCompression();
         super.setAuthenticated();
     }
 
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/CryptFilter.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/CryptFilter.java
index 7dd7404cc..af9dbfc55 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/CryptFilter.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/CryptFilter.java
@@ -55,6 +55,11 @@ public class CryptFilter extends IoFilter implements 
CryptStatisticsProvider {
      */
     public static final int MAX_PADDING = 127;
 
+    /**
+     * An arbitrary constant >= the largest authentication tag size we will 
ever have.
+     */
+    public static final int MAX_TAG_LENGTH = 64;
+
     private static final Logger LOG = 
LoggerFactory.getLogger(CryptFilter.class);
 
     // The minimum value for the packet length field of a valid SSH packet:
@@ -167,19 +172,6 @@ public class CryptFilter extends IoFilter implements 
CryptStatisticsProvider {
         return decryption.get().isSecure() && encryption.get().isSecure();
     }
 
-    /**
-     * Performs a best-effort precalculation of the needed packet buffer size 
assuming an a priori known packet length.
-     * This may help avoid needing to grow the buffer later on.
-     *
-     * @param  knownPayloadLength expected payload length
-     * @return                    a buffer size that will be sufficient to 
hold the full SSH packet including header,
-     *                            padding, and MAC, if any.
-     */
-    public int precomputeBufferLength(int knownPayloadLength) {
-        Settings out = getOutputSettings();
-        return knownPayloadLength + SshConstants.SSH_PACKET_HEADER_LEN + 
MAX_PADDING + out.getTagSize();
-    }
-
     public void addEncryptionListener(EncryptionListener listener) {
         listeners.addIfAbsent(Objects.requireNonNull(listener));
     }
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/SshTransportFilter.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/SshTransportFilter.java
new file mode 100644
index 000000000..c0180b951
--- /dev/null
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/SshTransportFilter.java
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sshd.common.session.filters;
+
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicReference;
+
+import org.apache.sshd.common.cipher.CipherInformation;
+import org.apache.sshd.common.compression.CompressionInformation;
+import org.apache.sshd.common.filter.DefaultFilterChain;
+import org.apache.sshd.common.filter.FilterChain;
+import org.apache.sshd.common.filter.InputHandler;
+import org.apache.sshd.common.filter.IoFilter;
+import org.apache.sshd.common.filter.OutputHandler;
+import org.apache.sshd.common.future.KeyExchangeFuture;
+import org.apache.sshd.common.kex.KexProposalOption;
+import org.apache.sshd.common.kex.KexState;
+import org.apache.sshd.common.mac.MacInformation;
+import org.apache.sshd.common.random.Random;
+import org.apache.sshd.common.session.SessionListener;
+import org.apache.sshd.common.session.filters.CryptFilter.EncryptionListener;
+import org.apache.sshd.common.session.filters.kex.KexFilter;
+import org.apache.sshd.common.session.filters.kex.KexFilter.HostKeyChecker;
+import org.apache.sshd.common.session.filters.kex.KexFilter.Proposer;
+import org.apache.sshd.common.session.filters.kex.KexListener;
+import org.apache.sshd.common.session.helpers.AbstractSession;
+
+/**
+ * A filter encapsulating the basic SSH transport up to and including KEX.
+ */
+public class SshTransportFilter extends IoFilter {
+
+    private final FilterChain filters = new DefaultFilterChain();
+
+    private final CryptFilter cryptFilter;
+    private final CompressionFilter compressionFilter;
+    private final KexFilter kexFilter;
+
+    /**
+     * Creates a new SSH transport filter.
+     *
+     * @param session       {@link AbstractSession} this filter is for
+     * @param random        {@link Random} instance to use
+     * @param identities    {@link SshIdentHandler} for handling the SSH 
identificaton string
+     * @param events        {@link SessionListener} to report some events
+     * @param cryptListener {@link EncryptionListener} called just before a 
buffer is encrypted
+     * @param proposer      {@link Proposer} to get KEX proposals
+     * @param checker       {@link HostKeyChecker} to check the peer's host 
key; may be {@code null} if on a server
+     */
+    public SshTransportFilter(AbstractSession session, Random random, 
SshIdentHandler identities, SessionListener events,
+                              EncryptionListener cryptListener, Proposer 
proposer, HostKeyChecker checker) {
+        IdentFilter ident = new IdentFilter();
+        ident.setPropertyResolver(session);
+        ident.setIdentHandler(identities);
+        filters.addLast(ident);
+
+        cryptFilter = new CryptFilter();
+        cryptFilter.setSession(session);
+        cryptFilter.setRandom(random);
+        cryptFilter.addEncryptionListener(cryptListener);
+        filters.addLast(cryptFilter);
+
+        compressionFilter = new CompressionFilter();
+        compressionFilter.setSession(session);
+        filters.addLast(compressionFilter);
+
+        DelayKexInitFilter delayKexFilter = new DelayKexInitFilter();
+        delayKexFilter.setSession(session);
+        filters.addLast(delayKexFilter);
+
+        filters.addLast(new InjectIgnoreFilter(session, random));
+
+        kexFilter = new KexFilter(session, random, cryptFilter, 
compressionFilter, events, proposer, checker);
+        filters.addLast(kexFilter);
+
+        ident.addIdentListener((peer, id) -> {
+            if (peer == session.isServerSession()) {
+                kexFilter.setClientIdent(id);
+            } else {
+                kexFilter.setServerIdent(id);
+            }
+        });
+        filters.addFirst(new InConnector(this));
+        filters.addLast(new OutConnector(this));
+    }
+
+    @Override
+    public InputHandler in() {
+        return filters.getFirst().in();
+    }
+
+    @Override
+    public OutputHandler out() {
+        return filters.getLast().out();
+    }
+
+    public KeyExchangeFuture startKex() throws Exception {
+        return kexFilter.startKex();
+    }
+
+    public void shutdown() {
+        kexFilter.shutdown();
+    }
+
+    public boolean isStrictKex() {
+        return kexFilter.isStrictKex();
+    }
+
+    public boolean isInitialKexDone() {
+        return kexFilter.isInitialKexDone();
+    }
+
+    public AtomicReference<KexState> getKexState() {
+        return kexFilter.getKexState();
+    }
+
+    public Map<KexProposalOption, String> getNegotiated() {
+        return kexFilter.getNegotiated();
+    }
+
+    public Map<KexProposalOption, String> getClientProposal() {
+        return kexFilter.getClientProposal();
+    }
+
+    public Map<KexProposalOption, String> getServerProposal() {
+        return kexFilter.getServerProposal();
+    }
+
+    public byte[] getSessionId() {
+        return kexFilter.getSessionId();
+    }
+
+    public void addKexListener(KexListener listener) {
+        kexFilter.addKexListener(listener);
+    }
+
+    public void removeKexListener(KexListener listener) {
+        kexFilter.removeKexListener(listener);
+    }
+
+    public boolean isSecure() {
+        return cryptFilter.isSecure();
+    }
+
+    public int getInputSequenceNumber() {
+        return cryptFilter.getInputSequenceNumber();
+    }
+
+    public int getOutputSequenceNumber() {
+        return cryptFilter.getOutputSequenceNumber();
+    }
+
+    public CipherInformation getCipherInformation(boolean incoming) {
+        return incoming ? cryptFilter.getInputSettings().getCipher() : 
cryptFilter.getOutputSettings().getCipher();
+    }
+
+    public MacInformation getMacInformation(boolean incoming) {
+        return incoming ? cryptFilter.getInputSettings().getMac() : 
cryptFilter.getOutputSettings().getMac();
+    }
+
+    public void enableInputCompression() {
+        compressionFilter.enableInput();
+    }
+
+    public void enableOutputCompression() {
+        compressionFilter.enableOutput();
+    }
+
+    public CompressionInformation getCompressionInformation(boolean incoming) {
+        return incoming ? compressionFilter.getInputCompression() : 
compressionFilter.getOutputCompression();
+    }
+
+    private static class InConnector extends IoFilter {
+
+        private final SshTransportFilter transport;
+
+        InConnector(SshTransportFilter transport) {
+            this.transport = transport;
+        }
+
+        @Override
+        public InputHandler in() {
+            return owner()::passOn;
+        }
+
+        @Override
+        public OutputHandler out() {
+            return transport.owner()::send;
+        }
+
+    }
+
+    private static class OutConnector extends IoFilter {
+
+        private final SshTransportFilter transport;
+
+        OutConnector(SshTransportFilter transport) {
+            this.transport = transport;
+        }
+
+        @Override
+        public InputHandler in() {
+            return transport.owner()::passOn;
+        }
+
+        @Override
+        public OutputHandler out() {
+            return owner()::send;
+        }
+
+    }
+}
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexFilter.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexFilter.java
index 19f46c657..30cb53ef1 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexFilter.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexFilter.java
@@ -247,7 +247,11 @@ public class KexFilter extends IoFilter {
         this.compression = Objects.requireNonNull(compression);
         this.signals = Objects.requireNonNull(listener);
         this.proposer = Objects.requireNonNull(proposer);
-        this.hostKeyChecker = Objects.requireNonNull(checker);
+        if (!session.isServerSession()) {
+            Objects.requireNonNull(checker);
+        }
+        this.hostKeyChecker = checker;
+
         rekeyAfterBytes = 
CoreModuleProperties.REKEY_BYTES_LIMIT.getRequired(session);
         rekeyAfterPackets = 
CoreModuleProperties.REKEY_PACKETS_LIMIT.getRequired(session);
         rekeyAfterBlocks = rekeyAfterBytes / 16; // Initial setting, will be 
updated once we know the cipher
@@ -1204,7 +1208,9 @@ public class KexFilter extends IoFilter {
             }
             if (kex.next(cmd, message)) {
                 // We're done
-                hostKeyChecker.check();
+                if (hostKeyChecker != null) {
+                    hostKeyChecker.check();
+                }
                 prepareNewSettings();
                 lastKexEnd.set(Instant.now());
                 sendNewKeys();
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
index b3f860a3b..2639812b0 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
@@ -73,13 +73,10 @@ import org.apache.sshd.common.random.Random;
 import org.apache.sshd.common.session.ReservedSessionMessagesHandler;
 import org.apache.sshd.common.session.Session;
 import org.apache.sshd.common.session.SessionListener;
-import org.apache.sshd.common.session.filters.CompressionFilter;
 import org.apache.sshd.common.session.filters.CryptFilter;
-import org.apache.sshd.common.session.filters.DelayKexInitFilter;
-import org.apache.sshd.common.session.filters.IdentFilter;
-import org.apache.sshd.common.session.filters.InjectIgnoreFilter;
+import org.apache.sshd.common.session.filters.CryptFilter.EncryptionListener;
 import org.apache.sshd.common.session.filters.SshIdentHandler;
-import org.apache.sshd.common.session.filters.kex.KexFilter;
+import org.apache.sshd.common.session.filters.SshTransportFilter;
 import org.apache.sshd.common.session.filters.kex.KexListener;
 import org.apache.sshd.common.util.EventListenerUtils;
 import org.apache.sshd.common.util.ExceptionUtils;
@@ -179,9 +176,7 @@ public abstract class AbstractSession extends SessionHelper 
{
 
     private final FilterChain filters = new DefaultFilterChain();
 
-    private CryptFilter cryptFilter;
-    private CompressionFilter compressionFilter;
-    private KexFilter kexFilter;
+    private SshTransportFilter sshTransport;
 
     /**
      * Create a new session.
@@ -267,9 +262,7 @@ public abstract class AbstractSession extends SessionHelper 
{
     }
 
     protected void setupFilterChain() {
-        IdentFilter ident = new IdentFilter();
-        ident.setPropertyResolver(this);
-        ident.setIdentHandler(new SshIdentHandler() {
+        SshIdentHandler identities = new SshIdentHandler() {
 
             @Override
             public boolean isServer() {
@@ -317,32 +310,8 @@ public abstract class AbstractSession extends 
SessionHelper {
                 }
                 return lines;
             }
-        });
-        filters.addLast(ident);
-
-        cryptFilter = new CryptFilter();
-        cryptFilter.setSession(this);
-        cryptFilter.setRandom(random);
-        cryptFilter.addEncryptionListener((buffer, sequenceNumber) -> {
-            // SSHD-968 - remember global request outgoing sequence number
-            LongConsumer setter = globalSequenceNumbers.remove(buffer);
-            if (setter != null) {
-                setter.accept(sequenceNumber);
-            }
-        });
-        filters.addLast(cryptFilter);
-
-        compressionFilter = new CompressionFilter();
-        compressionFilter.setSession(this);
-        filters.addLast(compressionFilter);
-
-        DelayKexInitFilter delayKexFilter = new DelayKexInitFilter();
-        delayKexFilter.setSession(this);
-        filters.addLast(delayKexFilter);
-
-        filters.addLast(new InjectIgnoreFilter(this, random));
-
-        kexFilter = new KexFilter(this, random, cryptFilter, 
compressionFilter, new SessionListener() {
+        };
+        SessionListener sessionEvents = new SessionListener() {
 
             @Override
             public void sessionNegotiationStart(
@@ -369,16 +338,17 @@ public abstract class AbstractSession extends 
SessionHelper {
                     throw new RuntimeSshException(e.getMessage(), e);
                 }
             }
-        }, this::getKexProposal, this::checkKeys);
-        filters.addLast(kexFilter);
-
-        ident.addIdentListener((peer, id) -> {
-            if (peer == isServerSession()) {
-                kexFilter.setClientIdent(id);
-            } else {
-                kexFilter.setServerIdent(id);
+        };
+        EncryptionListener sequenceListener = (buffer, sequenceNumber) -> {
+            // SSHD-968 - remember global request outgoing sequence number
+            LongConsumer setter = globalSequenceNumbers.remove(buffer);
+            if (setter != null) {
+                setter.accept(sequenceNumber);
             }
-        });
+        };
+        sshTransport = new SshTransportFilter(this, random, identities, 
sessionEvents, sequenceListener,
+                this::getKexProposal, this::checkKeys);
+        filters.addLast(sshTransport);
     }
 
     @Override
@@ -386,24 +356,24 @@ public abstract class AbstractSession extends 
SessionHelper {
         return filters;
     }
 
-    protected boolean isConnectionSecure() {
-        return cryptFilter.isSecure();
+    protected SshTransportFilter getTransport() {
+        return sshTransport;
     }
 
-    protected CompressionFilter getCompressionFilter() {
-        return compressionFilter;
+    protected boolean isConnectionSecure() {
+        return sshTransport.isSecure();
     }
 
     public void addKexListener(KexListener listener) {
-        kexFilter.addKexListener(listener);
+        sshTransport.addKexListener(listener);
     }
 
     public void removeKexListener(KexListener listener) {
-        kexFilter.addKexListener(listener);
+        sshTransport.addKexListener(listener);
     }
 
     protected void initializeKeyExchangePhase() throws Exception {
-        KeyExchangeFuture future = kexFilter.startKex();
+        KeyExchangeFuture future = sshTransport.startKex();
         Throwable t = future.getException();
         if (t != null) {
             if (t instanceof Exception) {
@@ -415,7 +385,7 @@ public abstract class AbstractSession extends SessionHelper 
{
     }
 
     protected boolean isStrictKex() {
-        return kexFilter.isStrictKex();
+        return sshTransport.isStrictKex();
     }
 
     /**
@@ -438,7 +408,7 @@ public abstract class AbstractSession extends SessionHelper 
{
 
     @Override
     public Map<KexProposalOption, String> getServerKexProposals() {
-        return kexFilter.getServerProposal();
+        return sshTransport.getServerProposal();
     }
 
     @Override
@@ -448,42 +418,42 @@ public abstract class AbstractSession extends 
SessionHelper {
 
     @Override
     public Map<KexProposalOption, String> getClientKexProposals() {
-        return kexFilter.getClientProposal();
+        return sshTransport.getClientProposal();
     }
 
     @Override
     public KexState getKexState() {
-        return kexFilter.getKexState().get();
+        return sshTransport.getKexState().get();
     }
 
     @Override
     public byte[] getSessionId() {
-        return kexFilter.getSessionId();
+        return sshTransport.getSessionId();
     }
 
     @Override
     public Map<KexProposalOption, String> getKexNegotiationResult() {
-        return kexFilter.getNegotiated();
+        return sshTransport.getNegotiated();
     }
 
     @Override
     public String getNegotiatedKexParameter(KexProposalOption paramType) {
-        return kexFilter.getNegotiated().get(paramType);
+        return sshTransport.getNegotiated().get(paramType);
     }
 
     @Override
     public CipherInformation getCipherInformation(boolean incoming) {
-        return incoming ? cryptFilter.getInputSettings().getCipher() : 
cryptFilter.getOutputSettings().getCipher();
+        return sshTransport.getCipherInformation(incoming);
     }
 
     @Override
     public CompressionInformation getCompressionInformation(boolean incoming) {
-        return incoming ? compressionFilter.getInputCompression() : 
compressionFilter.getOutputCompression();
+        return sshTransport.getCompressionInformation(incoming);
     }
 
     @Override
     public MacInformation getMacInformation(boolean incoming) {
-        return incoming ? cryptFilter.getInputSettings().getMac() : 
cryptFilter.getOutputSettings().getMac();
+        return sshTransport.getMacInformation(incoming);
     }
 
     /**
@@ -498,7 +468,7 @@ public abstract class AbstractSession extends SessionHelper 
{
     protected void handleMessage(Buffer buffer) throws Exception {
         int cmd = buffer.getUByte();
         if (log.isDebugEnabled()) {
-            log.debug("doHandleMessage({}) process #{} {}", this, 
cryptFilter.getInputSequenceNumber() - 1,
+            log.debug("doHandleMessage({}) process #{} {}", this, 
sshTransport.getInputSequenceNumber() - 1,
                     SshConstants.getCommandMessageName(cmd));
         }
 
@@ -622,8 +592,8 @@ public abstract class AbstractSession extends SessionHelper 
{
 
     @Override
     protected void preClose() {
-        if (kexFilter != null) {
-            kexFilter.shutdown();
+        if (sshTransport != null) {
+            sshTransport.shutdown();
         }
 
         // if anyone waiting for global response notify them about the closing 
session
@@ -913,13 +883,7 @@ public abstract class AbstractSession extends 
SessionHelper {
         // Since the caller claims to know how many bytes they will need
         // increase their request to account for our headers/footers if
         // they actually send exactly this amount.
-        int finalLength;
-        if (cryptFilter != null) {
-            finalLength = cryptFilter.precomputeBufferLength(len);
-        } else {
-            // Can occur in some tests
-            finalLength = len + SshConstants.SSH_PACKET_HEADER_LEN + 255 + 32;
-        }
+        int finalLength = len + SshConstants.SSH_PACKET_HEADER_LEN + 
CryptFilter.MAX_PADDING + CryptFilter.MAX_TAG_LENGTH;
         return prepareBuffer(cmd, new PacketBuffer(new byte[finalLength], 
false));
     }
 
@@ -974,7 +938,7 @@ public abstract class AbstractSession extends SessionHelper 
{
             return null;
         }
 
-        int seq = cryptFilter.getInputSequenceNumber() - 1;
+        int seq = sshTransport.getInputSequenceNumber() - 1;
         return sendNotImplemented(seq & 0xFFFF_FFFFL);
     }
 
@@ -1153,7 +1117,7 @@ public abstract class AbstractSession extends 
SessionHelper {
     @Override
     public KeyExchangeFuture reExchangeKeys() throws IOException {
         try {
-            return kexFilter.startKex();
+            return sshTransport.startKex();
         } catch (GeneralSecurityException e) {
             debug("reExchangeKeys({}) failed ({}) to request new keys: {}",
                     this, e.getClass().getSimpleName(), e.getMessage(), e);
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/server/session/AbstractServerSession.java
 
b/sshd-core/src/main/java/org/apache/sshd/server/session/AbstractServerSession.java
index 0a50314ac..5500df52b 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/server/session/AbstractServerSession.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/server/session/AbstractServerSession.java
@@ -301,9 +301,9 @@ public abstract class AbstractServerSession extends 
AbstractSession implements S
         IoSession networkSession = getIoSession();
         setUsername(username);
         setAuthenticated();
-        getCompressionFilter().enableInput();
+        getTransport().enableInputCompression();
         startService(authService, buffer);
-        IoWriteFuture future = writePacket(response).addListener(f -> 
getCompressionFilter().enableOutput());
+        IoWriteFuture future = writePacket(response).addListener(f -> 
getTransport().enableOutputCompression());
 
         resetIdleTimeout();
         log.info("Session {}@{} authenticated", username, 
networkSession.getRemoteAddress());

Reply via email to