This is an automated email from the ASF dual-hosted git repository.

twolf pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git

commit e2e4e11ec667f9d7536b1d51f401ff47f44eed51
Author: Thomas Wolf <tw...@apache.org>
AuthorDate: Thu Feb 20 23:42:57 2025 +0100

    [SSHD-1161] Pubkey auth: source-address restriction in certificates
    
    OpenSSH certificates may contain a critical option "source-address". If
    present, it contains a comma-separated list of CIDR ranges from where
    the certificate is accepted for authentication.[1]
    
    Extract and parse this option, and do not accept the certificate if the
    client's IP address is not contained in at least one CIDR range given.
    
    [1] 
https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.certkeys
---
 CHANGES.md                                         |   3 +
 .../apache/sshd/common/net/InetAddressRange.java   | 441 +++++++++++++++++++++
 .../sshd/common/net/InetAddressRangeTest.java      | 148 +++++++
 .../sshd/server/auth/pubkey/UserAuthPublicKey.java |  32 ++
 .../common/auth/PublicKeyAuthenticationTest.java   |  70 +++-
 5 files changed, 685 insertions(+), 9 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 9ec1e4455..f163fdeaf 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -32,6 +32,9 @@
 
 ## New Features
 
+* [SSHD-1161](https://issues.apache.org/jira/projects/SSHD/issues/SSHD-1161) 
Support pubkey auth with user certificates (server-side)
+    * Client-side support was introduced in version 2.8.0 already 
+
 ## Potential Compatibility Issues
 
 ## Major Code Re-factoring
diff --git 
a/sshd-common/src/main/java/org/apache/sshd/common/net/InetAddressRange.java 
b/sshd-common/src/main/java/org/apache/sshd/common/net/InetAddressRange.java
new file mode 100644
index 000000000..ca5f46265
--- /dev/null
+++ b/sshd-common/src/main/java/org/apache/sshd/common/net/InetAddressRange.java
@@ -0,0 +1,441 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sshd.common.net;
+
+import java.net.InetAddress;
+import java.util.Arrays;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+/**
+ * Describes a range of IP addresses specified in CIDR notation.
+ */
+public final class InetAddressRange {
+
+    private static final String IP4_BYTE = "(\\d{1,3})";
+
+    private static final String IP4_DOT_BYTE = "(?:\\." + IP4_BYTE + ')';
+
+    private static final String IP4_PREFIX = IP4_BYTE + "(?:" + IP4_DOT_BYTE + 
"(?:" + IP4_DOT_BYTE + "(?:" + IP4_DOT_BYTE
+                                             + ")?)?)?";
+
+    private static final String BITS = "(\\d{1,3})";
+
+    private static final Pattern IP4_CIDR = Pattern.compile('^' + IP4_PREFIX + 
"/" + BITS + '$');
+
+    private static final String IP6_WORD = "(?:[0-9a-fA-F]{1,4})";
+
+    private static final String IP6_PART = "(" + IP6_WORD + "(?::" + IP6_WORD 
+ ")*+)";
+
+    private static final Pattern IP6_CIDR = Pattern.compile('^' + IP6_PART + 
"?(?:::" + IP6_PART + "?)?" + '/' + BITS + '$');
+
+    private final byte[] base;
+
+    private final byte[] mask;
+
+    private final byte[] broadcast;
+
+    private final int networkZoneBits;
+
+    private InetAddressRange(byte[] base, int bits) {
+        byte[] netmask = new byte[base.length];
+        Builder.computeMask(netmask, bits);
+        byte[] net = Builder.and(base, netmask);
+
+        this.broadcast = Builder.invertedOr(net, netmask);
+        this.base = net;
+        this.mask = netmask;
+        this.networkZoneBits = bits;
+    }
+
+    /**
+     * Creates an {@link InetAddressRange} for a CIDR.
+     *
+     * @param  cidr                     the CIDR
+     * @return                          an {@link InetAddressRange}
+     * @throws IllegalArgumentException if the {@code cidr} cannot be parsed 
as a CIDR.
+     */
+    public static InetAddressRange fromCIDR(String cidr) {
+        return Builder.build(cidr);
+    }
+
+    /**
+     * Tests whether a given string is a valid CIDR.
+     *
+     * @param  cidr the string to test
+     * @return      {@code true} if the string can be parsed as a CIDR; {@code 
false} otherwise
+     */
+    public static boolean isCIDR(String cidr) {
+        try {
+            return fromCIDR(cidr) != null;
+        } catch (IllegalArgumentException e) {
+            return false;
+        }
+    }
+
+    /**
+     * Tells whether this is an IPv4 address range.
+     *
+     * @return {@code true} if this is an IPv4 address range, {@code false} 
otherwise
+     */
+    public boolean isIpV4() {
+        return base.length == 4;
+    }
+
+    /**
+     * Tells whether this is an IPv6 address range.
+     *
+     * @return {@code true} if this is an IPv6 address range, {@code false} 
otherwise
+     */
+    public boolean isIpV6() {
+        return base.length == 16;
+    }
+
+    /**
+     * Retrieves the first address of this range as a MSB-first byte array.
+     *
+     * <p>
+     * If {@code subnetBits() <= 1}, the address returned is always the zeroth 
address.
+     * </p>
+     *
+     * @param  inclusive whether to consider the zeroth address the first.
+     * @return           the first address of the range
+     */
+    public byte[] first(boolean inclusive) {
+        if (inclusive || networkZoneBits + 1 >= base.length * 8) {
+            return base.clone();
+        }
+        byte[] result = base.clone();
+        result[result.length - 1] |= 1;
+        return result;
+    }
+
+    /**
+     * Retrieves the last address of this range as a MSB-first byte array.
+     *
+     * <p>
+     * If {@code subnetBits() <= 1}, the address returned is always the {@link 
#broadcastAddress()}.
+     * </p>
+     *
+     * @param  inclusive whether to consider the direct broadcast address the 
last.
+     * @return           the last address of the range
+     */
+    public byte[] last(boolean inclusive) {
+        if (inclusive || networkZoneBits + 1 >= base.length * 8) {
+            return broadcast.clone();
+        }
+        byte[] result = broadcast.clone();
+        result[result.length - 1] &= ~1;
+        return result;
+    }
+
+    /**
+     * Retrieves the broadcast address of this range as a MSB-first byte array.
+     *
+     * @return the broadcast address of the range
+     */
+    public byte[] broadcastAddress() {
+        return broadcast.clone();
+    }
+
+    /**
+     * Tests whether this range contains the given {@link InetAddress}.
+     *
+     * @param  address {@link InetAddress} to test
+     * @return         {@code true} if the address is in the range; {@code 
false} otherwise
+     */
+    public boolean contains(InetAddress address) {
+        return contains(address.getAddress());
+    }
+
+    /**
+     * Tests whether this range contains the given IP address.
+     *
+     * @param  address the IP address to test, as an MSB-first byte array
+     * @return         {@code true} if the address is in the range; {@code 
false} otherwise
+     */
+    public boolean contains(byte[] address) {
+        if (address.length != mask.length) {
+            return false;
+        }
+        return Arrays.equals(base, Builder.and(address, mask));
+    }
+
+    /**
+     * Tests whether this range completely contains a given other range.
+     *
+     * @param  other {@link InetAddressRange} to test
+     * @return       {@code true} if the other range is completely contained 
in this range; {@code false} otherwise
+     */
+    public boolean contains(InetAddressRange other) {
+        return contains(other.first(true)) && contains(other.last(true));
+    }
+
+    /**
+     * Tests whether this range overlaps a given other range.
+     *
+     * @param  other {@link InetAddressRange} to test
+     * @return       {@code true} if this range overlaps with the other range; 
{@code false} otherwise
+     */
+    public boolean overlaps(InetAddressRange other) {
+        return contains(other.first(true)) || contains(other.last(true));
+    }
+
+    /**
+     * Retrieves the number of bits for the network zone.
+     *
+     * @return the number of bits for the network zone
+     */
+    public int networkZoneBits() {
+        return networkZoneBits;
+    }
+
+    /**
+     * Retrieves the number of bits for the subnet.
+     *
+     * @return the number of bits for the subnet
+     */
+    public int subnetBits() {
+        return base.length * 8 - networkZoneBits;
+    }
+
+    /**
+     * Determines the number of IP addresses in the range.
+     *
+     * <p>
+     * If {@code subnetBits() <= 1}, the count always includes the first and 
last address.
+     * </p>
+     *
+     * @param  inclusive whether to include the first and last (broadcast) 
addresses in the count
+     * @return           the number of addresses in the subnet
+     */
+    public long numberOfAddresses(boolean inclusive) {
+        long n = 1L << subnetBits();
+        return (inclusive || n <= 2) ? n : n - 2;
+    }
+
+    @Override
+    public int hashCode() {
+        final int prime = 31;
+        int result = 1;
+        result = prime * result + Arrays.hashCode(base);
+        result = prime * result + Integer.hashCode(networkZoneBits);
+        return result;
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (this == obj) {
+            return true;
+        }
+        if (obj == null || getClass() != obj.getClass()) {
+            return false;
+        }
+        InetAddressRange other = (InetAddressRange) obj;
+        return Arrays.equals(base, other.base) && networkZoneBits == 
other.networkZoneBits;
+    }
+
+    @Override
+    public String toString() {
+        if (base.length == 4) {
+            return toStringIp4();
+        }
+        return toStringIp6();
+    }
+
+    private String toStringIp4() {
+        StringBuilder b = new StringBuilder();
+        int j = base.length;
+        // Omit trailing zeroes
+        while (j > 0 && base[j - 1] == 0) {
+            j--;
+        }
+        for (int i = 0; i < j; i++) {
+            if (i > 0) {
+                b.append('.');
+            }
+            b.append(base[i] & 0xFF);
+        }
+        return b.append('/').append(networkZoneBits).toString();
+    }
+
+    private String toStringIp6() {
+        StringBuilder b = new StringBuilder();
+        int[] w = new int[base.length / 2];
+        for (int i = 0; i < w.length; i++) {
+            w[i] = ((base[2 * i] & 0xFF) << 8) + (base[2 * i + 1] & 0xFF);
+        }
+        // Find the longest sequence of zeroes; we'll collapse it to a single 
::
+        int longest = -1;
+        int longestStart = -1;
+        int length = -1;
+        int start = -1;
+        for (int i = 0; i < w.length; i++) {
+            if (w[i] != 0) {
+                if (length > longest) {
+                    longest = length;
+                    longestStart = start;
+                }
+                length = -1;
+                start = -1;
+            } else if (start < 0) {
+                start = i;
+                length = 1;
+            } else {
+                length++;
+            }
+        }
+        if (length >= longest) {
+            longest = length;
+            longestStart = start;
+        }
+        if (longestStart < 0) {
+            longestStart = w.length;
+            longest = 1;
+        }
+        for (int i = 0; i < longestStart; i++) {
+            if (i > 0) {
+                b.append(':');
+            }
+            b.append(Integer.toHexString(w[i]));
+        }
+        if (longestStart < w.length) {
+            b.append(':');
+            if (longestStart + longest >= w.length) {
+                b.append(':');
+            } else {
+                for (int i = longestStart + longest; i < w.length; i++) {
+                    b.append(':').append(Integer.toHexString(w[i]));
+                }
+            }
+        }
+        return b.append('/').append(networkZoneBits).toString();
+    }
+
+    private static final class Builder {
+
+        private Builder() {
+            throw new IllegalStateException();
+        }
+
+        static InetAddressRange build(String cidr) {
+            IllegalArgumentException ex = null;
+            if (!cidr.isEmpty()) {
+                try {
+                    if ("/0".equals(cidr)) {
+                        return new InetAddressRange(new byte[4], 0);
+                    }
+                    Matcher m = IP4_CIDR.matcher(cidr);
+                    if (m.matches()) {
+                        return fromIp4(m);
+                    }
+                    m = IP6_CIDR.matcher(cidr);
+                    if (m.matches() && cidr.charAt(0) != '/') {
+                        return fromIp6(m);
+                    }
+                } catch (IllegalArgumentException e) {
+                    ex = e;
+                }
+            }
+            throw new IllegalArgumentException(cidr + " is not a CIDR", ex);
+        }
+
+        private static InetAddressRange fromIp4(Matcher m) {
+            byte[] base = new byte[4];
+            for (int i = 0; i < 4; i++) {
+                String s = m.group(i + 1);
+                if (s != null && !s.isEmpty()) {
+                    base[i] = byteRange(Integer.parseInt(s));
+                }
+            }
+            int bits = Integer.parseInt(m.group(5));
+            return new InetAddressRange(base, bits);
+        }
+
+        private static InetAddressRange fromIp6(Matcher m) {
+            String prefix = m.group(1);
+            String suffix = m.group(2);
+            String[] pre = prefix == null ? new String[0] : prefix.split(":");
+            String[] post = suffix == null ? new String[0] : suffix.split(":");
+            if (pre.length + post.length > 8) {
+                throw new IllegalArgumentException("Too many components");
+            }
+            byte[] base = new byte[16];
+            for (int i = 0; i < pre.length; i++) {
+                int w = wordRange(Integer.parseInt(pre[i], 16));
+                base[2 * i] = (byte) (w >>> 8);
+                base[2 * i + 1] = (byte) w;
+            }
+            for (int i = post.length - 1, j = base.length / 2 - 1; i >= 0; 
i--, j--) {
+                int w = wordRange(Integer.parseInt(post[i], 16));
+                base[2 * j] = (byte) (w >>> 8);
+                base[2 * j + 1] = (byte) w;
+            }
+            int bits = Integer.parseInt(m.group(3));
+            return new InetAddressRange(base, bits);
+        }
+
+        private static byte byteRange(int x) {
+            rangeCheck(x, 0, 255);
+            return (byte) x;
+        }
+
+        private static int wordRange(int x) {
+            rangeCheck(x, 0, 1 << 16 - 1);
+            return x;
+        }
+
+        private static void rangeCheck(int x, int min, int max) {
+            if (x < min || x > max) {
+                throw new IllegalArgumentException(x + " not in range [" + min 
+ ',' + max + ']');
+            }
+        }
+
+        static void computeMask(byte[] mask, int bits) {
+            rangeCheck(bits, 0, mask.length * 8);
+            for (int i = 0; i < bits; i++) {
+                int j = i / 8;
+                int b = 1 << 7 - (i % 8);
+                mask[j] |= b;
+            }
+        }
+
+        static byte[] and(byte[] a, byte[] b) {
+            if (a.length != b.length) {
+                throw new IllegalArgumentException();
+            }
+            byte[] r = a.clone();
+            for (int i = 0; i < a.length; i++) {
+                r[i] &= b[i];
+            }
+            return r;
+        }
+
+        static byte[] invertedOr(byte[] a, byte[] b) {
+            if (a.length != b.length) {
+                throw new IllegalArgumentException();
+            }
+            byte[] r = a.clone();
+            for (int i = 0; i < a.length; i++) {
+                r[i] |= ~b[i];
+            }
+            return r;
+        }
+    }
+}
diff --git 
a/sshd-common/src/test/java/org/apache/sshd/common/net/InetAddressRangeTest.java
 
b/sshd-common/src/test/java/org/apache/sshd/common/net/InetAddressRangeTest.java
new file mode 100644
index 000000000..fe8b771de
--- /dev/null
+++ 
b/sshd-common/src/test/java/org/apache/sshd/common/net/InetAddressRangeTest.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sshd.common.net;
+
+import java.net.InetAddress;
+import java.util.Arrays;
+import java.util.stream.Stream;
+
+import org.apache.sshd.util.test.JUnitTestSupport;
+import org.junit.jupiter.api.Tag;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+@Tag("NoIoTestCase")
+class InetAddressRangeTest extends JUnitTestSupport {
+
+    static Stream<Arguments> ip4Cidrs() {
+        return Stream.of(
+                Arguments.of("192.168.17.17/16", "192.168/16"),
+                Arguments.of("10.0.2.5/24", "10.0.2/24"),
+                Arguments.of("192.168.17.43/32", "192.168.17.43/32"),
+                Arguments.of("192.168.17.43/0", "/0"));
+    }
+
+    @ParameterizedTest(name = "{0} - {1}")
+    @MethodSource("ip4Cidrs")
+    void toStringIp4(String cidr, String expected) {
+        InetAddressRange range = InetAddressRange.fromCIDR(cidr);
+        assertTrue(range.isIpV4());
+        assertFalse(range.isIpV6());
+        String str = range.toString();
+        assertEquals(range, InetAddressRange.fromCIDR(str));
+        assertEquals(expected, str);
+    }
+
+    static Stream<Arguments> ip6Cidrs() {
+        return Stream.of(
+                Arguments.of("2001:0df8:23f2:0000:0000:66ee:1336:1774/96", 
"2001:df8:23f2:0:0:66ee::/96"),
+                Arguments.of("2001:0df8:0000:0000:0000:66ee:1336:1774/96", 
"2001:df8::66ee:0:0/96"),
+                Arguments.of("2001:df8::/32", "2001:df8::/32"),
+                Arguments.of("0:0::/66", "::/66"));
+    }
+
+    @ParameterizedTest(name = "{0} - {1}")
+    @MethodSource("ip6Cidrs")
+    void toStringIp6(String cidr, String expected) {
+        InetAddressRange range = InetAddressRange.fromCIDR(cidr);
+        assertTrue(range.isIpV6());
+        assertFalse(range.isIpV4());
+        String str = range.toString();
+        assertEquals(range, InetAddressRange.fromCIDR(str));
+        assertEquals(expected, str);
+    }
+
+    static Stream<Arguments> ip4Contains() {
+        return Stream.of(
+                Arguments.of("192.168.17.17/24", "192.168.17/24", 256, 
"192.168.17.0", "192.168.17.255"),
+                Arguments.of("10.0.5.5/22", "10.0.4/22", 1024, "10.0.4.0", 
"10.0.7.255"));
+    }
+
+    @ParameterizedTest(name = "{0} - {1}")
+    @MethodSource("ip4Contains")
+    void containsIp4(String cidr, String expected, int size, String first, 
String last) throws Exception {
+        InetAddressRange range = InetAddressRange.fromCIDR(cidr);
+        assertTrue(range.isIpV4());
+        assertEquals(size, range.numberOfAddresses(true));
+        assertEquals(size - 2, range.numberOfAddresses(false));
+        byte[] from = range.first(true);
+        byte[] to = range.last(true);
+        assertArrayEquals(InetAddress.getByName(first).getAddress(), from);
+        assertArrayEquals(InetAddress.getByName(last).getAddress(), to);
+        int n = 0;
+        while (!Arrays.equals(from, to)) {
+            assertTrue(range.contains(from));
+            n++;
+            inc(from);
+        }
+        assertTrue(range.contains(from));
+        n++;
+        assertEquals(size, n);
+        inc(from);
+        assertFalse(range.contains(from));
+        String str = range.toString();
+        assertEquals(range, InetAddressRange.fromCIDR(str));
+        assertEquals(expected, str);
+    }
+
+    private void inc(byte[] b) {
+        for (int i = b.length - 1; i >= 0; i--) {
+            int x = (b[i] & 0xFF) + 1;
+            b[i] = (byte) x;
+            if (x < 256) {
+                break;
+            }
+        }
+    }
+
+    static Stream<Arguments> ip4Bounds() {
+        return Stream.of(Arguments.of("192.168.17.17/30", 30, 2, 4, 
"192.168.17.16", "192.168.17.19"),
+                Arguments.of("192.168.17.17/31", 31, 1, 2, "192.168.17.16", 
"192.168.17.17"),
+                Arguments.of("192.168.17.17/32", 32, 0, 1, "192.168.17.17", 
"192.168.17.17"));
+    }
+
+    @ParameterizedTest(name = "{0} - {1}")
+    @MethodSource("ip4Bounds")
+    void boundsIp4(String cidr, int net, int subnet, int size, String first, 
String last) throws Exception {
+        InetAddressRange range = InetAddressRange.fromCIDR(cidr);
+        assertTrue(range.isIpV4());
+        assertEquals(net, range.networkZoneBits());
+        assertEquals(subnet, range.subnetBits());
+        assertEquals(size, range.numberOfAddresses(true));
+        if (size <= 2) {
+            assertEquals(size, range.numberOfAddresses(false));
+        } else {
+            assertEquals(size - 2, range.numberOfAddresses(false));
+        }
+        byte[] from = range.first(true);
+        byte[] to = range.last(true);
+        assertArrayEquals(InetAddress.getByName(first).getAddress(), from);
+        assertArrayEquals(InetAddress.getByName(last).getAddress(), to);
+    }
+
+    @Test
+    void leadingZeroesIp6() {
+        InetAddressRange range = 
InetAddressRange.fromCIDR("0000:0000:0000:1234:5678:1234:5678::/96");
+        String str = range.toString();
+        assertEquals("::1234:5678:1234:0:0/96", str);
+        assertEquals(range, InetAddressRange.fromCIDR(str));
+    }
+}
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java
 
b/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java
index 04c0e8d28..3280170de 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java
@@ -18,6 +18,9 @@
  */
 package org.apache.sshd.server.auth.pubkey;
 
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
 import java.security.PublicKey;
 import java.security.SignatureException;
 import java.security.cert.CertificateException;
@@ -30,6 +33,8 @@ import org.apache.sshd.common.RuntimeSshException;
 import org.apache.sshd.common.SshConstants;
 import org.apache.sshd.common.config.keys.KeyUtils;
 import org.apache.sshd.common.config.keys.OpenSshCertificate;
+import org.apache.sshd.common.config.keys.OpenSshCertificate.CertificateOption;
+import org.apache.sshd.common.net.InetAddressRange;
 import org.apache.sshd.common.signature.Signature;
 import org.apache.sshd.common.signature.SignatureFactoriesManager;
 import org.apache.sshd.common.util.GenericUtils;
@@ -103,6 +108,7 @@ public class UserAuthPublicKey extends AbstractUserAuth 
implements SignatureFact
                     throw new CertificateException("expired");
                 }
                 verifyCertificateSignature(session, cert);
+                verifyCertificateSources(session, cert);
             } catch (Exception e) {
                 warn("doAuth({}@{}): public key certificate (id={}) is not 
valid: {}", username, session, cert.getId(),
                         e.getMessage(), e);
@@ -197,6 +203,32 @@ public class UserAuthPublicKey extends AbstractUserAuth 
implements SignatureFact
         }
     }
 
+    protected void verifyCertificateSources(ServerSession session, 
OpenSshCertificate cert) throws CertificateException {
+        String allowedSources = cert.getCriticalOptions().stream().filter(c -> 
"source-address".equals(c.getName()))
+                .map(CertificateOption::getData).findAny().orElse(null);
+        if (allowedSources == null) {
+            return;
+        }
+        SocketAddress remote = session.getRemoteAddress();
+        if (remote instanceof InetSocketAddress) {
+            InetAddress remoteAddress = ((InetSocketAddress) 
remote).getAddress();
+            for (String allowed : allowedSources.split(",")) {
+                String cidr = allowed.trim();
+                if (GenericUtils.isEmpty(cidr)) {
+                    continue;
+                }
+                try {
+                    if 
(InetAddressRange.fromCIDR(cidr).contains(remoteAddress)) {
+                        return;
+                    }
+                } catch (IllegalArgumentException e) {
+                    throw new CertificateException("Invalid CIDR range '" + 
cidr + "' in source-address critical option");
+                }
+            }
+            throw new CertificateException("Rejected by source-address 
critical option; not allowed from " + remoteAddress);
+        }
+    }
+
     protected boolean verifySignature(
             ServerSession session, String username, String alg, PublicKey key, 
Buffer buffer, Signature verifier, byte[] sig)
             throws Exception {
diff --git 
a/sshd-core/src/test/java/org/apache/sshd/common/auth/PublicKeyAuthenticationTest.java
 
b/sshd-core/src/test/java/org/apache/sshd/common/auth/PublicKeyAuthenticationTest.java
index be0359fe0..67a7d5949 100644
--- 
a/sshd-core/src/test/java/org/apache/sshd/common/auth/PublicKeyAuthenticationTest.java
+++ 
b/sshd-core/src/test/java/org/apache/sshd/common/auth/PublicKeyAuthenticationTest.java
@@ -49,6 +49,7 @@ import org.apache.sshd.common.SshException;
 import org.apache.sshd.common.config.keys.FilePasswordProvider;
 import org.apache.sshd.common.config.keys.KeyUtils;
 import org.apache.sshd.common.config.keys.OpenSshCertificate;
+import org.apache.sshd.common.config.keys.OpenSshCertificate.CertificateOption;
 import org.apache.sshd.common.config.keys.OpenSshCertificateImpl;
 import org.apache.sshd.common.keyprovider.KeyIdentityProvider;
 import org.apache.sshd.common.keyprovider.KeyPairProvider;
@@ -557,15 +558,7 @@ public class PublicKeyAuthenticationTest extends 
AuthenticationTestSupport {
 
         sshd.setUserAuthFactories(Collections.singletonList(new 
org.apache.sshd.server.auth.pubkey.UserAuthPublicKeyFactory()));
 
-        AtomicInteger authAttempts = new AtomicInteger(0);
-        sshd.setPublickeyAuthenticator((username, key, session) -> {
-            authAttempts.incrementAndGet();
-            if (key instanceof OpenSshCertificate) {
-                OpenSshCertificate cert = (OpenSshCertificate) key;
-                return KeyUtils.compareKeys(cert.getCaPubKey(), 
caKeypair.getPublic());
-            }
-            return false;
-        });
+        sshd.setPublickeyAuthenticator((username, key, session) -> true);
 
         // Client authentication should fail
         try (SshClient client = setupTestClient()) {
@@ -588,4 +581,63 @@ public class PublicKeyAuthenticationTest extends 
AuthenticationTestSupport {
         }
     }
 
+    @ParameterizedTest(name = "''{0}''")
+    @MethodSource("certificateSources")
+    void certificateSources(String sources, boolean expectSuccess) throws 
Exception {
+        KeyPair userkey = 
CommonTestSupportUtils.generateKeyPair(KeyUtils.EC_ALGORITHM, 256);
+        KeyPair caKeypair = 
CommonTestSupportUtils.generateKeyPair(KeyUtils.EC_ALGORITHM, 256);
+
+        OpenSshCertificate signedCert = 
OpenSshCertificateBuilder.userCertificate() //
+                .serial(System.currentTimeMillis()) //
+                .publicKey(userkey.getPublic()) //
+                .id("test-cert-ecdsa-sha2-nistp256") //
+                .validBefore(System.currentTimeMillis() + 
TimeUnit.HOURS.toMillis(1)) //
+                .principals(Collections.singletonList("user01")) //
+                .criticalOptions(Collections.singletonList(new 
CertificateOption("source-address", sources)))
+                .sign(caKeypair, "ecdsa-sha2-nistp256");
+
+        // Configure the ssh server
+        sshd.setPasswordAuthenticator(RejectAllPasswordAuthenticator.INSTANCE);
+        
sshd.setKeyboardInteractiveAuthenticator(KeyboardInteractiveAuthenticator.NONE);
+        CoreTestSupportUtils.setupFullSignaturesSupport(sshd);
+
+        sshd.setUserAuthFactories(Collections.singletonList(new 
org.apache.sshd.server.auth.pubkey.UserAuthPublicKeyFactory()));
+
+        sshd.setPublickeyAuthenticator((username, key, session) -> true);
+
+        try (SshClient client = setupTestClient()) {
+            CoreTestSupportUtils.setupFullSignaturesSupport(client);
+            client.setUserAuthFactories(
+                    Collections.singletonList(new 
org.apache.sshd.client.auth.pubkey.UserAuthPublicKeyFactory()));
+
+            client.start();
+
+            try (ClientSession session = client.connect("user01", 
TEST_LOCALHOST, port).verify(CONNECT_TIMEOUT).getSession()) {
+
+                KeyPair certKeyPair = new KeyPair(signedCert, 
userkey.getPrivate());
+                session.addPublicKeyIdentity(certKeyPair);
+
+                AuthFuture auth = session.auth();
+                if (expectSuccess) {
+                    auth.verify(AUTH_TIMEOUT);
+                    assertTrue(session.isAuthenticated());
+                } else {
+                    assertThrows(SshException.class, () -> 
auth.verify(AUTH_TIMEOUT));
+                }
+            } finally {
+                client.stop();
+            }
+        }
+    }
+
+    private static Stream<Arguments> certificateSources() {
+        return Stream.of( //
+                Arguments.of("127.0/24", true), //
+                Arguments.of("8.8.8.8/24", false), //
+                Arguments.of("127.0.0.1/32", true), //
+                Arguments.of("8.8.8.8/8,127.0.0.1/32", true), //
+                Arguments.of("bogus", false), //
+                Arguments.of("", false));
+    }
+
 }

Reply via email to