This is an automated email from the ASF dual-hosted git repository.

twolf pushed a commit to branch dev_3.0
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git

commit 6fc3dd821b953b60bca3ee4322cb8b941976624c
Author: Thomas Wolf <tw...@apache.org>
AuthorDate: Sat Apr 5 11:30:08 2025 +0200

    Support for customizing the filter chain
    
    Simplify calling the SessionListener, and add a new listener callback
    sessionStarting(). A session listener thus has two possibilities to
    override the default filter chain: either define the whole chain in
    sessionCreated(), or add filters to the default chain in
    sessionStarting().
    
    Add tests for client and server verifying that a custom filter added in
    sessionStarting() is called.
---
 .../sshd/common/session/SessionListener.java       |  17 +-
 .../common/session/helpers/AbstractSession.java    |   1 +
 .../sshd/common/session/helpers/SessionHelper.java | 387 ++++-----------------
 .../java/org/apache/sshd/client/ClientTest.java    |  50 +++
 .../java/org/apache/sshd/server/ServerTest.java    |  46 +++
 5 files changed, 179 insertions(+), 322 deletions(-)

diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/SessionListener.java 
b/sshd-core/src/main/java/org/apache/sshd/common/session/SessionListener.java
index 74780b996..e652aaa33 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/SessionListener.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/SessionListener.java
@@ -63,7 +63,8 @@ public interface SessionListener extends SshdEventListener {
     }
 
     /**
-     * A new session just been created
+     * A new session just been created. The event is emitted before the 
session is started. The session's filter chain
+     * is not yet set up.
      *
      * @param session The created {@link Session}
      */
@@ -71,6 +72,20 @@ public interface SessionListener extends SshdEventListener {
         // ignored
     }
 
+    /**
+     * A new session is about to start. The session's filter chain is defined, 
and it will start the SSH protocol next
+     * by sending its SSH protocol identification. The listener has a last 
chance to modify the filter chain.
+     *
+     * <p>
+     * This event could be used for instance to insert a proxy filter at the 
front of the filter chain.
+     * </p>
+     *
+     * @param session the starting {@link Session}
+     */
+    default void sessionStarting(Session session) {
+        // empty
+    }
+
     /**
      * About to send identification to peer
      *
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
index 8f6bf4bf5..f3a33ef7d 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
@@ -223,6 +223,7 @@ public abstract class AbstractSession extends SessionHelper 
{
         if (filters.isEmpty()) {
             setupFilterChain();
         }
+        signalSessionStarting();
 
         IoFilter ioSessionConnector = new IoFilter() {
 
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java
index cf69e94ab..2b8e73467 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java
@@ -36,6 +36,7 @@ import java.util.NavigableSet;
 import java.util.Objects;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
 import java.util.function.Function;
 
 import org.apache.sshd.common.AttributeRepository;
@@ -47,14 +48,12 @@ import org.apache.sshd.common.SshConstants;
 import org.apache.sshd.common.SshException;
 import org.apache.sshd.common.channel.throttle.ChannelStreamWriterResolver;
 import 
org.apache.sshd.common.channel.throttle.ChannelStreamWriterResolverManager;
-import org.apache.sshd.common.digest.Digest;
 import org.apache.sshd.common.forward.Forwarder;
 import org.apache.sshd.common.io.IoSession;
 import org.apache.sshd.common.io.IoWriteFuture;
 import org.apache.sshd.common.kex.AbstractKexFactoryManager;
 import org.apache.sshd.common.kex.KexProposalOption;
 import org.apache.sshd.common.kex.extension.KexExtensionHandler;
-import org.apache.sshd.common.random.Random;
 import org.apache.sshd.common.session.ConnectionService;
 import org.apache.sshd.common.session.ReservedSessionMessagesHandler;
 import org.apache.sshd.common.session.Session;
@@ -68,9 +67,6 @@ import org.apache.sshd.common.util.GenericUtils;
 import org.apache.sshd.common.util.MapEntryUtils;
 import org.apache.sshd.common.util.ValidateUtils;
 import org.apache.sshd.common.util.buffer.Buffer;
-import org.apache.sshd.common.util.buffer.BufferUtils;
-import org.apache.sshd.common.util.buffer.ByteArrayBuffer;
-import org.apache.sshd.common.util.io.functors.Invoker;
 import org.apache.sshd.common.util.net.SshdSocketAddress;
 import org.apache.sshd.core.CoreModuleProperties;
 
@@ -554,131 +550,44 @@ public abstract class SessionHelper extends 
AbstractKexFactoryManager implements
     }
 
     protected void signalSessionEstablished(IoSession ioSession) throws 
Exception {
-        try {
-            invokeSessionSignaller(l -> {
-                signalSessionEstablished(l);
-                return null;
-            });
-        } catch (Throwable err) {
-            Throwable e = ExceptionUtils.peelException(err);
-            debug("Failed ({}) to announce session={} established: {}",
-                    e.getClass().getSimpleName(), ioSession, e.getMessage(), 
e);
-            if (e instanceof Exception) {
-                throw (Exception) e;
-            } else {
-                throw new RuntimeSshException(e);
-            }
-        }
-    }
-
-    protected void signalSessionEstablished(SessionListener listener) {
-        if (listener == null) {
-            return;
-        }
-        listener.sessionEstablished(this);
+        callSignaller("established", l -> l.sessionEstablished(this));
     }
 
     protected void signalSessionCreated(IoSession ioSession) throws Exception {
-        try {
-            invokeSessionSignaller(l -> {
-                signalSessionCreated(l);
-                return null;
-            });
-        } catch (Throwable err) {
-            Throwable e = ExceptionUtils.peelException(err);
-            debug("Failed ({}) to announce session={} created: {}",
-                    e.getClass().getSimpleName(), ioSession, e.getMessage(), 
e);
-            if (e instanceof Exception) {
-                throw (Exception) e;
-            } else {
-                throw new RuntimeSshException(e);
-            }
-        }
+        callSignaller("created", l -> l.sessionCreated(this));
     }
 
-    protected void signalSessionCreated(SessionListener listener) {
-        if (listener == null) {
-            return;
-        }
-        listener.sessionCreated(this);
+    protected void signalSessionStarting() throws Exception {
+        callSignaller("starting", l -> l.sessionStarting(this));
     }
 
     protected void signalSendIdentification(String version, List<String> 
extraLines) throws Exception {
-        try {
-            invokeSessionSignaller(l -> {
-                signalSendIdentification(l, version, extraLines);
-                return null;
-            });
-        } catch (Throwable err) {
-            Throwable e = ExceptionUtils.peelException(err);
-            if (e instanceof Exception) {
-                throw (Exception) e;
-            } else {
-                throw new RuntimeSshException(e);
-            }
-        }
+        callSignaller("send identification", l -> 
l.sessionPeerIdentificationSend(this, version, extraLines));
     }
 
-    protected void signalSendIdentification(SessionListener listener, String 
version, List<String> extraLines) {
-        if (listener == null) {
-            return;
-        }
-
-        listener.sessionPeerIdentificationSend(this, version, extraLines);
+    protected void signalReadPeerIdentificationLine(String version, 
List<String> extraLines) throws Exception {
+        callSignaller("read peer identification line", l -> 
l.sessionPeerIdentificationLine(this, version, extraLines));
     }
 
-    protected void signalReadPeerIdentificationLine(String line, List<String> 
extraLines) throws Exception {
-        try {
-            invokeSessionSignaller(l -> {
-                signalReadPeerIdentificationLine(l, line, extraLines);
-                return null;
-            });
-        } catch (Throwable err) {
-            Throwable e = ExceptionUtils.peelException(err);
-            debug("signalReadPeerIdentificationLine({}) Failed ({}) to 
announce peer={}: {}",
-                    this, e.getClass().getSimpleName(), line, e.getMessage(), 
e);
-            if (e instanceof Exception) {
-                throw (Exception) e;
-            } else {
-                throw new RuntimeSshException(e);
-            }
-        }
+    protected void signalPeerIdentificationReceived(String version, 
List<String> extraLines) throws Exception {
+        callSignaller("receive peer identification version",
+                l -> l.sessionPeerIdentificationReceived(this, version, 
extraLines));
     }
 
-    protected void signalReadPeerIdentificationLine(
-            SessionListener listener, String version, List<String> extraLines) 
{
-        if (listener == null) {
-            return;
-        }
-
-        listener.sessionPeerIdentificationLine(this, version, extraLines);
+    protected void signalNegotiationOptionsCreated(Map<KexProposalOption, 
String> proposal) {
+        callSignaller(l -> l.sessionNegotiationOptionsCreated(this, proposal));
     }
 
-    protected void signalPeerIdentificationReceived(String version, 
List<String> extraLines) throws Exception {
-        try {
-            invokeSessionSignaller(l -> {
-                signalPeerIdentificationReceived(l, version, extraLines);
-                return null;
-            });
-        } catch (Throwable err) {
-            Throwable e = ExceptionUtils.peelException(err);
-            debug("signalPeerIdentificationReceived({}) Failed ({}) to 
announce peer={}: {}",
-                    this, e.getClass().getSimpleName(), version, 
e.getMessage(), e);
-            if (e instanceof Exception) {
-                throw (Exception) e;
-            } else {
-                throw new RuntimeSshException(e);
-            }
-        }
+    protected void signalNegotiationStart(
+            Map<KexProposalOption, String> c2sOptions,
+            Map<KexProposalOption, String> s2cOptions) {
+        callSignaller(l -> l.sessionNegotiationStart(this, c2sOptions, 
s2cOptions));
     }
 
-    protected void signalPeerIdentificationReceived(
-            SessionListener listener, String version, List<String> extraLines) 
{
-        if (listener == null) {
-            return;
-        }
-
-        listener.sessionPeerIdentificationReceived(this, version, extraLines);
+    protected void signalNegotiationEnd(
+            Map<KexProposalOption, String> c2sOptions, Map<KexProposalOption, 
String> s2cOptions,
+            Map<KexProposalOption, String> negotiatedGuess, Throwable reason) {
+        callSignaller(l -> l.sessionNegotiationEnd(this, c2sOptions, 
s2cOptions, negotiatedGuess, reason));
     }
 
     /**
@@ -689,10 +598,7 @@ public abstract class SessionHelper extends 
AbstractKexFactoryManager implements
      */
     protected void signalSessionEvent(SessionListener.Event event) throws 
Exception {
         try {
-            invokeSessionSignaller(l -> {
-                signalSessionEvent(l, event);
-                return null;
-            });
+            invokeSessionSignaller(l -> l.sessionEvent(this, event));
         } catch (Throwable err) {
             Throwable t = ExceptionUtils.peelException(err);
             debug("sendSessionEvent({})[{}] failed ({}) to inform listeners: 
{}",
@@ -705,15 +611,19 @@ public abstract class SessionHelper extends 
AbstractKexFactoryManager implements
         }
     }
 
-    protected void signalSessionEvent(SessionListener listener, 
SessionListener.Event event) throws IOException {
-        if (listener == null) {
-            return;
-        }
+    protected void signalDisconnect(int code, String msg, String lang, boolean 
initiator) {
+        callSignallerSilently("signalDisconnect", l -> 
l.sessionDisconnect(this, code, msg, lang, initiator));
+    }
 
-        listener.sessionEvent(this, event);
+    protected void signalExceptionCaught(Throwable t) {
+        callSignallerSilently("signalExceptionCaught", l -> 
l.sessionException(this, t));
     }
 
-    protected void invokeSessionSignaller(Invoker<SessionListener, Void> 
invoker) throws Throwable {
+    protected void signalSessionClosed() {
+        callSignallerSilently("signalSessionClosed", l -> 
l.sessionClosed(this));
+    }
+
+    protected void invokeSessionSignaller(Consumer<SessionListener> listener) 
throws Throwable {
         FactoryManager manager = getFactoryManager();
         SessionListener[] listeners = {
                 (manager == null) ? null : manager.getSessionListenerProxy(),
@@ -727,7 +637,7 @@ public abstract class SessionHelper extends 
AbstractKexFactoryManager implements
             }
 
             try {
-                invoker.invoke(l);
+                listener.accept(l);
             } catch (Throwable t) {
                 err = ExceptionUtils.accumulateException(err, t);
             }
@@ -738,39 +648,42 @@ public abstract class SessionHelper extends 
AbstractKexFactoryManager implements
         }
     }
 
-    /**
-     * Method used while putting new keys into use that will resize the key 
used to initialize the cipher to the needed
-     * length.
-     *
-     * @param  e         the key to resize
-     * @param  kdfSize   the cipher key-derivation-factor (in bytes)
-     * @param  hash      the hash algorithm
-     * @param  k         the key exchange k parameter
-     * @param  h         the key exchange h parameter
-     * @return           the resized key
-     * @throws Exception if a problem occur while resizing the key
-     */
-    protected byte[] resizeKey(
-            byte[] e, int kdfSize, Digest hash, byte[] k, byte[] h)
-            throws Exception {
-        for (Buffer buffer = null; kdfSize > e.length; buffer = 
BufferUtils.clear(buffer)) {
-            if (buffer == null) {
-                buffer = new ByteArrayBuffer();
-            }
-
-            buffer.putBytes(k);
-            buffer.putRawBytes(h);
-            buffer.putRawBytes(e);
-            hash.update(buffer.array(), 0, buffer.available());
+    private void callSignallerSilently(String msg, Consumer<SessionListener> 
listener) {
+        try {
+            invokeSessionSignaller(listener);
+        } catch (Throwable t) {
+            Throwable e = ExceptionUtils.peelException(t);
+            debug("{}({}) {} while signal session closed: {}", msg, this, 
e.getClass().getSimpleName(), e.getMessage(), e);
+        }
+    }
 
-            byte[] foo = hash.digest();
-            byte[] bar = new byte[e.length + foo.length];
-            System.arraycopy(e, 0, bar, 0, e.length);
-            System.arraycopy(foo, 0, bar, e.length, foo.length);
-            e = bar;
+    private void callSignaller(Consumer<SessionListener> listener) {
+        try {
+            invokeSessionSignaller(listener);
+        } catch (Throwable t) {
+            Throwable err = ExceptionUtils.peelException(t);
+            if (err instanceof RuntimeException) {
+                throw (RuntimeException) err;
+            } else if (err instanceof Error) {
+                throw (Error) err;
+            } else {
+                throw new IllegalArgumentException(err);
+            }
         }
+    }
 
-        return e;
+    private void callSignaller(String msg, Consumer<SessionListener> listener) 
throws Exception {
+        try {
+            invokeSessionSignaller(listener);
+        } catch (Throwable err) {
+            Throwable e = ExceptionUtils.peelException(err);
+            debug("Failed ({}) to announce session={} {}: {}", 
e.getClass().getSimpleName(), ioSession, msg, e.getMessage(), e);
+            if (e instanceof Exception) {
+                throw (Exception) e;
+            } else {
+                throw new RuntimeSshException(e);
+            }
+        }
     }
 
     /**
@@ -786,24 +699,6 @@ public abstract class SessionHelper extends 
AbstractKexFactoryManager implements
         return (s == null) ? null : s.getRemoteAddress();
     }
 
-    protected long calculateNextIgnorePacketCount(Random r, long freq, int 
variance) {
-        if ((freq <= 0L) || (variance < 0)) {
-            return -1L;
-        }
-
-        if (variance == 0) {
-            return freq;
-        }
-
-        int extra = r.random((variance < 0) ? (0 - variance) : variance);
-        long count = (variance < 0) ? (freq - extra) : (freq + extra);
-        if (log.isTraceEnabled()) {
-            log.trace("calculateNextIgnorePacketCount({}) count={}", this, 
count);
-        }
-
-        return count;
-    }
-
     /**
      * Resolves the identification to send to the peer session by consulting 
the associated {@link FactoryManager}. If a
      * value is set, then it is <U>appended</U> to the standard {@link 
SessionContext#DEFAULT_SSH_VERSION_PREFIX}.
@@ -1027,91 +922,6 @@ public abstract class SessionHelper extends 
AbstractKexFactoryManager implements
         return proposal;
     }
 
-    protected void signalNegotiationOptionsCreated(Map<KexProposalOption, 
String> proposal) {
-        try {
-            invokeSessionSignaller(l -> {
-                signalNegotiationOptionsCreated(l, proposal);
-                return null;
-            });
-        } catch (Throwable t) {
-            Throwable err = ExceptionUtils.peelException(t);
-            if (err instanceof RuntimeException) {
-                throw (RuntimeException) err;
-            } else if (err instanceof Error) {
-                throw (Error) err;
-            } else {
-                throw new IllegalArgumentException(err);
-            }
-        }
-    }
-
-    protected void signalNegotiationOptionsCreated(SessionListener listener, 
Map<KexProposalOption, String> proposal) {
-        if (listener == null) {
-            return;
-        }
-
-        listener.sessionNegotiationOptionsCreated(this, proposal);
-    }
-
-    protected void signalNegotiationStart(
-            Map<KexProposalOption, String> c2sOptions, Map<KexProposalOption, 
String> s2cOptions) {
-        try {
-            invokeSessionSignaller(l -> {
-                signalNegotiationStart(l, c2sOptions, s2cOptions);
-                return null;
-            });
-        } catch (Throwable t) {
-            Throwable err = ExceptionUtils.peelException(t);
-            if (err instanceof RuntimeException) {
-                throw (RuntimeException) err;
-            } else if (err instanceof Error) {
-                throw (Error) err;
-            } else {
-                throw new IllegalArgumentException(err);
-            }
-        }
-    }
-
-    protected void signalNegotiationStart(
-            SessionListener listener, Map<KexProposalOption, String> 
c2sOptions, Map<KexProposalOption, String> s2cOptions) {
-        if (listener == null) {
-            return;
-        }
-
-        listener.sessionNegotiationStart(this, c2sOptions, s2cOptions);
-    }
-
-    protected void signalNegotiationEnd(
-            Map<KexProposalOption, String> c2sOptions, Map<KexProposalOption, 
String> s2cOptions,
-            Map<KexProposalOption, String> negotiatedGuess, Throwable reason) {
-        try {
-            invokeSessionSignaller(l -> {
-                signalNegotiationEnd(l, c2sOptions, s2cOptions, 
negotiatedGuess, reason);
-                return null;
-            });
-        } catch (Throwable t) {
-            Throwable err = ExceptionUtils.peelException(t);
-            if (err instanceof RuntimeException) {
-                throw (RuntimeException) err;
-            } else if (err instanceof Error) {
-                throw (Error) err;
-            } else {
-                throw new IllegalArgumentException(err);
-            }
-        }
-    }
-
-    protected void signalNegotiationEnd(
-            SessionListener listener,
-            Map<KexProposalOption, String> c2sOptions, Map<KexProposalOption, 
String> s2cOptions,
-            Map<KexProposalOption, String> negotiatedGuess, Throwable reason) {
-        if (listener == null) {
-            return;
-        }
-
-        listener.sessionNegotiationEnd(this, c2sOptions, s2cOptions, 
negotiatedGuess, reason);
-    }
-
     @Override
     public void disconnect(int reason, String msg) throws IOException {
         log.info("Disconnecting({}): {} - {}",
@@ -1185,28 +995,6 @@ public abstract class SessionHelper extends 
AbstractKexFactoryManager implements
         close(true);
     }
 
-    protected void signalDisconnect(int code, String msg, String lang, boolean 
initiator) {
-        try {
-            invokeSessionSignaller(l -> {
-                signalDisconnect(l, code, msg, lang, initiator);
-                return null;
-            });
-        } catch (Throwable err) {
-            Throwable e = ExceptionUtils.peelException(err);
-            debug("signalDisconnect({}) {}: {}",
-                    this, e.getClass().getSimpleName(), e.getMessage(), e);
-        }
-    }
-
-    protected void signalDisconnect(
-            SessionListener listener, int code, String msg, String lang, 
boolean initiator) {
-        if (listener == null) {
-            return;
-        }
-
-        listener.sessionDisconnect(this, code, msg, lang, initiator);
-    }
-
     /**
      * Handle any exceptions that occurred on this session. The session will 
be closed and a disconnect packet will be
      * sent before if the given exception is an {@link SshException}.
@@ -1245,49 +1033,6 @@ public abstract class SessionHelper extends 
AbstractKexFactoryManager implements
         close(true);
     }
 
-    protected void signalExceptionCaught(Throwable t) {
-        try {
-            invokeSessionSignaller(l -> {
-                signalExceptionCaught(l, t);
-                return null;
-            });
-        } catch (Throwable err) {
-            Throwable e = ExceptionUtils.peelException(err);
-            debug("signalExceptionCaught({}) {}: {}",
-                    this, e.getClass().getSimpleName(), e.getMessage(), e);
-        }
-    }
-
-    protected void signalExceptionCaught(SessionListener listener, Throwable 
t) {
-        if (listener == null) {
-            return;
-        }
-
-        listener.sessionException(this, t);
-    }
-
-    protected void signalSessionClosed() {
-        try {
-            invokeSessionSignaller(l -> {
-                signalSessionClosed(l);
-                return null;
-            });
-        } catch (Throwable err) {
-            Throwable e = ExceptionUtils.peelException(err);
-            debug("signalSessionClosed({}) {} while signal session closed: {}",
-                    this, e.getClass().getSimpleName(), e.getMessage(), e);
-            // Do not re-throw since session closed anyway
-        }
-    }
-
-    protected void signalSessionClosed(SessionListener listener) {
-        if (listener == null) {
-            return;
-        }
-
-        listener.sessionClosed(this);
-    }
-
     protected abstract ConnectionService getConnectionService();
 
     protected Forwarder getForwarder() {
diff --git a/sshd-core/src/test/java/org/apache/sshd/client/ClientTest.java 
b/sshd-core/src/test/java/org/apache/sshd/client/ClientTest.java
index 496fe03a4..f82eb40c4 100644
--- a/sshd-core/src/test/java/org/apache/sshd/client/ClientTest.java
+++ b/sshd-core/src/test/java/org/apache/sshd/client/ClientTest.java
@@ -88,6 +88,10 @@ import 
org.apache.sshd.common.channel.exception.SshChannelClosedException;
 import org.apache.sshd.common.config.keys.KeyUtils;
 import 
org.apache.sshd.common.config.keys.writer.openssh.OpenSSHKeyEncryptionContext;
 import 
org.apache.sshd.common.config.keys.writer.openssh.OpenSSHKeyPairResourceWriter;
+import org.apache.sshd.common.filter.FilterChain;
+import org.apache.sshd.common.filter.InputHandler;
+import org.apache.sshd.common.filter.IoFilter;
+import org.apache.sshd.common.filter.OutputHandler;
 import org.apache.sshd.common.future.CancelFuture;
 import org.apache.sshd.common.future.CancelOption;
 import org.apache.sshd.common.future.CloseFuture;
@@ -1893,6 +1897,52 @@ public class ClientTest extends BaseTestSupport {
         }
     }
 
+    @Test
+    void customFilter() throws IOException {
+        
client.setUserAuthFactories(Collections.singletonList(UserAuthPasswordFactory.INSTANCE));
+        AtomicBoolean outCalled = new AtomicBoolean();
+        AtomicBoolean inCalled = new AtomicBoolean();
+        client.addSessionListener(new SessionListener() {
+
+            @Override
+            public void sessionStarting(Session session) {
+                FilterChain filters = session.getFilterChain();
+                filters.addFirst(new IoFilter() {
+
+                    @Override
+                    public OutputHandler out() {
+                        return (cmd, msg) -> {
+                            outCalled.set(true);
+                            return owner().send(cmd, msg);
+                        };
+                    }
+
+                    @Override
+                    public InputHandler in() {
+                        return msg -> {
+                            inCalled.set(true);
+                            owner().passOn(msg);
+                        };
+                    }
+                });
+            }
+        });
+        client.start();
+
+        try (ClientSession session = client.connect(getCurrentTestName(), 
TEST_LOCALHOST, port).verify(CONNECT_TIMEOUT)
+                .getSession()) {
+            assertNotNull(clientSessionHolder.get(), "Client session creation 
not signalled");
+            session.addPasswordIdentity(getClass().getSimpleName());
+            session.addPasswordIdentity(getCurrentTestName());
+            session.auth().verify(AUTH_TIMEOUT);
+        } finally {
+            client.stop();
+        }
+        assertNull(clientSessionHolder.get(), "Session closure not signalled");
+        assertTrue(inCalled.get(), "Custom filter IN should have been called");
+        assertTrue(outCalled.get(), "Custom filter OUT should have been 
called");
+    }
+
     @Test
     @Disabled
     void connectUsingIPv6Address() throws IOException {
diff --git a/sshd-core/src/test/java/org/apache/sshd/server/ServerTest.java 
b/sshd-core/src/test/java/org/apache/sshd/server/ServerTest.java
index 39ac3f9de..3b9bb89d2 100644
--- a/sshd-core/src/test/java/org/apache/sshd/server/ServerTest.java
+++ b/sshd-core/src/test/java/org/apache/sshd/server/ServerTest.java
@@ -36,6 +36,7 @@ import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
@@ -52,6 +53,10 @@ import org.apache.sshd.common.channel.Channel;
 import org.apache.sshd.common.channel.ChannelListener;
 import org.apache.sshd.common.channel.RemoteWindow;
 import org.apache.sshd.common.channel.WindowClosedException;
+import org.apache.sshd.common.filter.FilterChain;
+import org.apache.sshd.common.filter.InputHandler;
+import org.apache.sshd.common.filter.IoFilter;
+import org.apache.sshd.common.filter.OutputHandler;
 import org.apache.sshd.common.io.IoSession;
 import org.apache.sshd.common.kex.KexProposalOption;
 import org.apache.sshd.common.session.ReservedSessionMessagesHandler;
@@ -556,6 +561,47 @@ public class ServerTest extends BaseTestSupport {
         }
     }
 
+    @Test
+    void customFilter() throws Exception {
+        AtomicBoolean outCalled = new AtomicBoolean();
+        AtomicBoolean inCalled = new AtomicBoolean();
+        sshd.addSessionListener(new SessionListener() {
+
+            @Override
+            public void sessionStarting(Session session) {
+                FilterChain filters = session.getFilterChain();
+                filters.addFirst(new IoFilter() {
+
+                    @Override
+                    public OutputHandler out() {
+                        return (cmd, msg) -> {
+                            outCalled.set(true);
+                            return owner().send(cmd, msg);
+                        };
+                    }
+
+                    @Override
+                    public InputHandler in() {
+                        return msg -> {
+                            inCalled.set(true);
+                            owner().passOn(msg);
+                        };
+                    }
+                });
+            }
+        });
+        sshd.start();
+
+        client.start();
+        try (ClientSession s = createTestClientSession(sshd)) {
+            s.close(false);
+        } finally {
+            client.stop();
+        }
+        assertTrue(inCalled.get(), "Custom filter IN should have been called");
+        assertTrue(outCalled.get(), "Custom filter OUT should have been 
called");
+    }
+
     // see SSHD-645
     @Test
     void channelStateChangeNotifications() throws Exception {

Reply via email to