This is an automated email from the ASF dual-hosted git repository. xiazcy pushed a commit to branch dotnet-http-interceptors in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit 5a89d22a641b9aaece234e0198191dc27fb71ebd Author: Yang Xia <[email protected]> AuthorDate: Tue Mar 17 14:43:11 2026 -0700 Add interceptor to .net with new reference auth class # Conflicts: # gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs --- gremlin-dotnet/Examples/Connections/Connections.cs | 2 +- gremlin-dotnet/docker-compose.yml | 1 + gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs | 187 +++++++++++++++++ .../src/Gremlin.Net/Driver/Connection.cs | 59 ++++-- .../src/Gremlin.Net/Driver/ConnectionSettings.cs | 6 + .../src/Gremlin.Net/Driver/GremlinClient.cs | 13 +- .../src/Gremlin.Net/Driver/GremlinServer.cs | 16 +- .../src/Gremlin.Net/Driver/HttpRequestContext.cs | 80 +++++++ .../Driver/Remote/DriverRemoteConnection.cs | 10 +- gremlin-dotnet/src/Gremlin.Net/Gremlin.Net.csproj | 1 + .../Docs/Reference/GremlinVariantsTests.cs | 4 +- .../Driver/AuthIntegrationTests.cs | 104 ++++++++++ .../test/Gremlin.Net.UnitTest/Driver/AuthTests.cs | 230 +++++++++++++++++++++ .../Gremlin.Net.UnitTest/Driver/ConnectionTests.cs | 104 ++++++++++ .../Driver/HttpRequestContextTests.cs | 103 +++++++++ 15 files changed, 883 insertions(+), 37 deletions(-) diff --git a/gremlin-dotnet/Examples/Connections/Connections.cs b/gremlin-dotnet/Examples/Connections/Connections.cs index c09eea3f7b..519eb70363 100644 --- a/gremlin-dotnet/Examples/Connections/Connections.cs +++ b/gremlin-dotnet/Examples/Connections/Connections.cs @@ -52,7 +52,7 @@ public class ConnectionExample static void WithConf() { using var remoteConnection = new DriverRemoteConnection(new GremlinClient( - new GremlinServer(hostname: ServerHost, port: ServerPort, enableSsl: false, username: "", password: "")), "g"); + new GremlinServer(hostname: ServerHost, port: ServerPort, enableSsl: false)), "g"); var g = Traversal().With(remoteConnection); var v = g.AddV(VertexLabel).Iterate(); diff --git a/gremlin-dotnet/docker-compose.yml b/gremlin-dotnet/docker-compose.yml index 697eab2184..b29e9036f6 100644 --- a/gremlin-dotnet/docker-compose.yml +++ b/gremlin-dotnet/docker-compose.yml @@ -56,6 +56,7 @@ services: - GREMLIN_SERVER_PORT=45940 - VERTEX_LABEL=dotnet-example working_dir: /gremlin-dotnet + # TODO: fix examples issue after feature-completion command: > bash -c "find . -path '*/TestResults/*.trx' -delete 2>/dev/null || true; dotnet tool update -g dotnet-trx; dotnet test ./Gremlin.Net.sln -c Release --logger trx; diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs new file mode 100644 index 0000000000..d7d0770c1b --- /dev/null +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs @@ -0,0 +1,187 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Text; +using System.Threading.Tasks; +using Amazon.Runtime; +using Amazon.Runtime.Internal; +using Amazon.Runtime.Internal.Auth; +using Amazon.Runtime.Internal.Util; + +namespace Gremlin.Net.Driver +{ + /// <summary> + /// Provides factory methods for built-in request interceptors. + /// </summary> + public static class Auth + { + /// <summary> + /// Returns a request interceptor that adds an HTTP Basic Authentication header. + /// The credentials are pre-computed once and set on every request. + /// </summary> + /// <param name="username">The username.</param> + /// <param name="password">The password.</param> + /// <returns>A request interceptor delegate.</returns> + public static Func<HttpRequestContext, Task> BasicAuth(string username, string password) + { + var encoded = Convert.ToBase64String( + Encoding.UTF8.GetBytes(username + ":" + password)); + var headerValue = "Basic " + encoded; + + return context => + { + context.Headers["Authorization"] = headerValue; + return Task.CompletedTask; + }; + } + + /// <summary> + /// Returns a request interceptor that signs requests using AWS Signature Version 4. + /// If <paramref name="credentials"/> is null, the default AWS credential chain is + /// used and the resolved provider is cached on first use (the provider itself handles + /// credential refresh for expiring credentials like STS). + /// </summary> + /// <param name="region">The AWS region (e.g. "us-east-1").</param> + /// <param name="service">The AWS service name (e.g. "neptune-db").</param> + /// <param name="credentials"> + /// Optional AWS credentials. When null, the default credential chain is used. + /// </param> + /// <returns>A request interceptor delegate.</returns> + public static Func<HttpRequestContext, Task> SigV4Auth( + string region, string service, AWSCredentials? credentials = null) + { + // Cache the credential provider once when using the default chain. + AWSCredentials? cachedProvider = credentials; + var cacheLock = new object(); + + // Create signer and config once — both are stateless and thread-safe + var signer = new AWS4Signer(); + var clientConfig = new SigningClientConfig + { + AuthenticationRegion = region, + AuthenticationServiceName = service, + }; + + return context => + { + if (cachedProvider == null) + { + lock (cacheLock) + { + cachedProvider ??= FallbackCredentialsFactory.GetCredentials(); + } + } + var immutableCreds = cachedProvider.GetCredentials(); + SignRequest(context, region, service, immutableCreds, signer, clientConfig); + return Task.CompletedTask; + }; + } + + private static void SignRequest(HttpRequestContext context, string region, string service, + ImmutableCredentials credentials, AWS4Signer signer, SigningClientConfig clientConfig) + { + // Build a DefaultRequest from the HttpRequestContext for the AWS SDK signer. + var endpointUri = new Uri(context.Uri.GetLeftPart(UriPartial.Authority)); + var awsRequest = new DefaultRequest(new NullRequest(), service) + { + HttpMethod = context.Method, + Endpoint = endpointUri, + ResourcePath = context.Uri.AbsolutePath, + Content = context.Body ?? Array.Empty<byte>(), + AuthenticationRegion = region, + OverrideSigningServiceName = service, + }; + + // Copy headers (skip Host — signer adds it) + foreach (var header in context.Headers) + { + if (!string.Equals(header.Key, "Host", StringComparison.OrdinalIgnoreCase)) + { + awsRequest.Headers[header.Key] = header.Value; + } + } + + // Copy query parameters + var query = context.Uri.Query; + if (!string.IsNullOrEmpty(query)) + { + // Remove leading '?' + var queryString = query.StartsWith("?") ? query.Substring(1) : query; + foreach (var param in queryString.Split('&')) + { + if (string.IsNullOrEmpty(param)) continue; + var parts = param.Split(new[] { '=' }, 2); + var key = Uri.UnescapeDataString(parts[0]); + var value = parts.Length > 1 ? Uri.UnescapeDataString(parts[1]) : ""; + awsRequest.Parameters[key] = value; + } + } + + // Set content hash header before signing + var payloadHash = context.GetPayloadHash(); + awsRequest.Headers["x-amz-content-sha256"] = payloadHash; + + // Sign the request + signer.Sign(awsRequest, clientConfig, new RequestMetrics(), + credentials.AccessKey, credentials.SecretKey); + + // Copy signed headers back to context. Cherry-pick the known SigV4 headers + // because the .NET Dictionary is case-sensitive and the AWS SDK may use + // different casing than what interceptors expect. + context.Headers["Host"] = endpointUri.Host; + if (awsRequest.Headers.ContainsKey("Authorization")) + { + context.Headers["Authorization"] = awsRequest.Headers["Authorization"]; + } + if (awsRequest.Headers.ContainsKey("X-Amz-Date")) + { + context.Headers["X-Amz-Date"] = awsRequest.Headers["X-Amz-Date"]; + } + context.Headers["x-amz-content-sha256"] = payloadHash; + + // Add session token if temporary credentials + if (!string.IsNullOrEmpty(credentials.Token)) + { + context.Headers["X-Amz-Security-Token"] = credentials.Token; + } + } + + /// <summary> + /// A minimal AmazonWebServiceRequest implementation required by DefaultRequest. + /// </summary> + private class NullRequest : AmazonWebServiceRequest + { + } + + /// <summary> + /// A minimal ClientConfig implementation for SigV4 signing. + /// </summary> + private class SigningClientConfig : ClientConfig + { + public override string RegionEndpointServiceName => "execute-api"; + public override string ServiceVersion => "2024-01-01"; + public override string UserAgent => "gremlin-dotnet"; + } + } +} diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs index 81e1484ba3..f6e6649490 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs @@ -47,8 +47,7 @@ namespace Gremlin.Net.Driver private readonly Uri _uri; private readonly IMessageSerializer _serializer; private readonly ConnectionSettings _settings; - // Interceptor slot reserved for future spec - // private readonly IReadOnlyList<Func<HttpRequestMessage, Task>> _interceptors; + private readonly IReadOnlyList<Func<HttpRequestContext, Task>> _interceptors; /// <summary> /// Creates a new HTTP connection. The <see cref="HttpClient"/> is backed by @@ -56,11 +55,13 @@ namespace Gremlin.Net.Driver /// so a single <see cref="Connection"/> instance handles concurrent requests efficiently. /// </summary> public Connection(Uri uri, IMessageSerializer serializer, - ConnectionSettings settings) + ConnectionSettings settings, + IReadOnlyList<Func<HttpRequestContext, Task>>? interceptors = null) { _uri = uri; _serializer = serializer; _settings = settings; + _interceptors = interceptors ?? Array.Empty<Func<HttpRequestContext, Task>>(); #if NET6_0_OR_GREATER var handler = new SocketsHttpHandler @@ -70,6 +71,13 @@ namespace Gremlin.Net.Driver ConnectTimeout = settings.ConnectionTimeout, KeepAlivePingTimeout = settings.KeepAliveInterval, }; + if (settings.SkipCertificateValidation) + { + handler.SslOptions = new System.Net.Security.SslClientAuthenticationOptions + { + RemoteCertificateValidationCallback = (_, _, _, _) => true, + }; + } _httpClient = new HttpClient(handler); #else _httpClient = new HttpClient(); @@ -81,12 +89,14 @@ namespace Gremlin.Net.Driver /// Constructor that accepts a pre-configured HttpClient (for testing). /// </summary> internal Connection(Uri uri, IMessageSerializer serializer, - ConnectionSettings settings, HttpClient httpClient) + ConnectionSettings settings, HttpClient httpClient, + IReadOnlyList<Func<HttpRequestContext, Task>>? interceptors = null) { _uri = uri; _serializer = serializer; _settings = settings; _httpClient = httpClient; + _interceptors = interceptors ?? Array.Empty<Func<HttpRequestContext, Task>>(); } public async Task<ResultSet<T>> SubmitAsync<T>(RequestMessage requestMessage, @@ -95,29 +105,50 @@ namespace Gremlin.Net.Driver var requestBytes = await _serializer.SerializeMessageAsync(requestMessage, cancellationToken) .ConfigureAwait(false); - using var content = new ByteArrayContent(requestBytes); - content.Headers.ContentType = new MediaTypeHeaderValue(GraphBinaryMimeType); - - using var httpRequest = new HttpRequestMessage(HttpMethod.Post, _uri); - httpRequest.Content = content; - httpRequest.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue(GraphBinaryMimeType)); + // Build HttpRequestContext with default headers + var headers = new Dictionary<string, string>(); + headers["Content-Type"] = GraphBinaryMimeType; + headers["Accept"] = GraphBinaryMimeType; if (_settings.EnableCompression) { - httpRequest.Headers.AcceptEncoding.Add(new StringWithQualityHeaderValue("deflate")); + headers["Accept-Encoding"] = "deflate"; } if (_settings.EnableUserAgentOnConnect) { - httpRequest.Headers.TryAddWithoutValidation("User-Agent", Utils.UserAgent); + headers["User-Agent"] = Utils.UserAgent; } if (_settings.BulkResults) { - httpRequest.Headers.Add("bulkResults", "true"); + headers["bulkResults"] = "true"; + } + + var context = new HttpRequestContext("POST", _uri, headers, requestBytes); + + // Apply interceptors in order + foreach (var interceptor in _interceptors) + { + await interceptor(context).ConfigureAwait(false); } - // Future: apply interceptors here + // Convert HttpRequestContext to HttpRequestMessage + using var httpRequest = new HttpRequestMessage(new HttpMethod(context.Method), context.Uri); + httpRequest.Content = new ByteArrayContent(context.Body); + + foreach (var header in context.Headers) + { + // Content-Type must be set on the content headers, not the request headers + if (string.Equals(header.Key, "Content-Type", StringComparison.OrdinalIgnoreCase)) + { + httpRequest.Content.Headers.ContentType = new MediaTypeHeaderValue(header.Value); + } + else + { + httpRequest.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + } using var response = await _httpClient.SendAsync(httpRequest, cancellationToken) .ConfigureAwait(false); diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/ConnectionSettings.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/ConnectionSettings.cs index 0d8df4a6a7..83faebf7ed 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/ConnectionSettings.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/ConnectionSettings.cs @@ -84,5 +84,11 @@ namespace Gremlin.Net.Driver /// Gets or sets whether to send the bulkResults: true header on all requests. /// </summary> public bool BulkResults { get; set; } = false; + + /// <summary> + /// Gets or sets whether to skip SSL certificate validation. + /// Only use for testing with self-signed certificates. + /// </summary> + public bool SkipCertificateValidation { get; set; } = false; } } diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs index cd9011622d..18d95e54c9 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs @@ -22,6 +22,7 @@ #endregion using System; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Gremlin.Net.Driver.Messages; @@ -50,11 +51,14 @@ namespace Gremlin.Net.Driver /// </param> /// <param name="connectionSettings">The <see cref="ConnectionSettings" /> for the HTTP connection.</param> /// <param name="loggerFactory">A factory to create loggers. If not provided, then nothing will be logged.</param> - // Interceptor slot reserved for future spec: - // IReadOnlyList<Func<HttpRequestMessage, Task>>? interceptors = null, + /// <param name="interceptors"> + /// An optional list of request interceptors. Each interceptor receives a mutable + /// <see cref="HttpRequestContext" /> and can modify headers, body, URI, and method before the request is sent. + /// </param> public GremlinClient(GremlinServer gremlinServer, IMessageSerializer? messageSerializer = null, ConnectionSettings? connectionSettings = null, - ILoggerFactory? loggerFactory = null) + ILoggerFactory? loggerFactory = null, + IReadOnlyList<Func<HttpRequestContext, Task>>? interceptors = null) { messageSerializer ??= new GraphBinary4MessageSerializer(); connectionSettings ??= new ConnectionSettings(); @@ -63,7 +67,8 @@ namespace Gremlin.Net.Driver _connection = new Connection( gremlinServer.Uri, messageSerializer, - connectionSettings); + connectionSettings, + interceptors); var logger = LoggerFactory.CreateLogger<GremlinClient>(); logger.InitializedHttpConnection(gremlinServer.Uri); diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinServer.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinServer.cs index 94c46936ad..19ed2fbb18 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinServer.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinServer.cs @@ -36,15 +36,11 @@ namespace Gremlin.Net.Driver /// <param name="hostname">The hostname of the server.</param> /// <param name="port">The port on which Gremlin Server can be reached.</param> /// <param name="enableSsl">Specifies whether SSL should be enabled.</param> - /// <param name="username">The username to submit on requests that require authentication.</param> - /// <param name="password">The password to submit on requests that require authentication.</param> /// <param name="path">The path to the Gremlin endpoint on the server.</param> public GremlinServer(string hostname = "localhost", int port = 8182, bool enableSsl = false, - string? username = null, string? password = null, string path = "/gremlin") + string path = "/gremlin") { Uri = CreateUri(hostname, port, enableSsl, path); - Username = username; - Password = password; } /// <summary> @@ -52,16 +48,6 @@ namespace Gremlin.Net.Driver /// </summary> public Uri Uri { get; } - /// <summary> - /// Gets the username to submit on requests that require authentication. - /// </summary> - public string? Username { get; } - - /// <summary> - /// Gets the password to submit on requests that require authentication. - /// </summary> - public string? Password { get; } - private static Uri CreateUri(string hostname, int port, bool enableSsl, string path) { var scheme = enableSsl ? "https" : "http"; diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs new file mode 100644 index 0000000000..f2ce916885 --- /dev/null +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs @@ -0,0 +1,80 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Collections.Generic; +using System.Security.Cryptography; + +namespace Gremlin.Net.Driver +{ + /// <summary> + /// Mutable HTTP request context passed to request interceptors. + /// </summary> + public class HttpRequestContext + { + /// <summary> + /// Gets or sets the HTTP method (e.g. "POST"). + /// </summary> + public string Method { get; set; } + + /// <summary> + /// Gets or sets the request URI. + /// </summary> + public Uri Uri { get; set; } + + /// <summary> + /// Gets the HTTP headers. Interceptors may add, modify, or remove entries. + /// </summary> + public Dictionary<string, string> Headers { get; } + + /// <summary> + /// Gets or sets the serialized request body. + /// </summary> + public byte[] Body { get; set; } + + /// <summary> + /// Initializes a new instance of the <see cref="HttpRequestContext" /> class. + /// </summary> + /// <param name="method">The HTTP method.</param> + /// <param name="uri">The request URI.</param> + /// <param name="headers">The HTTP headers.</param> + /// <param name="body">The serialized request body.</param> + public HttpRequestContext(string method, Uri uri, Dictionary<string, string> headers, byte[] body) + { + Method = method ?? throw new ArgumentNullException(nameof(method)); + Uri = uri ?? throw new ArgumentNullException(nameof(uri)); + Headers = headers ?? throw new ArgumentNullException(nameof(headers)); + Body = body; + } + + /// <summary> + /// Returns the lowercase hex-encoded SHA-256 digest of the body. + /// </summary> + public string GetPayloadHash() + { + using var sha256 = SHA256.Create(); + var hash = sha256.ComputeHash(Body ?? Array.Empty<byte>()); + return BitConverter.ToString(hash).Replace("-", "").ToLowerInvariant(); + } + } +} diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs index 12064984e7..f9fa1d62bb 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs @@ -57,10 +57,16 @@ namespace Gremlin.Net.Driver.Remote /// <param name="port">The port to connect to.</param> /// <param name="traversalSource">The name of the traversal source on the server to bind to.</param> /// <param name="loggerFactory">A factory to create loggers. If not provided, then nothing will be logged.</param> + /// <param name="interceptors"> + /// An optional list of request interceptors forwarded to the underlying + /// <see cref="GremlinClient" />. + /// </param> /// <exception cref="ArgumentNullException">Thrown when client is null.</exception> public DriverRemoteConnection(string host, int port, string traversalSource = "g", - ILoggerFactory? loggerFactory = null) : this( - new GremlinClient(new GremlinServer(host, port), loggerFactory: loggerFactory), traversalSource, + ILoggerFactory? loggerFactory = null, + IReadOnlyList<Func<HttpRequestContext, Task>>? interceptors = null) : this( + new GremlinClient(new GremlinServer(host, port), loggerFactory: loggerFactory, interceptors: interceptors), + traversalSource, logger: loggerFactory?.CreateLogger<DriverRemoteConnection>() ?? NullLogger<DriverRemoteConnection>.Instance) { } diff --git a/gremlin-dotnet/src/Gremlin.Net/Gremlin.Net.csproj b/gremlin-dotnet/src/Gremlin.Net/Gremlin.Net.csproj index df9cac5be7..c1c53ae69d 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Gremlin.Net.csproj +++ b/gremlin-dotnet/src/Gremlin.Net/Gremlin.Net.csproj @@ -72,6 +72,7 @@ NOTE that versions suffixed with "-rc" are considered release candidates (i.e. p </PropertyGroup> <ItemGroup> + <PackageReference Include="AWSSDK.Core" Version="3.7.400.2" /> <PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="8.0.2" /> <PackageReference Include="Microsoft.SourceLink.GitHub" Version="8.0.0" PrivateAssets="All" /> <PackageReference Include="Polly" Version="8.5.1" /> diff --git a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Docs/Reference/GremlinVariantsTests.cs b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Docs/Reference/GremlinVariantsTests.cs index a1689bdf0e..5c5103aff0 100644 --- a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Docs/Reference/GremlinVariantsTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Docs/Reference/GremlinVariantsTests.cs @@ -158,7 +158,9 @@ var response = // tag::submittingScriptsWithAuthentication[] var username = "username"; var password = "password"; -var gremlinServer = new GremlinServer("localhost", 8182, true, username, password); +var gremlinServer = new GremlinServer("localhost", 8182, enableSsl: true); +using var gremlinClient = new GremlinClient(gremlinServer, + interceptors: new[] { Auth.BasicAuth(username, password) }); // end::submittingScriptsWithAuthentication[] } diff --git a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/AuthIntegrationTests.cs b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/AuthIntegrationTests.cs new file mode 100644 index 0000000000..802a9fb066 --- /dev/null +++ b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/AuthIntegrationTests.cs @@ -0,0 +1,104 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Net.Http; +using System.Threading.Tasks; +using Gremlin.Net.Driver; +using Gremlin.Net.Driver.Remote; +using Gremlin.Net.Process.Traversal; +using Gremlin.Net.Structure.IO.GraphBinary4; +using Xunit; + +namespace Gremlin.Net.IntegrationTest.Driver +{ + /// <summary> + /// Integration tests for authentication interceptors against the secure Gremlin Server + /// (port 45941, SSL + SimpleAuthenticator). + /// </summary> + public class AuthIntegrationTests + { + private static readonly string TestHost = ConfigProvider.Configuration["TestServerIpAddress"]!; + private static readonly int TestSecurePort = Convert.ToInt32(ConfigProvider.Configuration["TestSecureServerPort"]); + + private GremlinClient CreateSecureClient(Func<HttpRequestContext, Task>[]? interceptors = null) + { + var gremlinServer = new GremlinServer(TestHost, TestSecurePort, enableSsl: true); + return new GremlinClient(gremlinServer, + new GraphBinaryMessageSerializer(), + new ConnectionSettings { SkipCertificateValidation = true }, + interceptors: interceptors); + } + + [Fact] + public async Task ShouldAuthenticateWithBasicAuth() + { + // The secure server uses SimpleAuthenticator with credentials: stephen/password + using var gremlinClient = CreateSecureClient( + new[] { Auth.BasicAuth("stephen", "password") }); + + var response = await gremlinClient.SubmitAsync<long>("g.inject(1).count()"); + + Assert.Single(response); + } + + [Fact] + public async Task ShouldAuthenticateWithBasicAuthViaDriverRemoteConnection() + { + // Test through DriverRemoteConnection + traversal + using var client = CreateSecureClient( + new[] { Auth.BasicAuth("stephen", "password") }); + using var remote = new DriverRemoteConnection(client, "gmodern"); + var g = AnonymousTraversalSource.Traversal().With(remote); + + var count = await g.V().Count().Promise(t => t.Next()); + + Assert.True(count > 0); + } + + [Fact] + public async Task ShouldFailWithWrongCredentials() + { + using var gremlinClient = CreateSecureClient( + new[] { Auth.BasicAuth("stephen", "wrongpassword") }); + + var ex = await Assert.ThrowsAsync<HttpRequestException>( + () => gremlinClient.SubmitAsync<long>("g.inject(1).count()")); + + // Verify this is an HTTP 401 error, not a connection refused error + Assert.Contains("401", ex.Message); + } + + [Fact] + public async Task ShouldFailWithNoCredentials() + { + using var gremlinClient = CreateSecureClient(); + + var ex = await Assert.ThrowsAsync<HttpRequestException>( + () => gremlinClient.SubmitAsync<long>("g.inject(1).count()")); + + // Verify this is an HTTP 401 error, not a connection refused error + Assert.Contains("401", ex.Message); + } + } +} diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/AuthTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/AuthTests.cs new file mode 100644 index 0000000000..6090521187 --- /dev/null +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/AuthTests.cs @@ -0,0 +1,230 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Collections.Generic; +using System.Security.Cryptography; +using System.Text; +using System.Threading.Tasks; +using Amazon.Runtime; +using Gremlin.Net.Driver; +using Xunit; + +namespace Gremlin.Net.UnitTest.Driver +{ + public class AuthTests + { + private static HttpRequestContext CreateTestContext() + { + return new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary<string, string>(), new byte[] { 0x01 }); + } + + [Fact] + public async Task BasicAuthShouldSetCorrectAuthorizationHeader() + { + var interceptor = Auth.BasicAuth("user", "pass"); + var context = CreateTestContext(); + + await interceptor(context); + + var expected = "Basic " + Convert.ToBase64String(Encoding.UTF8.GetBytes("user:pass")); + Assert.Equal(expected, context.Headers["Authorization"]); + } + + [Fact] + public async Task BasicAuthShouldSetHeaderOnEveryInvocation() + { + var interceptor = Auth.BasicAuth("user", "pass"); + var context1 = CreateTestContext(); + var context2 = CreateTestContext(); + + await interceptor(context1); + await interceptor(context2); + + var expected = "Basic " + Convert.ToBase64String(Encoding.UTF8.GetBytes("user:pass")); + Assert.Equal(expected, context1.Headers["Authorization"]); + Assert.Equal(expected, context2.Headers["Authorization"]); + } + + [Fact] + public async Task BasicAuthShouldHandleColonsInPassword() + { + var interceptor = Auth.BasicAuth("user", "pass:with:colons"); + var context = CreateTestContext(); + + await interceptor(context); + + var expected = "Basic " + Convert.ToBase64String( + Encoding.UTF8.GetBytes("user:pass:with:colons")); + Assert.Equal(expected, context.Headers["Authorization"]); + } + + [Fact] + public async Task BasicAuthShouldHandleUnicodeCharacters() + { + var interceptor = Auth.BasicAuth("用户", "密码"); + var context = CreateTestContext(); + + await interceptor(context); + + var expected = "Basic " + Convert.ToBase64String( + Encoding.UTF8.GetBytes("用户:密码")); + Assert.Equal(expected, context.Headers["Authorization"]); + } + + [Fact] + public async Task BasicAuthShouldOverwriteExistingAuthorizationHeader() + { + var interceptor = Auth.BasicAuth("user", "pass"); + var context = CreateTestContext(); + context.Headers["Authorization"] = "Bearer old-token"; + + await interceptor(context); + + var expected = "Basic " + Convert.ToBase64String(Encoding.UTF8.GetBytes("user:pass")); + Assert.Equal(expected, context.Headers["Authorization"]); + } + + [Fact] + public async Task BasicAuthShouldHandleEmptyCredentials() + { + var interceptor = Auth.BasicAuth("", ""); + var context = CreateTestContext(); + + await interceptor(context); + + var expected = "Basic " + Convert.ToBase64String(Encoding.UTF8.GetBytes(":")); + Assert.Equal(expected, context.Headers["Authorization"]); + } + + // --- SigV4 Tests --- + + private static readonly BasicAWSCredentials TestBasicCredentials = + new BasicAWSCredentials("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"); + + private static readonly SessionAWSCredentials TestSessionCredentials = + new SessionAWSCredentials("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "FwoGZXIvYXdzEBYaDHqa0AP"); + + private static HttpRequestContext CreateSigv4TestContext(byte[]? body = null) + { + return new HttpRequestContext("POST", new Uri("https://example.com:8182/gremlin"), + new Dictionary<string, string> + { + { "Content-Type", "application/vnd.graphbinary-v4.0" }, + { "Accept", "application/vnd.graphbinary-v4.0" }, + }, + body ?? new byte[] { 0x84, 0x00 }); + } + + [Fact] + public async Task SigV4AuthShouldAddRequiredHeaders() + { + var interceptor = Auth.SigV4Auth("us-east-1", "neptune-db", TestBasicCredentials); + var context = CreateSigv4TestContext(); + + await interceptor(context); + + Assert.True(context.Headers.ContainsKey("Authorization")); + Assert.True(context.Headers.ContainsKey("X-Amz-Date")); + Assert.True(context.Headers.ContainsKey("x-amz-content-sha256")); + Assert.True(context.Headers.ContainsKey("Host")); + } + + [Fact] + public async Task SigV4AuthShouldHaveCorrectAuthorizationPrefix() + { + var interceptor = Auth.SigV4Auth("us-west-2", "neptune-db", TestBasicCredentials); + var context = CreateSigv4TestContext(); + + await interceptor(context); + + Assert.StartsWith("AWS4-HMAC-SHA256 Credential=", context.Headers["Authorization"]); + Assert.Contains("us-west-2", context.Headers["Authorization"]); + Assert.Contains("neptune-db", context.Headers["Authorization"]); + } + + [Fact] + public async Task SigV4AuthShouldAddSessionTokenForTemporaryCredentials() + { + var interceptor = Auth.SigV4Auth("us-east-1", "neptune-db", TestSessionCredentials); + var context = CreateSigv4TestContext(); + + await interceptor(context); + + Assert.True(context.Headers.ContainsKey("X-Amz-Security-Token")); + Assert.Equal("FwoGZXIvYXdzEBYaDHqa0AP", context.Headers["X-Amz-Security-Token"]); + } + + [Fact] + public async Task SigV4AuthShouldNotAddSessionTokenForPermanentCredentials() + { + var interceptor = Auth.SigV4Auth("us-east-1", "neptune-db", TestBasicCredentials); + var context = CreateSigv4TestContext(); + + await interceptor(context); + + Assert.False(context.Headers.ContainsKey("X-Amz-Security-Token")); + } + + [Fact] + public async Task SigV4AuthContentHashShouldMatchBodySha256() + { + var body = new byte[] { 0x84, 0x00, 0xFD, 0x01 }; + var interceptor = Auth.SigV4Auth("us-east-1", "neptune-db", TestBasicCredentials); + var context = CreateSigv4TestContext(body); + + await interceptor(context); + + using var sha256 = SHA256.Create(); + var expectedHash = BitConverter.ToString(sha256.ComputeHash(body)) + .Replace("-", "").ToLowerInvariant(); + Assert.Equal(expectedHash, context.Headers["x-amz-content-sha256"]); + } + + [Fact] + public async Task SigV4AuthShouldHandleEmptyBody() + { + var interceptor = Auth.SigV4Auth("us-east-1", "neptune-db", TestBasicCredentials); + var context = CreateSigv4TestContext(Array.Empty<byte>()); + + await interceptor(context); + + Assert.True(context.Headers.ContainsKey("Authorization")); + Assert.Equal("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + context.Headers["x-amz-content-sha256"]); + } + + [Fact] + public async Task SigV4AuthShouldSetCorrectHost() + { + var interceptor = Auth.SigV4Auth("us-east-1", "neptune-db", TestBasicCredentials); + var context = CreateSigv4TestContext(); + + await interceptor(context); + + Assert.Equal("example.com", context.Headers["Host"]); + } + } +} diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs index a73b5acc8a..bcb4b35d8b 100644 --- a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs @@ -278,6 +278,110 @@ namespace Gremlin.Net.UnitTest.Driver return serializer; } + [Fact] + public async Task ShouldCallInterceptorsInOrder() + { + var (httpClient, _) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + var callOrder = new List<int>(); + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => { callOrder.Add(1); return Task.CompletedTask; }, + ctx => { callOrder.Add(2); return Task.CompletedTask; }, + ctx => { callOrder.Add(3); return Task.CompletedTask; }, + }; + + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.Equal(new List<int> { 1, 2, 3 }, callOrder); + } + + [Fact] + public async Task ShouldPropagateInterceptorException() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + var expectedException = new InvalidOperationException("interceptor failed"); + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + _ => throw expectedException, + }; + + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); + + var ex = await Assert.ThrowsAsync<InvalidOperationException>( + () => connection.SubmitAsync<object>(CreateTestRequest())); + + Assert.Same(expectedException, ex); + // HTTP request should not have been sent + Assert.Null(handler.CapturedRequest); + } + + [Fact] + public async Task ShouldAllowInterceptorToModifyHeaders() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + ctx.Headers["Authorization"] = "Basic dGVzdDp0ZXN0"; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(handler.CapturedRequest); + Assert.True(handler.CapturedRequest!.Headers.Contains("Authorization")); + Assert.Equal("Basic dGVzdDp0ZXN0", + handler.CapturedRequest.Headers.GetValues("Authorization").First()); + } + + [Fact] + public async Task ShouldSeeEarlierInterceptorModifications() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + string? observedHeader = null; + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + ctx.Headers["X-Custom"] = "first"; + return Task.CompletedTask; + }, + ctx => + { + observedHeader = ctx.Headers.ContainsKey("X-Custom") ? ctx.Headers["X-Custom"] : null; + ctx.Headers["X-Custom"] = "second"; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.Equal("first", observedHeader); + Assert.NotNull(handler.CapturedRequest); + Assert.Equal("second", + handler.CapturedRequest!.Headers.GetValues("X-Custom").First()); + } + /// <summary> /// A test HttpMessageHandler that captures the request and returns a canned response. /// </summary> diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs new file mode 100644 index 0000000000..d5fbb2bc90 --- /dev/null +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs @@ -0,0 +1,103 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Collections.Generic; +using System.Text; +using Gremlin.Net.Driver; +using Xunit; + +namespace Gremlin.Net.UnitTest.Driver +{ + public class HttpRequestContextTests + { + [Fact] + public void ShouldConstructWithAllProperties() + { + var method = "POST"; + var uri = new Uri("http://localhost:8182/gremlin"); + var headers = new Dictionary<string, string> { { "Content-Type", "application/vnd.graphbinary-v4.0" } }; + var body = new byte[] { 0x01, 0x02, 0x03 }; + + var context = new HttpRequestContext(method, uri, headers, body); + + Assert.Equal(method, context.Method); + Assert.Equal(uri, context.Uri); + Assert.Same(headers, context.Headers); + Assert.Same(body, context.Body); + } + + [Fact] + public void ShouldAllowMutatingProperties() + { + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary<string, string>(), new byte[] { 0x01 }); + + var newUri = new Uri("https://example.com/gremlin"); + context.Method = "PUT"; + context.Uri = newUri; + context.Body = new byte[] { 0x02, 0x03 }; + context.Headers["Authorization"] = "Basic dGVzdA=="; + + Assert.Equal("PUT", context.Method); + Assert.Equal(newUri, context.Uri); + Assert.Equal(new byte[] { 0x02, 0x03 }, context.Body); + Assert.Equal("Basic dGVzdA==", context.Headers["Authorization"]); + } + + [Fact] + public void ShouldComputePayloadHashForKnownBody() + { + // SHA-256 of "hello" = 2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824 + var body = Encoding.UTF8.GetBytes("hello"); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary<string, string>(), body); + + var hash = context.GetPayloadHash(); + + Assert.Equal("2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash); + } + + [Fact] + public void ShouldComputePayloadHashForEmptyBody() + { + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary<string, string>(), Array.Empty<byte>()); + + var hash = context.GetPayloadHash(); + + Assert.Equal("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hash); + } + + [Fact] + public void ShouldComputePayloadHashForNullBody() + { + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary<string, string>(), null!); + + var hash = context.GetPayloadHash(); + + Assert.Equal("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hash); + } + } +}
