This is an automated email from the ASF dual-hosted git repository. twolf pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/mina-sshd.git
commit e2e4e11ec667f9d7536b1d51f401ff47f44eed51 Author: Thomas Wolf <tw...@apache.org> AuthorDate: Thu Feb 20 23:42:57 2025 +0100 [SSHD-1161] Pubkey auth: source-address restriction in certificates OpenSSH certificates may contain a critical option "source-address". If present, it contains a comma-separated list of CIDR ranges from where the certificate is accepted for authentication.[1] Extract and parse this option, and do not accept the certificate if the client's IP address is not contained in at least one CIDR range given. [1] https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.certkeys --- CHANGES.md | 3 + .../apache/sshd/common/net/InetAddressRange.java | 441 +++++++++++++++++++++ .../sshd/common/net/InetAddressRangeTest.java | 148 +++++++ .../sshd/server/auth/pubkey/UserAuthPublicKey.java | 32 ++ .../common/auth/PublicKeyAuthenticationTest.java | 70 +++- 5 files changed, 685 insertions(+), 9 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 9ec1e4455..f163fdeaf 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -32,6 +32,9 @@ ## New Features +* [SSHD-1161](https://issues.apache.org/jira/projects/SSHD/issues/SSHD-1161) Support pubkey auth with user certificates (server-side) + * Client-side support was introduced in version 2.8.0 already + ## Potential Compatibility Issues ## Major Code Re-factoring diff --git a/sshd-common/src/main/java/org/apache/sshd/common/net/InetAddressRange.java b/sshd-common/src/main/java/org/apache/sshd/common/net/InetAddressRange.java new file mode 100644 index 000000000..ca5f46265 --- /dev/null +++ b/sshd-common/src/main/java/org/apache/sshd/common/net/InetAddressRange.java @@ -0,0 +1,441 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd.common.net; + +import java.net.InetAddress; +import java.util.Arrays; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Describes a range of IP addresses specified in CIDR notation. + */ +public final class InetAddressRange { + + private static final String IP4_BYTE = "(\\d{1,3})"; + + private static final String IP4_DOT_BYTE = "(?:\\." + IP4_BYTE + ')'; + + private static final String IP4_PREFIX = IP4_BYTE + "(?:" + IP4_DOT_BYTE + "(?:" + IP4_DOT_BYTE + "(?:" + IP4_DOT_BYTE + + ")?)?)?"; + + private static final String BITS = "(\\d{1,3})"; + + private static final Pattern IP4_CIDR = Pattern.compile('^' + IP4_PREFIX + "/" + BITS + '$'); + + private static final String IP6_WORD = "(?:[0-9a-fA-F]{1,4})"; + + private static final String IP6_PART = "(" + IP6_WORD + "(?::" + IP6_WORD + ")*+)"; + + private static final Pattern IP6_CIDR = Pattern.compile('^' + IP6_PART + "?(?:::" + IP6_PART + "?)?" + '/' + BITS + '$'); + + private final byte[] base; + + private final byte[] mask; + + private final byte[] broadcast; + + private final int networkZoneBits; + + private InetAddressRange(byte[] base, int bits) { + byte[] netmask = new byte[base.length]; + Builder.computeMask(netmask, bits); + byte[] net = Builder.and(base, netmask); + + this.broadcast = Builder.invertedOr(net, netmask); + this.base = net; + this.mask = netmask; + this.networkZoneBits = bits; + } + + /** + * Creates an {@link InetAddressRange} for a CIDR. + * + * @param cidr the CIDR + * @return an {@link InetAddressRange} + * @throws IllegalArgumentException if the {@code cidr} cannot be parsed as a CIDR. + */ + public static InetAddressRange fromCIDR(String cidr) { + return Builder.build(cidr); + } + + /** + * Tests whether a given string is a valid CIDR. + * + * @param cidr the string to test + * @return {@code true} if the string can be parsed as a CIDR; {@code false} otherwise + */ + public static boolean isCIDR(String cidr) { + try { + return fromCIDR(cidr) != null; + } catch (IllegalArgumentException e) { + return false; + } + } + + /** + * Tells whether this is an IPv4 address range. + * + * @return {@code true} if this is an IPv4 address range, {@code false} otherwise + */ + public boolean isIpV4() { + return base.length == 4; + } + + /** + * Tells whether this is an IPv6 address range. + * + * @return {@code true} if this is an IPv6 address range, {@code false} otherwise + */ + public boolean isIpV6() { + return base.length == 16; + } + + /** + * Retrieves the first address of this range as a MSB-first byte array. + * + * <p> + * If {@code subnetBits() <= 1}, the address returned is always the zeroth address. + * </p> + * + * @param inclusive whether to consider the zeroth address the first. + * @return the first address of the range + */ + public byte[] first(boolean inclusive) { + if (inclusive || networkZoneBits + 1 >= base.length * 8) { + return base.clone(); + } + byte[] result = base.clone(); + result[result.length - 1] |= 1; + return result; + } + + /** + * Retrieves the last address of this range as a MSB-first byte array. + * + * <p> + * If {@code subnetBits() <= 1}, the address returned is always the {@link #broadcastAddress()}. + * </p> + * + * @param inclusive whether to consider the direct broadcast address the last. + * @return the last address of the range + */ + public byte[] last(boolean inclusive) { + if (inclusive || networkZoneBits + 1 >= base.length * 8) { + return broadcast.clone(); + } + byte[] result = broadcast.clone(); + result[result.length - 1] &= ~1; + return result; + } + + /** + * Retrieves the broadcast address of this range as a MSB-first byte array. + * + * @return the broadcast address of the range + */ + public byte[] broadcastAddress() { + return broadcast.clone(); + } + + /** + * Tests whether this range contains the given {@link InetAddress}. + * + * @param address {@link InetAddress} to test + * @return {@code true} if the address is in the range; {@code false} otherwise + */ + public boolean contains(InetAddress address) { + return contains(address.getAddress()); + } + + /** + * Tests whether this range contains the given IP address. + * + * @param address the IP address to test, as an MSB-first byte array + * @return {@code true} if the address is in the range; {@code false} otherwise + */ + public boolean contains(byte[] address) { + if (address.length != mask.length) { + return false; + } + return Arrays.equals(base, Builder.and(address, mask)); + } + + /** + * Tests whether this range completely contains a given other range. + * + * @param other {@link InetAddressRange} to test + * @return {@code true} if the other range is completely contained in this range; {@code false} otherwise + */ + public boolean contains(InetAddressRange other) { + return contains(other.first(true)) && contains(other.last(true)); + } + + /** + * Tests whether this range overlaps a given other range. + * + * @param other {@link InetAddressRange} to test + * @return {@code true} if this range overlaps with the other range; {@code false} otherwise + */ + public boolean overlaps(InetAddressRange other) { + return contains(other.first(true)) || contains(other.last(true)); + } + + /** + * Retrieves the number of bits for the network zone. + * + * @return the number of bits for the network zone + */ + public int networkZoneBits() { + return networkZoneBits; + } + + /** + * Retrieves the number of bits for the subnet. + * + * @return the number of bits for the subnet + */ + public int subnetBits() { + return base.length * 8 - networkZoneBits; + } + + /** + * Determines the number of IP addresses in the range. + * + * <p> + * If {@code subnetBits() <= 1}, the count always includes the first and last address. + * </p> + * + * @param inclusive whether to include the first and last (broadcast) addresses in the count + * @return the number of addresses in the subnet + */ + public long numberOfAddresses(boolean inclusive) { + long n = 1L << subnetBits(); + return (inclusive || n <= 2) ? n : n - 2; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + Arrays.hashCode(base); + result = prime * result + Integer.hashCode(networkZoneBits); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + InetAddressRange other = (InetAddressRange) obj; + return Arrays.equals(base, other.base) && networkZoneBits == other.networkZoneBits; + } + + @Override + public String toString() { + if (base.length == 4) { + return toStringIp4(); + } + return toStringIp6(); + } + + private String toStringIp4() { + StringBuilder b = new StringBuilder(); + int j = base.length; + // Omit trailing zeroes + while (j > 0 && base[j - 1] == 0) { + j--; + } + for (int i = 0; i < j; i++) { + if (i > 0) { + b.append('.'); + } + b.append(base[i] & 0xFF); + } + return b.append('/').append(networkZoneBits).toString(); + } + + private String toStringIp6() { + StringBuilder b = new StringBuilder(); + int[] w = new int[base.length / 2]; + for (int i = 0; i < w.length; i++) { + w[i] = ((base[2 * i] & 0xFF) << 8) + (base[2 * i + 1] & 0xFF); + } + // Find the longest sequence of zeroes; we'll collapse it to a single :: + int longest = -1; + int longestStart = -1; + int length = -1; + int start = -1; + for (int i = 0; i < w.length; i++) { + if (w[i] != 0) { + if (length > longest) { + longest = length; + longestStart = start; + } + length = -1; + start = -1; + } else if (start < 0) { + start = i; + length = 1; + } else { + length++; + } + } + if (length >= longest) { + longest = length; + longestStart = start; + } + if (longestStart < 0) { + longestStart = w.length; + longest = 1; + } + for (int i = 0; i < longestStart; i++) { + if (i > 0) { + b.append(':'); + } + b.append(Integer.toHexString(w[i])); + } + if (longestStart < w.length) { + b.append(':'); + if (longestStart + longest >= w.length) { + b.append(':'); + } else { + for (int i = longestStart + longest; i < w.length; i++) { + b.append(':').append(Integer.toHexString(w[i])); + } + } + } + return b.append('/').append(networkZoneBits).toString(); + } + + private static final class Builder { + + private Builder() { + throw new IllegalStateException(); + } + + static InetAddressRange build(String cidr) { + IllegalArgumentException ex = null; + if (!cidr.isEmpty()) { + try { + if ("/0".equals(cidr)) { + return new InetAddressRange(new byte[4], 0); + } + Matcher m = IP4_CIDR.matcher(cidr); + if (m.matches()) { + return fromIp4(m); + } + m = IP6_CIDR.matcher(cidr); + if (m.matches() && cidr.charAt(0) != '/') { + return fromIp6(m); + } + } catch (IllegalArgumentException e) { + ex = e; + } + } + throw new IllegalArgumentException(cidr + " is not a CIDR", ex); + } + + private static InetAddressRange fromIp4(Matcher m) { + byte[] base = new byte[4]; + for (int i = 0; i < 4; i++) { + String s = m.group(i + 1); + if (s != null && !s.isEmpty()) { + base[i] = byteRange(Integer.parseInt(s)); + } + } + int bits = Integer.parseInt(m.group(5)); + return new InetAddressRange(base, bits); + } + + private static InetAddressRange fromIp6(Matcher m) { + String prefix = m.group(1); + String suffix = m.group(2); + String[] pre = prefix == null ? new String[0] : prefix.split(":"); + String[] post = suffix == null ? new String[0] : suffix.split(":"); + if (pre.length + post.length > 8) { + throw new IllegalArgumentException("Too many components"); + } + byte[] base = new byte[16]; + for (int i = 0; i < pre.length; i++) { + int w = wordRange(Integer.parseInt(pre[i], 16)); + base[2 * i] = (byte) (w >>> 8); + base[2 * i + 1] = (byte) w; + } + for (int i = post.length - 1, j = base.length / 2 - 1; i >= 0; i--, j--) { + int w = wordRange(Integer.parseInt(post[i], 16)); + base[2 * j] = (byte) (w >>> 8); + base[2 * j + 1] = (byte) w; + } + int bits = Integer.parseInt(m.group(3)); + return new InetAddressRange(base, bits); + } + + private static byte byteRange(int x) { + rangeCheck(x, 0, 255); + return (byte) x; + } + + private static int wordRange(int x) { + rangeCheck(x, 0, 1 << 16 - 1); + return x; + } + + private static void rangeCheck(int x, int min, int max) { + if (x < min || x > max) { + throw new IllegalArgumentException(x + " not in range [" + min + ',' + max + ']'); + } + } + + static void computeMask(byte[] mask, int bits) { + rangeCheck(bits, 0, mask.length * 8); + for (int i = 0; i < bits; i++) { + int j = i / 8; + int b = 1 << 7 - (i % 8); + mask[j] |= b; + } + } + + static byte[] and(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException(); + } + byte[] r = a.clone(); + for (int i = 0; i < a.length; i++) { + r[i] &= b[i]; + } + return r; + } + + static byte[] invertedOr(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException(); + } + byte[] r = a.clone(); + for (int i = 0; i < a.length; i++) { + r[i] |= ~b[i]; + } + return r; + } + } +} diff --git a/sshd-common/src/test/java/org/apache/sshd/common/net/InetAddressRangeTest.java b/sshd-common/src/test/java/org/apache/sshd/common/net/InetAddressRangeTest.java new file mode 100644 index 000000000..fe8b771de --- /dev/null +++ b/sshd-common/src/test/java/org/apache/sshd/common/net/InetAddressRangeTest.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd.common.net; + +import java.net.InetAddress; +import java.util.Arrays; +import java.util.stream.Stream; + +import org.apache.sshd.util.test.JUnitTestSupport; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +@Tag("NoIoTestCase") +class InetAddressRangeTest extends JUnitTestSupport { + + static Stream<Arguments> ip4Cidrs() { + return Stream.of( + Arguments.of("192.168.17.17/16", "192.168/16"), + Arguments.of("10.0.2.5/24", "10.0.2/24"), + Arguments.of("192.168.17.43/32", "192.168.17.43/32"), + Arguments.of("192.168.17.43/0", "/0")); + } + + @ParameterizedTest(name = "{0} - {1}") + @MethodSource("ip4Cidrs") + void toStringIp4(String cidr, String expected) { + InetAddressRange range = InetAddressRange.fromCIDR(cidr); + assertTrue(range.isIpV4()); + assertFalse(range.isIpV6()); + String str = range.toString(); + assertEquals(range, InetAddressRange.fromCIDR(str)); + assertEquals(expected, str); + } + + static Stream<Arguments> ip6Cidrs() { + return Stream.of( + Arguments.of("2001:0df8:23f2:0000:0000:66ee:1336:1774/96", "2001:df8:23f2:0:0:66ee::/96"), + Arguments.of("2001:0df8:0000:0000:0000:66ee:1336:1774/96", "2001:df8::66ee:0:0/96"), + Arguments.of("2001:df8::/32", "2001:df8::/32"), + Arguments.of("0:0::/66", "::/66")); + } + + @ParameterizedTest(name = "{0} - {1}") + @MethodSource("ip6Cidrs") + void toStringIp6(String cidr, String expected) { + InetAddressRange range = InetAddressRange.fromCIDR(cidr); + assertTrue(range.isIpV6()); + assertFalse(range.isIpV4()); + String str = range.toString(); + assertEquals(range, InetAddressRange.fromCIDR(str)); + assertEquals(expected, str); + } + + static Stream<Arguments> ip4Contains() { + return Stream.of( + Arguments.of("192.168.17.17/24", "192.168.17/24", 256, "192.168.17.0", "192.168.17.255"), + Arguments.of("10.0.5.5/22", "10.0.4/22", 1024, "10.0.4.0", "10.0.7.255")); + } + + @ParameterizedTest(name = "{0} - {1}") + @MethodSource("ip4Contains") + void containsIp4(String cidr, String expected, int size, String first, String last) throws Exception { + InetAddressRange range = InetAddressRange.fromCIDR(cidr); + assertTrue(range.isIpV4()); + assertEquals(size, range.numberOfAddresses(true)); + assertEquals(size - 2, range.numberOfAddresses(false)); + byte[] from = range.first(true); + byte[] to = range.last(true); + assertArrayEquals(InetAddress.getByName(first).getAddress(), from); + assertArrayEquals(InetAddress.getByName(last).getAddress(), to); + int n = 0; + while (!Arrays.equals(from, to)) { + assertTrue(range.contains(from)); + n++; + inc(from); + } + assertTrue(range.contains(from)); + n++; + assertEquals(size, n); + inc(from); + assertFalse(range.contains(from)); + String str = range.toString(); + assertEquals(range, InetAddressRange.fromCIDR(str)); + assertEquals(expected, str); + } + + private void inc(byte[] b) { + for (int i = b.length - 1; i >= 0; i--) { + int x = (b[i] & 0xFF) + 1; + b[i] = (byte) x; + if (x < 256) { + break; + } + } + } + + static Stream<Arguments> ip4Bounds() { + return Stream.of(Arguments.of("192.168.17.17/30", 30, 2, 4, "192.168.17.16", "192.168.17.19"), + Arguments.of("192.168.17.17/31", 31, 1, 2, "192.168.17.16", "192.168.17.17"), + Arguments.of("192.168.17.17/32", 32, 0, 1, "192.168.17.17", "192.168.17.17")); + } + + @ParameterizedTest(name = "{0} - {1}") + @MethodSource("ip4Bounds") + void boundsIp4(String cidr, int net, int subnet, int size, String first, String last) throws Exception { + InetAddressRange range = InetAddressRange.fromCIDR(cidr); + assertTrue(range.isIpV4()); + assertEquals(net, range.networkZoneBits()); + assertEquals(subnet, range.subnetBits()); + assertEquals(size, range.numberOfAddresses(true)); + if (size <= 2) { + assertEquals(size, range.numberOfAddresses(false)); + } else { + assertEquals(size - 2, range.numberOfAddresses(false)); + } + byte[] from = range.first(true); + byte[] to = range.last(true); + assertArrayEquals(InetAddress.getByName(first).getAddress(), from); + assertArrayEquals(InetAddress.getByName(last).getAddress(), to); + } + + @Test + void leadingZeroesIp6() { + InetAddressRange range = InetAddressRange.fromCIDR("0000:0000:0000:1234:5678:1234:5678::/96"); + String str = range.toString(); + assertEquals("::1234:5678:1234:0:0/96", str); + assertEquals(range, InetAddressRange.fromCIDR(str)); + } +} diff --git a/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java b/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java index 04c0e8d28..3280170de 100644 --- a/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java +++ b/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java @@ -18,6 +18,9 @@ */ package org.apache.sshd.server.auth.pubkey; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.security.PublicKey; import java.security.SignatureException; import java.security.cert.CertificateException; @@ -30,6 +33,8 @@ import org.apache.sshd.common.RuntimeSshException; import org.apache.sshd.common.SshConstants; import org.apache.sshd.common.config.keys.KeyUtils; import org.apache.sshd.common.config.keys.OpenSshCertificate; +import org.apache.sshd.common.config.keys.OpenSshCertificate.CertificateOption; +import org.apache.sshd.common.net.InetAddressRange; import org.apache.sshd.common.signature.Signature; import org.apache.sshd.common.signature.SignatureFactoriesManager; import org.apache.sshd.common.util.GenericUtils; @@ -103,6 +108,7 @@ public class UserAuthPublicKey extends AbstractUserAuth implements SignatureFact throw new CertificateException("expired"); } verifyCertificateSignature(session, cert); + verifyCertificateSources(session, cert); } catch (Exception e) { warn("doAuth({}@{}): public key certificate (id={}) is not valid: {}", username, session, cert.getId(), e.getMessage(), e); @@ -197,6 +203,32 @@ public class UserAuthPublicKey extends AbstractUserAuth implements SignatureFact } } + protected void verifyCertificateSources(ServerSession session, OpenSshCertificate cert) throws CertificateException { + String allowedSources = cert.getCriticalOptions().stream().filter(c -> "source-address".equals(c.getName())) + .map(CertificateOption::getData).findAny().orElse(null); + if (allowedSources == null) { + return; + } + SocketAddress remote = session.getRemoteAddress(); + if (remote instanceof InetSocketAddress) { + InetAddress remoteAddress = ((InetSocketAddress) remote).getAddress(); + for (String allowed : allowedSources.split(",")) { + String cidr = allowed.trim(); + if (GenericUtils.isEmpty(cidr)) { + continue; + } + try { + if (InetAddressRange.fromCIDR(cidr).contains(remoteAddress)) { + return; + } + } catch (IllegalArgumentException e) { + throw new CertificateException("Invalid CIDR range '" + cidr + "' in source-address critical option"); + } + } + throw new CertificateException("Rejected by source-address critical option; not allowed from " + remoteAddress); + } + } + protected boolean verifySignature( ServerSession session, String username, String alg, PublicKey key, Buffer buffer, Signature verifier, byte[] sig) throws Exception { diff --git a/sshd-core/src/test/java/org/apache/sshd/common/auth/PublicKeyAuthenticationTest.java b/sshd-core/src/test/java/org/apache/sshd/common/auth/PublicKeyAuthenticationTest.java index be0359fe0..67a7d5949 100644 --- a/sshd-core/src/test/java/org/apache/sshd/common/auth/PublicKeyAuthenticationTest.java +++ b/sshd-core/src/test/java/org/apache/sshd/common/auth/PublicKeyAuthenticationTest.java @@ -49,6 +49,7 @@ import org.apache.sshd.common.SshException; import org.apache.sshd.common.config.keys.FilePasswordProvider; import org.apache.sshd.common.config.keys.KeyUtils; import org.apache.sshd.common.config.keys.OpenSshCertificate; +import org.apache.sshd.common.config.keys.OpenSshCertificate.CertificateOption; import org.apache.sshd.common.config.keys.OpenSshCertificateImpl; import org.apache.sshd.common.keyprovider.KeyIdentityProvider; import org.apache.sshd.common.keyprovider.KeyPairProvider; @@ -557,15 +558,7 @@ public class PublicKeyAuthenticationTest extends AuthenticationTestSupport { sshd.setUserAuthFactories(Collections.singletonList(new org.apache.sshd.server.auth.pubkey.UserAuthPublicKeyFactory())); - AtomicInteger authAttempts = new AtomicInteger(0); - sshd.setPublickeyAuthenticator((username, key, session) -> { - authAttempts.incrementAndGet(); - if (key instanceof OpenSshCertificate) { - OpenSshCertificate cert = (OpenSshCertificate) key; - return KeyUtils.compareKeys(cert.getCaPubKey(), caKeypair.getPublic()); - } - return false; - }); + sshd.setPublickeyAuthenticator((username, key, session) -> true); // Client authentication should fail try (SshClient client = setupTestClient()) { @@ -588,4 +581,63 @@ public class PublicKeyAuthenticationTest extends AuthenticationTestSupport { } } + @ParameterizedTest(name = "''{0}''") + @MethodSource("certificateSources") + void certificateSources(String sources, boolean expectSuccess) throws Exception { + KeyPair userkey = CommonTestSupportUtils.generateKeyPair(KeyUtils.EC_ALGORITHM, 256); + KeyPair caKeypair = CommonTestSupportUtils.generateKeyPair(KeyUtils.EC_ALGORITHM, 256); + + OpenSshCertificate signedCert = OpenSshCertificateBuilder.userCertificate() // + .serial(System.currentTimeMillis()) // + .publicKey(userkey.getPublic()) // + .id("test-cert-ecdsa-sha2-nistp256") // + .validBefore(System.currentTimeMillis() + TimeUnit.HOURS.toMillis(1)) // + .principals(Collections.singletonList("user01")) // + .criticalOptions(Collections.singletonList(new CertificateOption("source-address", sources))) + .sign(caKeypair, "ecdsa-sha2-nistp256"); + + // Configure the ssh server + sshd.setPasswordAuthenticator(RejectAllPasswordAuthenticator.INSTANCE); + sshd.setKeyboardInteractiveAuthenticator(KeyboardInteractiveAuthenticator.NONE); + CoreTestSupportUtils.setupFullSignaturesSupport(sshd); + + sshd.setUserAuthFactories(Collections.singletonList(new org.apache.sshd.server.auth.pubkey.UserAuthPublicKeyFactory())); + + sshd.setPublickeyAuthenticator((username, key, session) -> true); + + try (SshClient client = setupTestClient()) { + CoreTestSupportUtils.setupFullSignaturesSupport(client); + client.setUserAuthFactories( + Collections.singletonList(new org.apache.sshd.client.auth.pubkey.UserAuthPublicKeyFactory())); + + client.start(); + + try (ClientSession session = client.connect("user01", TEST_LOCALHOST, port).verify(CONNECT_TIMEOUT).getSession()) { + + KeyPair certKeyPair = new KeyPair(signedCert, userkey.getPrivate()); + session.addPublicKeyIdentity(certKeyPair); + + AuthFuture auth = session.auth(); + if (expectSuccess) { + auth.verify(AUTH_TIMEOUT); + assertTrue(session.isAuthenticated()); + } else { + assertThrows(SshException.class, () -> auth.verify(AUTH_TIMEOUT)); + } + } finally { + client.stop(); + } + } + } + + private static Stream<Arguments> certificateSources() { + return Stream.of( // + Arguments.of("127.0/24", true), // + Arguments.of("8.8.8.8/24", false), // + Arguments.of("127.0.0.1/32", true), // + Arguments.of("8.8.8.8/8,127.0.0.1/32", true), // + Arguments.of("bogus", false), // + Arguments.of("", false)); + } + }