This is an automated email from the ASF dual-hosted git repository. twolf pushed a commit to branch dev_3.0 in repository https://gitbox.apache.org/repos/asf/mina-sshd.git
commit 06d3813910e060cf725d231896aab0f87d1a22ae Author: Thomas Wolf <tw...@apache.org> AuthorDate: Wed Apr 2 20:58:29 2025 +0200 Introduce a compound SshTransportFilter Introduce a SshTransportFilter that encapsulates the RFC 4253 protocol up to and including the key exchange. --- .../sshd/client/session/ClientSessionImpl.java | 4 +- .../sshd/common/session/filters/CryptFilter.java | 18 +- .../common/session/filters/SshTransportFilter.java | 228 +++++++++++++++++++++ .../sshd/common/session/filters/kex/KexFilter.java | 10 +- .../common/session/helpers/AbstractSession.java | 114 ++++------- .../sshd/server/session/AbstractServerSession.java | 4 +- 6 files changed, 284 insertions(+), 94 deletions(-) diff --git a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSessionImpl.java b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSessionImpl.java index cdc5d19ba..d713d386e 100644 --- a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSessionImpl.java +++ b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSessionImpl.java @@ -98,8 +98,8 @@ public class ClientSessionImpl extends AbstractClientSession { @Override public void setAuthenticated() throws IOException { - getCompressionFilter().enableInput(); - getCompressionFilter().enableOutput(); + getTransport().enableInputCompression(); + getTransport().enableOutputCompression(); super.setAuthenticated(); } diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/CryptFilter.java b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/CryptFilter.java index 7dd7404cc..af9dbfc55 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/CryptFilter.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/CryptFilter.java @@ -55,6 +55,11 @@ public class CryptFilter extends IoFilter implements CryptStatisticsProvider { */ public static final int MAX_PADDING = 127; + /** + * An arbitrary constant >= the largest authentication tag size we will ever have. + */ + public static final int MAX_TAG_LENGTH = 64; + private static final Logger LOG = LoggerFactory.getLogger(CryptFilter.class); // The minimum value for the packet length field of a valid SSH packet: @@ -167,19 +172,6 @@ public class CryptFilter extends IoFilter implements CryptStatisticsProvider { return decryption.get().isSecure() && encryption.get().isSecure(); } - /** - * Performs a best-effort precalculation of the needed packet buffer size assuming an a priori known packet length. - * This may help avoid needing to grow the buffer later on. - * - * @param knownPayloadLength expected payload length - * @return a buffer size that will be sufficient to hold the full SSH packet including header, - * padding, and MAC, if any. - */ - public int precomputeBufferLength(int knownPayloadLength) { - Settings out = getOutputSettings(); - return knownPayloadLength + SshConstants.SSH_PACKET_HEADER_LEN + MAX_PADDING + out.getTagSize(); - } - public void addEncryptionListener(EncryptionListener listener) { listeners.addIfAbsent(Objects.requireNonNull(listener)); } diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/SshTransportFilter.java b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/SshTransportFilter.java new file mode 100644 index 000000000..c0180b951 --- /dev/null +++ b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/SshTransportFilter.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd.common.session.filters; + +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.sshd.common.cipher.CipherInformation; +import org.apache.sshd.common.compression.CompressionInformation; +import org.apache.sshd.common.filter.DefaultFilterChain; +import org.apache.sshd.common.filter.FilterChain; +import org.apache.sshd.common.filter.InputHandler; +import org.apache.sshd.common.filter.IoFilter; +import org.apache.sshd.common.filter.OutputHandler; +import org.apache.sshd.common.future.KeyExchangeFuture; +import org.apache.sshd.common.kex.KexProposalOption; +import org.apache.sshd.common.kex.KexState; +import org.apache.sshd.common.mac.MacInformation; +import org.apache.sshd.common.random.Random; +import org.apache.sshd.common.session.SessionListener; +import org.apache.sshd.common.session.filters.CryptFilter.EncryptionListener; +import org.apache.sshd.common.session.filters.kex.KexFilter; +import org.apache.sshd.common.session.filters.kex.KexFilter.HostKeyChecker; +import org.apache.sshd.common.session.filters.kex.KexFilter.Proposer; +import org.apache.sshd.common.session.filters.kex.KexListener; +import org.apache.sshd.common.session.helpers.AbstractSession; + +/** + * A filter encapsulating the basic SSH transport up to and including KEX. + */ +public class SshTransportFilter extends IoFilter { + + private final FilterChain filters = new DefaultFilterChain(); + + private final CryptFilter cryptFilter; + private final CompressionFilter compressionFilter; + private final KexFilter kexFilter; + + /** + * Creates a new SSH transport filter. + * + * @param session {@link AbstractSession} this filter is for + * @param random {@link Random} instance to use + * @param identities {@link SshIdentHandler} for handling the SSH identificaton string + * @param events {@link SessionListener} to report some events + * @param cryptListener {@link EncryptionListener} called just before a buffer is encrypted + * @param proposer {@link Proposer} to get KEX proposals + * @param checker {@link HostKeyChecker} to check the peer's host key; may be {@code null} if on a server + */ + public SshTransportFilter(AbstractSession session, Random random, SshIdentHandler identities, SessionListener events, + EncryptionListener cryptListener, Proposer proposer, HostKeyChecker checker) { + IdentFilter ident = new IdentFilter(); + ident.setPropertyResolver(session); + ident.setIdentHandler(identities); + filters.addLast(ident); + + cryptFilter = new CryptFilter(); + cryptFilter.setSession(session); + cryptFilter.setRandom(random); + cryptFilter.addEncryptionListener(cryptListener); + filters.addLast(cryptFilter); + + compressionFilter = new CompressionFilter(); + compressionFilter.setSession(session); + filters.addLast(compressionFilter); + + DelayKexInitFilter delayKexFilter = new DelayKexInitFilter(); + delayKexFilter.setSession(session); + filters.addLast(delayKexFilter); + + filters.addLast(new InjectIgnoreFilter(session, random)); + + kexFilter = new KexFilter(session, random, cryptFilter, compressionFilter, events, proposer, checker); + filters.addLast(kexFilter); + + ident.addIdentListener((peer, id) -> { + if (peer == session.isServerSession()) { + kexFilter.setClientIdent(id); + } else { + kexFilter.setServerIdent(id); + } + }); + filters.addFirst(new InConnector(this)); + filters.addLast(new OutConnector(this)); + } + + @Override + public InputHandler in() { + return filters.getFirst().in(); + } + + @Override + public OutputHandler out() { + return filters.getLast().out(); + } + + public KeyExchangeFuture startKex() throws Exception { + return kexFilter.startKex(); + } + + public void shutdown() { + kexFilter.shutdown(); + } + + public boolean isStrictKex() { + return kexFilter.isStrictKex(); + } + + public boolean isInitialKexDone() { + return kexFilter.isInitialKexDone(); + } + + public AtomicReference<KexState> getKexState() { + return kexFilter.getKexState(); + } + + public Map<KexProposalOption, String> getNegotiated() { + return kexFilter.getNegotiated(); + } + + public Map<KexProposalOption, String> getClientProposal() { + return kexFilter.getClientProposal(); + } + + public Map<KexProposalOption, String> getServerProposal() { + return kexFilter.getServerProposal(); + } + + public byte[] getSessionId() { + return kexFilter.getSessionId(); + } + + public void addKexListener(KexListener listener) { + kexFilter.addKexListener(listener); + } + + public void removeKexListener(KexListener listener) { + kexFilter.removeKexListener(listener); + } + + public boolean isSecure() { + return cryptFilter.isSecure(); + } + + public int getInputSequenceNumber() { + return cryptFilter.getInputSequenceNumber(); + } + + public int getOutputSequenceNumber() { + return cryptFilter.getOutputSequenceNumber(); + } + + public CipherInformation getCipherInformation(boolean incoming) { + return incoming ? cryptFilter.getInputSettings().getCipher() : cryptFilter.getOutputSettings().getCipher(); + } + + public MacInformation getMacInformation(boolean incoming) { + return incoming ? cryptFilter.getInputSettings().getMac() : cryptFilter.getOutputSettings().getMac(); + } + + public void enableInputCompression() { + compressionFilter.enableInput(); + } + + public void enableOutputCompression() { + compressionFilter.enableOutput(); + } + + public CompressionInformation getCompressionInformation(boolean incoming) { + return incoming ? compressionFilter.getInputCompression() : compressionFilter.getOutputCompression(); + } + + private static class InConnector extends IoFilter { + + private final SshTransportFilter transport; + + InConnector(SshTransportFilter transport) { + this.transport = transport; + } + + @Override + public InputHandler in() { + return owner()::passOn; + } + + @Override + public OutputHandler out() { + return transport.owner()::send; + } + + } + + private static class OutConnector extends IoFilter { + + private final SshTransportFilter transport; + + OutConnector(SshTransportFilter transport) { + this.transport = transport; + } + + @Override + public InputHandler in() { + return transport.owner()::passOn; + } + + @Override + public OutputHandler out() { + return owner()::send; + } + + } +} diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexFilter.java b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexFilter.java index 19f46c657..30cb53ef1 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexFilter.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/session/filters/kex/KexFilter.java @@ -247,7 +247,11 @@ public class KexFilter extends IoFilter { this.compression = Objects.requireNonNull(compression); this.signals = Objects.requireNonNull(listener); this.proposer = Objects.requireNonNull(proposer); - this.hostKeyChecker = Objects.requireNonNull(checker); + if (!session.isServerSession()) { + Objects.requireNonNull(checker); + } + this.hostKeyChecker = checker; + rekeyAfterBytes = CoreModuleProperties.REKEY_BYTES_LIMIT.getRequired(session); rekeyAfterPackets = CoreModuleProperties.REKEY_PACKETS_LIMIT.getRequired(session); rekeyAfterBlocks = rekeyAfterBytes / 16; // Initial setting, will be updated once we know the cipher @@ -1204,7 +1208,9 @@ public class KexFilter extends IoFilter { } if (kex.next(cmd, message)) { // We're done - hostKeyChecker.check(); + if (hostKeyChecker != null) { + hostKeyChecker.check(); + } prepareNewSettings(); lastKexEnd.set(Instant.now()); sendNewKeys(); diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java index b3f860a3b..2639812b0 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java @@ -73,13 +73,10 @@ import org.apache.sshd.common.random.Random; import org.apache.sshd.common.session.ReservedSessionMessagesHandler; import org.apache.sshd.common.session.Session; import org.apache.sshd.common.session.SessionListener; -import org.apache.sshd.common.session.filters.CompressionFilter; import org.apache.sshd.common.session.filters.CryptFilter; -import org.apache.sshd.common.session.filters.DelayKexInitFilter; -import org.apache.sshd.common.session.filters.IdentFilter; -import org.apache.sshd.common.session.filters.InjectIgnoreFilter; +import org.apache.sshd.common.session.filters.CryptFilter.EncryptionListener; import org.apache.sshd.common.session.filters.SshIdentHandler; -import org.apache.sshd.common.session.filters.kex.KexFilter; +import org.apache.sshd.common.session.filters.SshTransportFilter; import org.apache.sshd.common.session.filters.kex.KexListener; import org.apache.sshd.common.util.EventListenerUtils; import org.apache.sshd.common.util.ExceptionUtils; @@ -179,9 +176,7 @@ public abstract class AbstractSession extends SessionHelper { private final FilterChain filters = new DefaultFilterChain(); - private CryptFilter cryptFilter; - private CompressionFilter compressionFilter; - private KexFilter kexFilter; + private SshTransportFilter sshTransport; /** * Create a new session. @@ -267,9 +262,7 @@ public abstract class AbstractSession extends SessionHelper { } protected void setupFilterChain() { - IdentFilter ident = new IdentFilter(); - ident.setPropertyResolver(this); - ident.setIdentHandler(new SshIdentHandler() { + SshIdentHandler identities = new SshIdentHandler() { @Override public boolean isServer() { @@ -317,32 +310,8 @@ public abstract class AbstractSession extends SessionHelper { } return lines; } - }); - filters.addLast(ident); - - cryptFilter = new CryptFilter(); - cryptFilter.setSession(this); - cryptFilter.setRandom(random); - cryptFilter.addEncryptionListener((buffer, sequenceNumber) -> { - // SSHD-968 - remember global request outgoing sequence number - LongConsumer setter = globalSequenceNumbers.remove(buffer); - if (setter != null) { - setter.accept(sequenceNumber); - } - }); - filters.addLast(cryptFilter); - - compressionFilter = new CompressionFilter(); - compressionFilter.setSession(this); - filters.addLast(compressionFilter); - - DelayKexInitFilter delayKexFilter = new DelayKexInitFilter(); - delayKexFilter.setSession(this); - filters.addLast(delayKexFilter); - - filters.addLast(new InjectIgnoreFilter(this, random)); - - kexFilter = new KexFilter(this, random, cryptFilter, compressionFilter, new SessionListener() { + }; + SessionListener sessionEvents = new SessionListener() { @Override public void sessionNegotiationStart( @@ -369,16 +338,17 @@ public abstract class AbstractSession extends SessionHelper { throw new RuntimeSshException(e.getMessage(), e); } } - }, this::getKexProposal, this::checkKeys); - filters.addLast(kexFilter); - - ident.addIdentListener((peer, id) -> { - if (peer == isServerSession()) { - kexFilter.setClientIdent(id); - } else { - kexFilter.setServerIdent(id); + }; + EncryptionListener sequenceListener = (buffer, sequenceNumber) -> { + // SSHD-968 - remember global request outgoing sequence number + LongConsumer setter = globalSequenceNumbers.remove(buffer); + if (setter != null) { + setter.accept(sequenceNumber); } - }); + }; + sshTransport = new SshTransportFilter(this, random, identities, sessionEvents, sequenceListener, + this::getKexProposal, this::checkKeys); + filters.addLast(sshTransport); } @Override @@ -386,24 +356,24 @@ public abstract class AbstractSession extends SessionHelper { return filters; } - protected boolean isConnectionSecure() { - return cryptFilter.isSecure(); + protected SshTransportFilter getTransport() { + return sshTransport; } - protected CompressionFilter getCompressionFilter() { - return compressionFilter; + protected boolean isConnectionSecure() { + return sshTransport.isSecure(); } public void addKexListener(KexListener listener) { - kexFilter.addKexListener(listener); + sshTransport.addKexListener(listener); } public void removeKexListener(KexListener listener) { - kexFilter.addKexListener(listener); + sshTransport.addKexListener(listener); } protected void initializeKeyExchangePhase() throws Exception { - KeyExchangeFuture future = kexFilter.startKex(); + KeyExchangeFuture future = sshTransport.startKex(); Throwable t = future.getException(); if (t != null) { if (t instanceof Exception) { @@ -415,7 +385,7 @@ public abstract class AbstractSession extends SessionHelper { } protected boolean isStrictKex() { - return kexFilter.isStrictKex(); + return sshTransport.isStrictKex(); } /** @@ -438,7 +408,7 @@ public abstract class AbstractSession extends SessionHelper { @Override public Map<KexProposalOption, String> getServerKexProposals() { - return kexFilter.getServerProposal(); + return sshTransport.getServerProposal(); } @Override @@ -448,42 +418,42 @@ public abstract class AbstractSession extends SessionHelper { @Override public Map<KexProposalOption, String> getClientKexProposals() { - return kexFilter.getClientProposal(); + return sshTransport.getClientProposal(); } @Override public KexState getKexState() { - return kexFilter.getKexState().get(); + return sshTransport.getKexState().get(); } @Override public byte[] getSessionId() { - return kexFilter.getSessionId(); + return sshTransport.getSessionId(); } @Override public Map<KexProposalOption, String> getKexNegotiationResult() { - return kexFilter.getNegotiated(); + return sshTransport.getNegotiated(); } @Override public String getNegotiatedKexParameter(KexProposalOption paramType) { - return kexFilter.getNegotiated().get(paramType); + return sshTransport.getNegotiated().get(paramType); } @Override public CipherInformation getCipherInformation(boolean incoming) { - return incoming ? cryptFilter.getInputSettings().getCipher() : cryptFilter.getOutputSettings().getCipher(); + return sshTransport.getCipherInformation(incoming); } @Override public CompressionInformation getCompressionInformation(boolean incoming) { - return incoming ? compressionFilter.getInputCompression() : compressionFilter.getOutputCompression(); + return sshTransport.getCompressionInformation(incoming); } @Override public MacInformation getMacInformation(boolean incoming) { - return incoming ? cryptFilter.getInputSettings().getMac() : cryptFilter.getOutputSettings().getMac(); + return sshTransport.getMacInformation(incoming); } /** @@ -498,7 +468,7 @@ public abstract class AbstractSession extends SessionHelper { protected void handleMessage(Buffer buffer) throws Exception { int cmd = buffer.getUByte(); if (log.isDebugEnabled()) { - log.debug("doHandleMessage({}) process #{} {}", this, cryptFilter.getInputSequenceNumber() - 1, + log.debug("doHandleMessage({}) process #{} {}", this, sshTransport.getInputSequenceNumber() - 1, SshConstants.getCommandMessageName(cmd)); } @@ -622,8 +592,8 @@ public abstract class AbstractSession extends SessionHelper { @Override protected void preClose() { - if (kexFilter != null) { - kexFilter.shutdown(); + if (sshTransport != null) { + sshTransport.shutdown(); } // if anyone waiting for global response notify them about the closing session @@ -913,13 +883,7 @@ public abstract class AbstractSession extends SessionHelper { // Since the caller claims to know how many bytes they will need // increase their request to account for our headers/footers if // they actually send exactly this amount. - int finalLength; - if (cryptFilter != null) { - finalLength = cryptFilter.precomputeBufferLength(len); - } else { - // Can occur in some tests - finalLength = len + SshConstants.SSH_PACKET_HEADER_LEN + 255 + 32; - } + int finalLength = len + SshConstants.SSH_PACKET_HEADER_LEN + CryptFilter.MAX_PADDING + CryptFilter.MAX_TAG_LENGTH; return prepareBuffer(cmd, new PacketBuffer(new byte[finalLength], false)); } @@ -974,7 +938,7 @@ public abstract class AbstractSession extends SessionHelper { return null; } - int seq = cryptFilter.getInputSequenceNumber() - 1; + int seq = sshTransport.getInputSequenceNumber() - 1; return sendNotImplemented(seq & 0xFFFF_FFFFL); } @@ -1153,7 +1117,7 @@ public abstract class AbstractSession extends SessionHelper { @Override public KeyExchangeFuture reExchangeKeys() throws IOException { try { - return kexFilter.startKex(); + return sshTransport.startKex(); } catch (GeneralSecurityException e) { debug("reExchangeKeys({}) failed ({}) to request new keys: {}", this, e.getClass().getSimpleName(), e.getMessage(), e); diff --git a/sshd-core/src/main/java/org/apache/sshd/server/session/AbstractServerSession.java b/sshd-core/src/main/java/org/apache/sshd/server/session/AbstractServerSession.java index 0a50314ac..5500df52b 100644 --- a/sshd-core/src/main/java/org/apache/sshd/server/session/AbstractServerSession.java +++ b/sshd-core/src/main/java/org/apache/sshd/server/session/AbstractServerSession.java @@ -301,9 +301,9 @@ public abstract class AbstractServerSession extends AbstractSession implements S IoSession networkSession = getIoSession(); setUsername(username); setAuthenticated(); - getCompressionFilter().enableInput(); + getTransport().enableInputCompression(); startService(authService, buffer); - IoWriteFuture future = writePacket(response).addListener(f -> getCompressionFilter().enableOutput()); + IoWriteFuture future = writePacket(response).addListener(f -> getTransport().enableOutputCompression()); resetIdleTimeout(); log.info("Session {}@{} authenticated", username, networkSession.getRemoteAddress());