This is an automated email from the ASF dual-hosted git repository. xiazcy pushed a commit to branch dotnet-http in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit 2c0b1e228a0e3ca62c8841f7c65a4cf0a59b5026 Author: Yang Xia <[email protected]> AuthorDate: Mon Mar 23 14:55:24 2026 -0700 Add interceptors to .NET with new reference auth implementations, split request and response serializer to allow user customization and moved MimeType into IMessageSerializer (#3334) --- .../Examples/BasicGremlin/BasicGremlin.cs | 5 +- gremlin-dotnet/Examples/Connections/Connections.cs | 25 +- .../Examples/ModernTraversals/ModernTraversals.cs | 4 +- gremlin-dotnet/docker-compose.yml | 1 + gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs | 194 ++++++ .../src/Gremlin.Net/Driver/Connection.cs | 125 +++- .../src/Gremlin.Net/Driver/ConnectionSettings.cs | 6 + .../src/Gremlin.Net/Driver/GremlinClient.cs | 60 +- .../src/Gremlin.Net/Driver/GremlinServer.cs | 16 +- .../src/Gremlin.Net/Driver/HttpRequestContext.cs | 93 +++ .../src/Gremlin.Net/Driver/IMessageSerializer.cs | 7 + .../Driver/Remote/DriverRemoteConnection.cs | 10 +- gremlin-dotnet/src/Gremlin.Net/Gremlin.Net.csproj | 1 + .../GraphBinary4/GraphBinary4MessageSerializer.cs | 4 + .../IO/GraphSON/GraphSON2MessageSerializer.cs | 4 +- .../IO/GraphSON/GraphSON3MessageSerializer.cs | 4 +- .../IO/GraphSON/GraphSONMessageSerializer.cs | 3 + .../Docs/Reference/GremlinVariantsTests.cs | 4 +- .../Driver/AuthIntegrationTests.cs | 102 +++ .../DriverRemoteConnection/GraphTraversalTests.cs | 10 +- .../test/Gremlin.Net.UnitTest/Driver/AuthTests.cs | 245 ++++++++ .../Gremlin.Net.UnitTest/Driver/ConnectionTests.cs | 690 ++++++++++++++++++++- .../Driver/HttpRequestContextTests.cs | 131 ++++ 23 files changed, 1651 insertions(+), 93 deletions(-) diff --git a/gremlin-dotnet/Examples/BasicGremlin/BasicGremlin.cs b/gremlin-dotnet/Examples/BasicGremlin/BasicGremlin.cs index 5e92f4853c..7b0ef042be 100644 --- a/gremlin-dotnet/Examples/BasicGremlin/BasicGremlin.cs +++ b/gremlin-dotnet/Examples/BasicGremlin/BasicGremlin.cs @@ -20,6 +20,7 @@ under the License. using Gremlin.Net.Driver; using Gremlin.Net.Driver.Remote; using static Gremlin.Net.Process.Traversal.AnonymousTraversalSource; +using static Gremlin.Net.Process.Traversal.__; public class BasicGremlinExample { @@ -40,8 +41,8 @@ public class BasicGremlinExample // Be sure to use a terminating step like Next() or Iterate() so that the traversal "executes" // Iterate() does not return any data and is used to just generate side-effects (i.e. write data to the database) - g.V(v1).AddE("knows").To(v2).Property("weight", 0.75).Iterate(); - g.V(v1).AddE("knows").To(v3).Property("weight", 0.75).Iterate(); + g.V(v1).AddE("knows").To(Constant(v2)).Property("weight", 0.75).Iterate(); + g.V(v1).AddE("knows").To(Constant(v3)).Property("weight", 0.75).Iterate(); // Retrieve the data from the "marko" vertex var marko = await g.V().Has(VertexLabel, "name", "marko").Values<string>("name").Promise(t => t.Next()); diff --git a/gremlin-dotnet/Examples/Connections/Connections.cs b/gremlin-dotnet/Examples/Connections/Connections.cs index a553b4d236..bcfda7e4f8 100644 --- a/gremlin-dotnet/Examples/Connections/Connections.cs +++ b/gremlin-dotnet/Examples/Connections/Connections.cs @@ -19,20 +19,20 @@ under the License. using Gremlin.Net.Driver; using Gremlin.Net.Driver.Remote; -using Gremlin.Net.Structure.IO.GraphSON; using static Gremlin.Net.Process.Traversal.AnonymousTraversalSource; public class ConnectionExample { static readonly string ServerHost = Environment.GetEnvironmentVariable("GREMLIN_SERVER_HOST") ?? "localhost"; static readonly int ServerPort = int.Parse(Environment.GetEnvironmentVariable("GREMLIN_SERVER_PORT") ?? "8182"); + static readonly int SecureServerPort = int.Parse(Environment.GetEnvironmentVariable("GREMLIN_SECURE_SERVER_PORT") ?? "8183"); static readonly string VertexLabel = Environment.GetEnvironmentVariable("VERTEX_LABEL") ?? "connection"; static void Main() { WithRemoteConnection(); WithConf(); - WithSerializer(); + WithBasicAuth(); } // Connecting to the server @@ -48,11 +48,16 @@ public class ConnectionExample Console.WriteLine("Vertex count: " + count); } - // Connecting to the server with customized configurations + // Connecting to the server with customized connection settings static void WithConf() { - using var remoteConnection = new DriverRemoteConnection(new GremlinClient( - new GremlinServer(hostname: ServerHost, port: ServerPort, enableSsl: false, username: "", password: "")), "g"); + var server = new GremlinServer(ServerHost, ServerPort); + var settings = new ConnectionSettings + { + ConnectionTimeout = TimeSpan.FromSeconds(30), + }; + using var remoteConnection = new DriverRemoteConnection( + new GremlinClient(server, connectionSettings: settings), "g"); var g = Traversal().With(remoteConnection); var v = g.AddV(VertexLabel).Iterate(); @@ -60,11 +65,13 @@ public class ConnectionExample Console.WriteLine("Vertex count: " + count); } - // Specifying a serializer - static void WithSerializer() + // Connecting with basic authentication using a request interceptor + static void WithBasicAuth() { - var server = new GremlinServer(ServerHost, ServerPort); - var client = new GremlinClient(server, new GraphSON3MessageSerializer()); + var server = new GremlinServer(ServerHost, SecureServerPort, enableSsl: true); + var client = new GremlinClient(server, + connectionSettings: new ConnectionSettings { SkipCertificateValidation = true }, + interceptors: new[] { Auth.BasicAuth("stephen", "password") }); using var remoteConnection = new DriverRemoteConnection(client, "g"); var g = Traversal().With(remoteConnection); diff --git a/gremlin-dotnet/Examples/ModernTraversals/ModernTraversals.cs b/gremlin-dotnet/Examples/ModernTraversals/ModernTraversals.cs index 77f20c6d11..19a1d0b00b 100644 --- a/gremlin-dotnet/Examples/ModernTraversals/ModernTraversals.cs +++ b/gremlin-dotnet/Examples/ModernTraversals/ModernTraversals.cs @@ -47,8 +47,8 @@ public class ModernTraversalExample var e2 = g.V(1).BothE().Where(OtherV().HasId(2)).ToList(); // (2) var v1 = g.V(1).Next(); var v2 = g.V(2).Next(); - var e3 = g.V(v1).BothE().Where(OtherV().Is(v2)).ToList(); // (3) - var e4 = g.V(v1).OutE().Where(InV().Is(v2)).ToList(); // (4) + var e3 = g.V(v1).BothE().Where(OtherV().Id().Is(v2)).ToList(); // (3) + var e4 = g.V(v1).OutE().Where(InV().Id().Is(v2)).ToList(); // (4) var e5 = g.V(1).OutE().Where(InV().Has(T.Id, Within(2, 3))).ToList(); // (5) var e6 = g.V(1).Out().Where(__.In().HasId(6)).ToList(); // (6) diff --git a/gremlin-dotnet/docker-compose.yml b/gremlin-dotnet/docker-compose.yml index a2a9343dc1..a32dae8f30 100644 --- a/gremlin-dotnet/docker-compose.yml +++ b/gremlin-dotnet/docker-compose.yml @@ -54,6 +54,7 @@ services: - DOCKER_ENVIRONMENT=true - GREMLIN_SERVER_HOST=gremlin-server-test-dotnet - GREMLIN_SERVER_PORT=45940 + - GREMLIN_SECURE_SERVER_PORT=45941 - VERTEX_LABEL=dotnet-example working_dir: /gremlin-dotnet command: > diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs new file mode 100644 index 0000000000..e10aafe55f --- /dev/null +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs @@ -0,0 +1,194 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Text; +using System.Threading.Tasks; +using Amazon.Runtime; +using Amazon.Runtime.Internal; +using Amazon.Runtime.Internal.Auth; +using Amazon.Runtime.Internal.Util; + +namespace Gremlin.Net.Driver +{ + /// <summary> + /// Provides factory methods for built-in request interceptors. + /// </summary> + public static class Auth + { + /// <summary> + /// Returns a request interceptor that adds an HTTP Basic Authentication header. + /// The credentials are pre-computed once and set on every request. + /// </summary> + /// <param name="username">The username.</param> + /// <param name="password">The password.</param> + /// <returns>A request interceptor delegate.</returns> + public static Func<HttpRequestContext, Task> BasicAuth(string username, string password) + { + var encoded = Convert.ToBase64String( + Encoding.UTF8.GetBytes(username + ":" + password)); + var headerValue = "Basic " + encoded; + + return context => + { + context.Headers["Authorization"] = headerValue; + return Task.CompletedTask; + }; + } + + /// <summary> + /// Returns a request interceptor that signs requests using AWS Signature Version 4. + /// If <paramref name="credentials"/> is null, the default AWS credential chain is + /// used and the resolved provider is cached on first use (the provider itself handles + /// credential refresh for expiring credentials like STS). + /// </summary> + /// <param name="region">The AWS region (e.g. "us-east-1").</param> + /// <param name="service">The AWS service name (e.g. "neptune-db").</param> + /// <param name="credentials"> + /// Optional AWS credentials. When null, the default credential chain is used. + /// </param> + /// <returns>A request interceptor delegate.</returns> + public static Func<HttpRequestContext, Task> SigV4Auth( + string region, string service, AWSCredentials? credentials = null) + { + // Cache the credential provider once when using the default chain. + AWSCredentials? cachedProvider = credentials; + var cacheLock = new object(); + + // Create signer and config once — both are stateless and thread-safe + var signer = new AWS4Signer(); + var clientConfig = new SigningClientConfig + { + AuthenticationRegion = region, + AuthenticationServiceName = service, + }; + + return async context => + { + if (cachedProvider == null) + { + lock (cacheLock) + { + // FallbackCredentialsFactory only has a sync API, but this runs once. + cachedProvider ??= FallbackCredentialsFactory.GetCredentials(); + } + } + + // Use the async path — important for credential providers that perform + // network I/O (e.g. IMDS on EC2, ECS task role endpoint). + var immutableCreds = await cachedProvider.GetCredentialsAsync() + .ConfigureAwait(false); + SignRequest(context, immutableCreds, signer, clientConfig); + }; + } + + private static void SignRequest(HttpRequestContext context, + ImmutableCredentials credentials, AWS4Signer signer, SigningClientConfig clientConfig) + { + // Build a DefaultRequest from the HttpRequestContext for the AWS SDK signer. + var endpointUri = new Uri(context.Uri.GetLeftPart(UriPartial.Authority)); + var awsRequest = new DefaultRequest(new NullRequest(), clientConfig.AuthenticationServiceName) + { + HttpMethod = context.Method, + Endpoint = endpointUri, + ResourcePath = context.Uri.AbsolutePath, + Content = context.Body is byte[] bytes + ? bytes + : throw new InvalidOperationException( + "SigV4 signing requires Body to be byte[]. " + + "Ensure serialization occurs before the SigV4 interceptor."), + AuthenticationRegion = clientConfig.AuthenticationRegion, + OverrideSigningServiceName = clientConfig.AuthenticationServiceName, + }; + + // Copy headers (skip Host — signer adds it) + foreach (var header in context.Headers) + { + if (!string.Equals(header.Key, "Host", StringComparison.OrdinalIgnoreCase)) + { + awsRequest.Headers[header.Key] = header.Value; + } + } + + // Copy query parameters + var query = context.Uri.Query; + if (!string.IsNullOrEmpty(query)) + { + // Remove leading '?' + var queryString = query.StartsWith("?") ? query.Substring(1) : query; + foreach (var param in queryString.Split('&')) + { + if (string.IsNullOrEmpty(param)) continue; + var parts = param.Split(new[] { '=' }, 2); + var key = Uri.UnescapeDataString(parts[0]); + var value = parts.Length > 1 ? Uri.UnescapeDataString(parts[1]) : ""; + awsRequest.Parameters[key] = value; + } + } + + // Set content hash header before signing + var payloadHash = context.GetPayloadHash(); + awsRequest.Headers["x-amz-content-sha256"] = payloadHash; + + // Sign the request + signer.Sign(awsRequest, clientConfig, new RequestMetrics(), credentials); + + // Copy signed headers back to context. Cherry-pick the known SigV4 headers + // because the .NET Dictionary is case-sensitive and the AWS SDK may use + // different casing than what interceptors expect. + context.Headers["Host"] = endpointUri.Host; + if (awsRequest.Headers.ContainsKey("Authorization")) + { + context.Headers["Authorization"] = awsRequest.Headers["Authorization"]; + } + if (awsRequest.Headers.ContainsKey("X-Amz-Date")) + { + context.Headers["X-Amz-Date"] = awsRequest.Headers["X-Amz-Date"]; + } + context.Headers["x-amz-content-sha256"] = payloadHash; + + // Add session token if temporary credentials + if (!string.IsNullOrEmpty(credentials.Token)) + { + context.Headers["X-Amz-Security-Token"] = credentials.Token; + } + } + + /// <summary> + /// A minimal AmazonWebServiceRequest implementation required by DefaultRequest. + /// </summary> + private class NullRequest : AmazonWebServiceRequest + { + } + + /// <summary> + /// A minimal ClientConfig implementation for SigV4 signing. + /// </summary> + private class SigningClientConfig : ClientConfig + { + public override string RegionEndpointServiceName => "execute-api"; + public override string ServiceVersion => "2024-01-01"; + public override string UserAgent => "gremlin-dotnet"; + } + } +} diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs index 883d1c2bab..84ec963d40 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs @@ -41,26 +41,40 @@ namespace Gremlin.Net.Driver /// </summary> internal class Connection : IDisposable { - private const string GraphBinaryMimeType = SerializationTokens.GraphBinary4MimeType; - private readonly HttpClient _httpClient; private readonly Uri _uri; - private readonly IMessageSerializer _serializer; + private readonly IMessageSerializer? _requestSerializer; + private readonly IMessageSerializer _responseSerializer; private readonly ConnectionSettings _settings; - // Interceptor slot reserved for future spec - // private readonly IReadOnlyList<Func<HttpRequestMessage, Task>> _interceptors; + private readonly IReadOnlyList<Func<HttpRequestContext, Task>> _interceptors; /// <summary> /// Creates a new HTTP connection. The <see cref="HttpClient"/> is backed by /// SocketsHttpHandler which manages its own TCP connection pool internally, /// so a single <see cref="Connection"/> instance handles concurrent requests efficiently. /// </summary> - public Connection(Uri uri, IMessageSerializer serializer, - ConnectionSettings settings) + /// <param name="uri">The Gremlin Server URI.</param> + /// <param name="requestSerializer"> + /// The serializer for outgoing requests. When non-null, the request body is serialized + /// to <c>byte[]</c> before interceptors run and the <c>Content-Type</c> header is set + /// automatically. When <c>null</c>, the body is passed as a <see cref="RequestMessage"/> + /// and an interceptor is responsible for serializing it to <c>byte[]</c> and setting + /// <c>Content-Type</c>. This follows the Python driver's <c>request_serializer=None</c> + /// pattern. + /// </param> + /// <param name="responseSerializer">The serializer for incoming responses (always required).</param> + /// <param name="settings">Connection settings.</param> + /// <param name="interceptors">Optional request interceptors.</param> + public Connection(Uri uri, IMessageSerializer? requestSerializer, + IMessageSerializer responseSerializer, + ConnectionSettings settings, + IReadOnlyList<Func<HttpRequestContext, Task>>? interceptors = null) { _uri = uri; - _serializer = serializer; + _requestSerializer = requestSerializer; + _responseSerializer = responseSerializer; _settings = settings; + _interceptors = interceptors ?? Array.Empty<Func<HttpRequestContext, Task>>(); #if NET6_0_OR_GREATER var handler = new SocketsHttpHandler @@ -70,6 +84,13 @@ namespace Gremlin.Net.Driver ConnectTimeout = settings.ConnectionTimeout, KeepAlivePingTimeout = settings.KeepAliveInterval, }; + if (settings.SkipCertificateValidation) + { + handler.SslOptions = new System.Net.Security.SslClientAuthenticationOptions + { + RemoteCertificateValidationCallback = (_, _, _, _) => true, + }; + } _httpClient = new HttpClient(handler); #else _httpClient = new HttpClient(); @@ -80,44 +101,96 @@ namespace Gremlin.Net.Driver /// <summary> /// Constructor that accepts a pre-configured HttpClient (for testing). /// </summary> - internal Connection(Uri uri, IMessageSerializer serializer, - ConnectionSettings settings, HttpClient httpClient) + internal Connection(Uri uri, IMessageSerializer? requestSerializer, + IMessageSerializer responseSerializer, + ConnectionSettings settings, HttpClient httpClient, + IReadOnlyList<Func<HttpRequestContext, Task>>? interceptors = null) { _uri = uri; - _serializer = serializer; + _requestSerializer = requestSerializer; + _responseSerializer = responseSerializer; _settings = settings; _httpClient = httpClient; + _interceptors = interceptors ?? Array.Empty<Func<HttpRequestContext, Task>>(); } public async Task<ResultSet<T>> SubmitAsync<T>(RequestMessage requestMessage, CancellationToken cancellationToken = default) { - var requestBytes = await _serializer.SerializeMessageAsync(requestMessage, cancellationToken) - .ConfigureAwait(false); - - using var content = new ByteArrayContent(requestBytes); - content.Headers.ContentType = new MediaTypeHeaderValue(GraphBinaryMimeType); - - using var httpRequest = new HttpRequestMessage(HttpMethod.Post, _uri); - httpRequest.Content = content; - httpRequest.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue(GraphBinaryMimeType)); + // Build HttpRequestContext with default headers + var headers = new Dictionary<string, string>(); + headers["Accept"] = _responseSerializer.MimeType; if (_settings.EnableCompression) { - httpRequest.Headers.AcceptEncoding.Add(new StringWithQualityHeaderValue("deflate")); + headers["Accept-Encoding"] = "deflate"; } if (_settings.EnableUserAgentOnConnect) { - httpRequest.Headers.TryAddWithoutValidation("User-Agent", Utils.UserAgent); + headers["User-Agent"] = Utils.UserAgent; } if (_settings.BulkResults) { - httpRequest.Headers.Add("bulkResults", "true"); + headers["bulkResults"] = "true"; + } + + object body; + if (_requestSerializer != null) + { + // Default path: serialize before interceptors + var requestBytes = await _requestSerializer.SerializeMessageAsync(requestMessage, cancellationToken) + .ConfigureAwait(false); + body = requestBytes; + headers["Content-Type"] = _requestSerializer.MimeType; + } + else + { + // Interceptor-managed path: body is RequestMessage, interceptor must serialize + body = requestMessage; + } + + var context = new HttpRequestContext("POST", _uri, headers, body); + + // Apply interceptors in order + foreach (var interceptor in _interceptors) + { + await interceptor(context).ConfigureAwait(false); } - // Future: apply interceptors here + // Convert HttpRequestContext to HttpRequestMessage + using var httpRequest = new HttpRequestMessage(new HttpMethod(context.Method), context.Uri); + + if (context.Body is byte[] bodyBytes) + { + httpRequest.Content = new ByteArrayContent(bodyBytes); + } + else if (context.Body is HttpContent httpContent) + { + httpRequest.Content = httpContent; + } + else + { + throw new InvalidOperationException( + "Request body must be byte[] or HttpContent after all interceptors complete, " + + "but found " + (context.Body?.GetType().Name ?? "null") + + ". Either provide a requestSerializer or add an interceptor " + + "that serializes the RequestMessage."); + } + + foreach (var header in context.Headers) + { + // Content-Type must be set on the content headers, not the request headers + if (string.Equals(header.Key, "Content-Type", StringComparison.OrdinalIgnoreCase)) + { + httpRequest.Content.Headers.ContentType = new MediaTypeHeaderValue(header.Value); + } + else + { + httpRequest.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + } using var response = await _httpClient.SendAsync(httpRequest, cancellationToken) .ConfigureAwait(false); @@ -128,7 +201,7 @@ namespace Gremlin.Net.Driver // is GraphBinary do we fall through to normal deserialization so the status footer // in the GB4 response can surface the application-level error. if (!response.IsSuccessStatusCode && - response.Content.Headers.ContentType?.MediaType != GraphBinaryMimeType) + response.Content.Headers.ContentType?.MediaType != _responseSerializer.MimeType) { var errorBody = await response.Content.ReadAsStringAsync().ConfigureAwait(false); @@ -141,7 +214,7 @@ namespace Gremlin.Net.Driver var responseBytes = await ReadResponseBytesAsync(response).ConfigureAwait(false); - var responseMessage = await _serializer.DeserializeMessageAsync(responseBytes, cancellationToken) + var responseMessage = await _responseSerializer.DeserializeMessageAsync(responseBytes, cancellationToken) .ConfigureAwait(false); return BuildResultSet<T>(responseMessage); diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/ConnectionSettings.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/ConnectionSettings.cs index 0d8df4a6a7..83faebf7ed 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/ConnectionSettings.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/ConnectionSettings.cs @@ -84,5 +84,11 @@ namespace Gremlin.Net.Driver /// Gets or sets whether to send the bulkResults: true header on all requests. /// </summary> public bool BulkResults { get; set; } = false; + + /// <summary> + /// Gets or sets whether to skip SSL certificate validation. + /// Only use for testing with self-signed certificates. + /// </summary> + public bool SkipCertificateValidation { get; set; } = false; } } diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs index b5b43faccf..50dcbcbd27 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs @@ -22,6 +22,7 @@ #endregion using System; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Gremlin.Net.Driver.Messages; @@ -44,31 +45,70 @@ namespace Gremlin.Net.Driver /// Initializes a new instance of the <see cref="GremlinClient" /> class for the specified Gremlin Server. /// </summary> /// <param name="gremlinServer">The <see cref="GremlinServer" /> the requests should be sent to.</param> - /// <param name="messageSerializer"> - /// A <see cref="IMessageSerializer" /> instance to serialize messages sent to and received - /// from the server. + /// <param name="requestSerializer"> + /// A <see cref="IMessageSerializer" /> instance to serialize outgoing request messages. + /// When <c>null</c>, the request body is passed as a <see cref="RequestMessage"/> to + /// interceptors, and an interceptor must serialize it to <c>byte[]</c> and set the + /// <c>Content-Type</c> header. This follows the Python driver's + /// <c>request_serializer=None</c> pattern. + /// </param> + /// <param name="responseSerializer"> + /// A <see cref="IMessageSerializer" /> instance to deserialize incoming response messages. + /// Always required. /// </param> /// <param name="connectionSettings">The <see cref="ConnectionSettings" /> for the HTTP connection.</param> /// <param name="loggerFactory">A factory to create loggers. If not provided, then nothing will be logged.</param> - // Interceptor slot reserved for future spec: - // IReadOnlyList<Func<HttpRequestMessage, Task>>? interceptors = null, - public GremlinClient(GremlinServer gremlinServer, IMessageSerializer? messageSerializer = null, + /// <param name="interceptors"> + /// An optional list of request interceptors. Each interceptor receives a mutable + /// <see cref="HttpRequestContext" /> and can modify headers, body, URI, and method + /// before the request is sent. + /// </param> + public GremlinClient(GremlinServer gremlinServer, IMessageSerializer? requestSerializer, + IMessageSerializer responseSerializer, ConnectionSettings? connectionSettings = null, - ILoggerFactory? loggerFactory = null) + ILoggerFactory? loggerFactory = null, + IReadOnlyList<Func<HttpRequestContext, Task>>? interceptors = null) { - messageSerializer ??= new GraphBinary4MessageSerializer(); connectionSettings ??= new ConnectionSettings(); LoggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _connection = new Connection( gremlinServer.Uri, - messageSerializer, - connectionSettings); + requestSerializer, + responseSerializer, + connectionSettings, + interceptors); var logger = LoggerFactory.CreateLogger<GremlinClient>(); logger.InitializedHttpConnection(gremlinServer.Uri); } + /// <summary> + /// Initializes a new instance of the <see cref="GremlinClient" /> class with a single + /// serializer used for both request serialization and response deserialization. + /// This is the backward-compatible convenience constructor. + /// </summary> + /// <param name="gremlinServer">The <see cref="GremlinServer" /> the requests should be sent to.</param> + /// <param name="messageSerializer"> + /// A <see cref="IMessageSerializer" /> instance used for both request serialization and + /// response deserialization. Defaults to <see cref="GraphBinary4MessageSerializer"/>. + /// </param> + /// <param name="connectionSettings">The <see cref="ConnectionSettings" /> for the HTTP connection.</param> + /// <param name="loggerFactory">A factory to create loggers. If not provided, then nothing will be logged.</param> + /// <param name="interceptors"> + /// An optional list of request interceptors. + /// </param> + public GremlinClient(GremlinServer gremlinServer, IMessageSerializer? messageSerializer = null, + ConnectionSettings? connectionSettings = null, + ILoggerFactory? loggerFactory = null, + IReadOnlyList<Func<HttpRequestContext, Task>>? interceptors = null) + : this(gremlinServer, + messageSerializer ?? new GraphBinary4MessageSerializer(), + messageSerializer ?? new GraphBinary4MessageSerializer(), + connectionSettings, loggerFactory, interceptors) + { + } + /// <inheritdoc /> public async Task<ResultSet<T>> SubmitAsync<T>(RequestMessage requestMessage, CancellationToken cancellationToken = default) diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinServer.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinServer.cs index b17b8836ad..0d8a41dfd1 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinServer.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinServer.cs @@ -36,15 +36,11 @@ namespace Gremlin.Net.Driver /// <param name="hostname">The hostname of the server.</param> /// <param name="port">The port on which Gremlin Server can be reached.</param> /// <param name="enableSsl">Specifies whether SSL should be enabled.</param> - /// <param name="username">The username to submit on requests that require authentication.</param> - /// <param name="password">The password to submit on requests that require authentication.</param> /// <param name="path">The path to the Gremlin endpoint on the server.</param> public GremlinServer(string hostname = "localhost", int port = 8182, bool enableSsl = false, - string? username = null, string? password = null, string path = "/gremlin") + string path = "/gremlin") { Uri = CreateUri(hostname, port, enableSsl, path); - Username = username; - Password = password; } /// <summary> @@ -52,16 +48,6 @@ namespace Gremlin.Net.Driver /// </summary> public Uri Uri { get; } - /// <summary> - /// Gets the username to submit on requests that require authentication. - /// </summary> - public string? Username { get; } - - /// <summary> - /// Gets the password to submit on requests that require authentication. - /// </summary> - public string? Password { get; } - private static Uri CreateUri(string hostname, int port, bool enableSsl, string path) { var scheme = enableSsl ? "https" : "http"; diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs new file mode 100644 index 0000000000..9fc170b45f --- /dev/null +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs @@ -0,0 +1,93 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Collections.Generic; +using System.Security.Cryptography; + +namespace Gremlin.Net.Driver +{ + /// <summary> + /// Mutable HTTP request context passed to request interceptors. + /// </summary> + public class HttpRequestContext + { + /// <summary> + /// Gets or sets the HTTP method (e.g. "POST"). + /// </summary> + public string Method { get; set; } + + /// <summary> + /// Gets or sets the request URI. + /// </summary> + public Uri Uri { get; set; } + + /// <summary> + /// Gets the HTTP headers. Interceptors may add, modify, or remove entries. + /// </summary> + public Dictionary<string, string> Headers { get; } + + /// <summary> + /// Gets or sets the request body. This is <c>byte[]</c> when serialization has occurred + /// (default path), or <c>RequestMessage</c> when serialization is deferred to interceptors + /// (<c>requestSerializer = null</c>). Interceptors may also set this to an + /// <see cref="System.Net.Http.HttpContent"/> instance for full control over the wire format. + /// </summary> + public object Body { get; set; } + + /// <summary> + /// Initializes a new instance of the <see cref="HttpRequestContext" /> class. + /// </summary> + /// <param name="method">The HTTP method.</param> + /// <param name="uri">The request URI.</param> + /// <param name="headers">The HTTP headers.</param> + /// <param name="body">The request body. Typically <c>byte[]</c> (post-serialization) or + /// <c>RequestMessage</c> (pre-serialization).</param> + public HttpRequestContext(string method, Uri uri, Dictionary<string, string> headers, object body) + { + Method = method ?? throw new ArgumentNullException(nameof(method)); + Uri = uri ?? throw new ArgumentNullException(nameof(uri)); + Headers = headers ?? throw new ArgumentNullException(nameof(headers)); + Body = body; + } + + /// <summary> + /// Returns the lowercase hex-encoded SHA-256 digest of the body. + /// Throws <see cref="InvalidOperationException"/> if <see cref="Body"/> is not <c>byte[]</c>, + /// which indicates that serialization has not yet occurred. + /// </summary> + public string GetPayloadHash() + { + if (Body is not byte[] bytes) + { + throw new InvalidOperationException( + "Cannot compute payload hash before serialization. " + + "Body must be byte[] but is " + + (Body?.GetType().Name ?? "null") + "."); + } + using var sha256 = SHA256.Create(); + var hash = sha256.ComputeHash(bytes); + return BitConverter.ToString(hash).Replace("-", "").ToLowerInvariant(); + } + } +} diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs index afb65b05fb..81fb5fd6b1 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs @@ -33,6 +33,13 @@ namespace Gremlin.Net.Driver /// </summary> public interface IMessageSerializer { + /// <summary> + /// Gets the MIME type produced by this serializer (e.g. + /// <c>"application/vnd.graphbinary-v4.0"</c>). Used by the driver to set + /// <c>Content-Type</c> and <c>Accept</c> headers automatically. + /// </summary> + string MimeType { get; } + /// <summary> /// Serializes a <see cref="RequestMessage"/>. /// </summary> diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs index 4ddb326db8..81dd3433e8 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs @@ -57,10 +57,16 @@ namespace Gremlin.Net.Driver.Remote /// <param name="port">The port to connect to.</param> /// <param name="traversalSource">The name of the traversal source on the server to bind to.</param> /// <param name="loggerFactory">A factory to create loggers. If not provided, then nothing will be logged.</param> + /// <param name="interceptors"> + /// An optional list of request interceptors forwarded to the underlying + /// <see cref="GremlinClient" />. + /// </param> /// <exception cref="ArgumentNullException">Thrown when client is null.</exception> public DriverRemoteConnection(string host, int port, string traversalSource = "g", - ILoggerFactory? loggerFactory = null) : this( - new GremlinClient(new GremlinServer(host, port), loggerFactory: loggerFactory), traversalSource, + ILoggerFactory? loggerFactory = null, + IReadOnlyList<Func<HttpRequestContext, Task>>? interceptors = null) : this( + new GremlinClient(new GremlinServer(host, port), loggerFactory: loggerFactory, interceptors: interceptors), + traversalSource, logger: loggerFactory?.CreateLogger<DriverRemoteConnection>() ?? NullLogger<DriverRemoteConnection>.Instance) { } diff --git a/gremlin-dotnet/src/Gremlin.Net/Gremlin.Net.csproj b/gremlin-dotnet/src/Gremlin.Net/Gremlin.Net.csproj index df9cac5be7..c1c53ae69d 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Gremlin.Net.csproj +++ b/gremlin-dotnet/src/Gremlin.Net/Gremlin.Net.csproj @@ -72,6 +72,7 @@ NOTE that versions suffixed with "-rc" are considered release candidates (i.e. p </PropertyGroup> <ItemGroup> + <PackageReference Include="AWSSDK.Core" Version="3.7.400.2" /> <PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="8.0.2" /> <PackageReference Include="Microsoft.SourceLink.GitHub" Version="8.0.0" PrivateAssets="All" /> <PackageReference Include="Polly" Version="8.5.1" /> diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/GraphBinary4MessageSerializer.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/GraphBinary4MessageSerializer.cs index 6f8bbbd4eb..214f32be6e 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/GraphBinary4MessageSerializer.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/GraphBinary4MessageSerializer.cs @@ -27,6 +27,7 @@ using System.Threading; using System.Threading.Tasks; using Gremlin.Net.Driver; using Gremlin.Net.Driver.Messages; +using Gremlin.Net.Structure.IO; namespace Gremlin.Net.Structure.IO.GraphBinary4 { @@ -40,6 +41,9 @@ namespace Gremlin.Net.Structure.IO.GraphBinary4 private readonly RequestMessageSerializer _requestSerializer = new RequestMessageSerializer(); private readonly ResponseMessageSerializer _responseSerializer = new ResponseMessageSerializer(); + /// <inheritdoc /> + public string MimeType => SerializationTokens.GraphBinary4MimeType; + /// <summary> /// Initializes a new instance of the <see cref="GraphBinary4MessageSerializer" /> class /// with the default type serializer registry. diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphSON/GraphSON2MessageSerializer.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphSON/GraphSON2MessageSerializer.cs index e1fcb20884..2ed32e5ef9 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphSON/GraphSON2MessageSerializer.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphSON/GraphSON2MessageSerializer.cs @@ -28,15 +28,13 @@ namespace Gremlin.Net.Structure.IO.GraphSON /// </summary> public class GraphSON2MessageSerializer : GraphSONMessageSerializer { - private const string MimeType = SerializationTokens.GraphSON2MimeType; - /// <summary> /// Initializes a new instance of the <see cref="GraphSON2MessageSerializer" /> class with custom serializers. /// </summary> /// <param name="graphSONReader">The <see cref="GraphSON2Reader"/> used to deserialize from GraphSON.</param> /// <param name="graphSONWriter">The <see cref="GraphSON2Writer"/> used to serialize to GraphSON.</param> public GraphSON2MessageSerializer(GraphSON2Reader? graphSONReader = null, GraphSON2Writer? graphSONWriter = null) - : base(MimeType, graphSONReader ?? new GraphSON2Reader(), graphSONWriter ?? new GraphSON2Writer()) + : base(SerializationTokens.GraphSON2MimeType, graphSONReader ?? new GraphSON2Reader(), graphSONWriter ?? new GraphSON2Writer()) { } } diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphSON/GraphSON3MessageSerializer.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphSON/GraphSON3MessageSerializer.cs index 73a2eccb7c..a9cd69c2ec 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphSON/GraphSON3MessageSerializer.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphSON/GraphSON3MessageSerializer.cs @@ -28,15 +28,13 @@ namespace Gremlin.Net.Structure.IO.GraphSON /// </summary> public class GraphSON3MessageSerializer : GraphSONMessageSerializer { - private const string MimeType = SerializationTokens.GraphSON3MimeType; - /// <summary> /// Initializes a new instance of the <see cref="GraphSON3MessageSerializer" /> class with custom serializers. /// </summary> /// <param name="graphSONReader">The <see cref="GraphSON3Reader"/> used to deserialize from GraphSON.</param> /// <param name="graphSONWriter">The <see cref="GraphSON3Writer"/> used to serialize to GraphSON.</param> public GraphSON3MessageSerializer(GraphSON3Reader? graphSONReader = null, GraphSON3Writer? graphSONWriter = null) - : base(MimeType, graphSONReader ?? new GraphSON3Reader(), graphSONWriter ?? new GraphSON3Writer()) + : base(SerializationTokens.GraphSON3MimeType, graphSONReader ?? new GraphSON3Reader(), graphSONWriter ?? new GraphSON3Writer()) { } } diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphSON/GraphSONMessageSerializer.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphSON/GraphSONMessageSerializer.cs index 49b2013ca1..1dc16215ec 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphSON/GraphSONMessageSerializer.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphSON/GraphSONMessageSerializer.cs @@ -56,6 +56,9 @@ namespace Gremlin.Net.Structure.IO.GraphSON _graphSONWriter = graphSonWriter; } + /// <inheritdoc /> + public string MimeType => _mimeType; + /// <inheritdoc /> public virtual Task<byte[]> SerializeMessageAsync(RequestMessage requestMessage, CancellationToken cancellationToken = default) diff --git a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Docs/Reference/GremlinVariantsTests.cs b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Docs/Reference/GremlinVariantsTests.cs index 706556e91c..77c3f554dd 100644 --- a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Docs/Reference/GremlinVariantsTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Docs/Reference/GremlinVariantsTests.cs @@ -158,7 +158,9 @@ var response = // tag::submittingScriptsWithAuthentication[] var username = "username"; var password = "password"; -var gremlinServer = new GremlinServer("localhost", 8182, true, username, password); +var gremlinServer = new GremlinServer("localhost", 8182, enableSsl: true); +using var gremlinClient = new GremlinClient(gremlinServer, + interceptors: new[] { Auth.BasicAuth(username, password) }); // end::submittingScriptsWithAuthentication[] } diff --git a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/AuthIntegrationTests.cs b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/AuthIntegrationTests.cs new file mode 100644 index 0000000000..3b5c9eed07 --- /dev/null +++ b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/AuthIntegrationTests.cs @@ -0,0 +1,102 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Net.Http; +using System.Threading.Tasks; +using Gremlin.Net.Driver; +using Gremlin.Net.Driver.Remote; +using Gremlin.Net.Process.Traversal; +using Xunit; + +namespace Gremlin.Net.IntegrationTest.Driver +{ + /// <summary> + /// Integration tests for authentication interceptors against the secure Gremlin Server + /// (port 45941, SSL + SimpleAuthenticator). + /// </summary> + public class AuthIntegrationTests + { + private static readonly string TestHost = ConfigProvider.Configuration["TestServerIpAddress"]!; + private static readonly int TestSecurePort = Convert.ToInt32(ConfigProvider.Configuration["TestSecureServerPort"]); + + private GremlinClient CreateSecureClient(Func<HttpRequestContext, Task>[]? interceptors = null) + { + var gremlinServer = new GremlinServer(TestHost, TestSecurePort, enableSsl: true); + return new GremlinClient(gremlinServer, + connectionSettings: new ConnectionSettings { SkipCertificateValidation = true }, + interceptors: interceptors); + } + + [Fact] + public async Task ShouldAuthenticateWithBasicAuth() + { + // The secure server uses SimpleAuthenticator with credentials: stephen/password + using var gremlinClient = CreateSecureClient( + new[] { Auth.BasicAuth("stephen", "password") }); + + var response = await gremlinClient.SubmitAsync<long>("g.inject(1).count()"); + + Assert.Single(response); + } + + [Fact] + public async Task ShouldAuthenticateWithBasicAuthViaDriverRemoteConnection() + { + // Test through DriverRemoteConnection + traversal + using var client = CreateSecureClient( + new[] { Auth.BasicAuth("stephen", "password") }); + using var remote = new DriverRemoteConnection(client, "gmodern"); + var g = AnonymousTraversalSource.Traversal().With(remote); + + var count = await g.V().Count().Promise(t => t.Next()); + + Assert.True(count > 0); + } + + [Fact] + public async Task ShouldFailWithWrongCredentials() + { + using var gremlinClient = CreateSecureClient( + new[] { Auth.BasicAuth("stephen", "wrongpassword") }); + + // The server returns auth errors as JSON (not GraphBinary), so Connection + // extracts the message and throws HttpRequestException. + var ex = await Assert.ThrowsAsync<HttpRequestException>( + () => gremlinClient.SubmitAsync<long>("g.inject(1).count()")); + + Assert.Contains("incorrect", ex.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ShouldFailWithNoCredentials() + { + using var gremlinClient = CreateSecureClient(); + + var ex = await Assert.ThrowsAsync<HttpRequestException>( + () => gremlinClient.SubmitAsync<long>("g.inject(1).count()")); + + Assert.Contains("credentials", ex.Message, StringComparison.OrdinalIgnoreCase); + } + } +} diff --git a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Process/Traversal/DriverRemoteConnection/GraphTraversalTests.cs b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Process/Traversal/DriverRemoteConnection/GraphTraversalTests.cs index a3862de510..718e0fbe84 100644 --- a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Process/Traversal/DriverRemoteConnection/GraphTraversalTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Process/Traversal/DriverRemoteConnection/GraphTraversalTests.cs @@ -152,13 +152,9 @@ namespace Gremlin.Net.IntegrationTest.Process.Traversal.DriverRemoteConnection var g = AnonymousTraversalSource.Traversal().With(connection); var result = g.V(1).ValueMap<string, IList<object>>().Next(); - Assert.Equal( - new Dictionary<string, IList<object>> - { - { "age", new List<object> { 29 } }, - { "name", new List<object> { "marko" } } - }, - result); + Assert.True(result.Count >= 2); // .NET may receive an extra haltedTraversers key from the server which we currently just ignore + Assert.Equal(new List<object> { 29 }, result["age"]); + Assert.Equal(new List<object> { "marko" }, result["name"]); } [Fact] diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/AuthTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/AuthTests.cs new file mode 100644 index 0000000000..40f5239e44 --- /dev/null +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/AuthTests.cs @@ -0,0 +1,245 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Collections.Generic; +using System.Security.Cryptography; +using System.Text; +using System.Threading.Tasks; +using Amazon.Runtime; +using Gremlin.Net.Driver; +using Xunit; + +namespace Gremlin.Net.UnitTest.Driver +{ + public class AuthTests + { + private static HttpRequestContext CreateTestContext() + { + return new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary<string, string>(), new byte[] { 0x01 }); + } + + [Fact] + public async Task BasicAuthShouldSetCorrectAuthorizationHeader() + { + var interceptor = Auth.BasicAuth("user", "pass"); + var context = CreateTestContext(); + + await interceptor(context); + + var expected = "Basic " + Convert.ToBase64String(Encoding.UTF8.GetBytes("user:pass")); + Assert.Equal(expected, context.Headers["Authorization"]); + } + + [Fact] + public async Task BasicAuthShouldSetHeaderOnEveryInvocation() + { + var interceptor = Auth.BasicAuth("user", "pass"); + var context1 = CreateTestContext(); + var context2 = CreateTestContext(); + + await interceptor(context1); + await interceptor(context2); + + var expected = "Basic " + Convert.ToBase64String(Encoding.UTF8.GetBytes("user:pass")); + Assert.Equal(expected, context1.Headers["Authorization"]); + Assert.Equal(expected, context2.Headers["Authorization"]); + } + + [Fact] + public async Task BasicAuthShouldHandleColonsInPassword() + { + var interceptor = Auth.BasicAuth("user", "pass:with:colons"); + var context = CreateTestContext(); + + await interceptor(context); + + var expected = "Basic " + Convert.ToBase64String( + Encoding.UTF8.GetBytes("user:pass:with:colons")); + Assert.Equal(expected, context.Headers["Authorization"]); + } + + [Fact] + public async Task BasicAuthShouldHandleUnicodeCharacters() + { + var interceptor = Auth.BasicAuth("用户", "密码"); + var context = CreateTestContext(); + + await interceptor(context); + + var expected = "Basic " + Convert.ToBase64String( + Encoding.UTF8.GetBytes("用户:密码")); + Assert.Equal(expected, context.Headers["Authorization"]); + } + + [Fact] + public async Task BasicAuthShouldOverwriteExistingAuthorizationHeader() + { + var interceptor = Auth.BasicAuth("user", "pass"); + var context = CreateTestContext(); + context.Headers["Authorization"] = "Bearer old-token"; + + await interceptor(context); + + var expected = "Basic " + Convert.ToBase64String(Encoding.UTF8.GetBytes("user:pass")); + Assert.Equal(expected, context.Headers["Authorization"]); + } + + [Fact] + public async Task BasicAuthShouldHandleEmptyCredentials() + { + var interceptor = Auth.BasicAuth("", ""); + var context = CreateTestContext(); + + await interceptor(context); + + var expected = "Basic " + Convert.ToBase64String(Encoding.UTF8.GetBytes(":")); + Assert.Equal(expected, context.Headers["Authorization"]); + } + + // --- SigV4 Tests --- + + private static readonly BasicAWSCredentials TestBasicCredentials = + new BasicAWSCredentials("MOCK_ID", "MOCK_KEY"); + + private static readonly SessionAWSCredentials TestSessionCredentials = + new SessionAWSCredentials("MOCK_ID", "MOCK_KEY", "MOCK_TOKEN"); + + private static HttpRequestContext CreateSigv4TestContext(byte[]? body = null) + { + return new HttpRequestContext("POST", new Uri("https://example.com:8182/gremlin"), + new Dictionary<string, string> + { + { "Content-Type", "application/vnd.graphbinary-v4.0" }, + { "Accept", "application/vnd.graphbinary-v4.0" }, + }, + body ?? new byte[] { 0x84, 0x00 }); + } + + [Fact] + public async Task SigV4AuthShouldAddRequiredHeaders() + { + var interceptor = Auth.SigV4Auth("gremlin-east-1", "tinkerpop-sigv4", TestBasicCredentials); + var context = CreateSigv4TestContext(); + + await interceptor(context); + + Assert.True(context.Headers.ContainsKey("Authorization")); + Assert.True(context.Headers.ContainsKey("X-Amz-Date")); + Assert.True(context.Headers.ContainsKey("x-amz-content-sha256")); + Assert.True(context.Headers.ContainsKey("Host")); + } + + [Fact] + public async Task SigV4AuthShouldHaveCorrectAuthorizationPrefix() + { + var interceptor = Auth.SigV4Auth("gremlin-west-2", "tinkerpop-sigv4", TestBasicCredentials); + var context = CreateSigv4TestContext(); + + await interceptor(context); + + Assert.StartsWith("AWS4-HMAC-SHA256 Credential=MOCK_ID", context.Headers["Authorization"]); + Assert.Contains("gremlin-west-2/tinkerpop-sigv4/aws4_request", context.Headers["Authorization"]); + } + + [Fact] + public async Task SigV4AuthShouldAddSessionTokenForTemporaryCredentials() + { + var interceptor = Auth.SigV4Auth("gremlin-east-1", "tinkerpop-sigv4", TestSessionCredentials); + var context = CreateSigv4TestContext(); + + await interceptor(context); + + Assert.True(context.Headers.ContainsKey("X-Amz-Security-Token")); + Assert.Equal("MOCK_TOKEN", context.Headers["X-Amz-Security-Token"]); + } + + [Fact] + public async Task SigV4AuthShouldNotAddSessionTokenForPermanentCredentials() + { + var interceptor = Auth.SigV4Auth("gremlin-east-1", "tinkerpop-sigv4", TestBasicCredentials); + var context = CreateSigv4TestContext(); + + await interceptor(context); + + Assert.False(context.Headers.ContainsKey("X-Amz-Security-Token")); + } + + [Fact] + public async Task SigV4AuthContentHashShouldMatchBodySha256() + { + var body = new byte[] { 0x84, 0x00, 0xFD, 0x01 }; + var interceptor = Auth.SigV4Auth("gremlin-east-1", "tinkerpop-sigv4", TestBasicCredentials); + var context = CreateSigv4TestContext(body); + + await interceptor(context); + + using var sha256 = SHA256.Create(); + var expectedHash = BitConverter.ToString(sha256.ComputeHash(body)) + .Replace("-", "").ToLowerInvariant(); + Assert.Equal(expectedHash, context.Headers["x-amz-content-sha256"]); + } + + [Fact] + public async Task SigV4AuthShouldHandleEmptyBody() + { + var interceptor = Auth.SigV4Auth("gremlin-east-1", "tinkerpop-sigv4", TestBasicCredentials); + var context = CreateSigv4TestContext(Array.Empty<byte>()); + + await interceptor(context); + + Assert.True(context.Headers.ContainsKey("Authorization")); + Assert.Equal("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + context.Headers["x-amz-content-sha256"]); + } + + [Fact] + public async Task SigV4AuthShouldSetCorrectHost() + { + var interceptor = Auth.SigV4Auth("gremlin-east-1", "tinkerpop-sigv4", TestBasicCredentials); + var context = CreateSigv4TestContext(); + + await interceptor(context); + + Assert.Equal("example.com", context.Headers["Host"]); + } + + [Fact] + public async Task SigV4AuthShouldThrowWhenBodyIsNotByteArray() + { + var interceptor = Auth.SigV4Auth("gremlin-east-1", "tinkerpop-sigv4", TestBasicCredentials); + var context = new HttpRequestContext("POST", new Uri("https://example.com:8182/gremlin"), + new Dictionary<string, string> + { + { "Content-Type", "application/vnd.graphbinary-v4.0" }, + }, + "not-bytes"); + + var ex = await Assert.ThrowsAsync<InvalidOperationException>( + () => interceptor(context)); + + Assert.Contains("byte[]", ex.Message); + } + } +} diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs index 37308c1c5f..9857552a07 100644 --- a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs @@ -84,7 +84,7 @@ namespace Gremlin.Net.UnitTest.Driver var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); await connection.SubmitAsync<object>(CreateTestRequest()); @@ -99,7 +99,7 @@ namespace Gremlin.Net.UnitTest.Driver var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); await connection.SubmitAsync<object>(CreateTestRequest()); @@ -114,7 +114,7 @@ namespace Gremlin.Net.UnitTest.Driver var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); await connection.SubmitAsync<object>(CreateTestRequest()); @@ -128,7 +128,7 @@ namespace Gremlin.Net.UnitTest.Driver var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); await connection.SubmitAsync<object>(CreateTestRequest()); @@ -142,7 +142,7 @@ namespace Gremlin.Net.UnitTest.Driver var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableCompression = true }; - using var connection = new Connection(TestUri, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); await connection.SubmitAsync<object>(CreateTestRequest()); @@ -157,7 +157,7 @@ namespace Gremlin.Net.UnitTest.Driver var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableCompression = false }; - using var connection = new Connection(TestUri, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); await connection.SubmitAsync<object>(CreateTestRequest()); @@ -172,7 +172,7 @@ namespace Gremlin.Net.UnitTest.Driver var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableUserAgentOnConnect = true }; - using var connection = new Connection(TestUri, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); await connection.SubmitAsync<object>(CreateTestRequest()); @@ -186,7 +186,7 @@ namespace Gremlin.Net.UnitTest.Driver var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableUserAgentOnConnect = false }; - using var connection = new Connection(TestUri, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); await connection.SubmitAsync<object>(CreateTestRequest()); @@ -200,7 +200,7 @@ namespace Gremlin.Net.UnitTest.Driver var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { BulkResults = true }; - using var connection = new Connection(TestUri, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); await connection.SubmitAsync<object>(CreateTestRequest()); @@ -215,7 +215,7 @@ namespace Gremlin.Net.UnitTest.Driver var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { BulkResults = false }; - using var connection = new Connection(TestUri, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); await connection.SubmitAsync<object>(CreateTestRequest()); @@ -241,7 +241,7 @@ namespace Gremlin.Net.UnitTest.Driver var (httpClient, handler) = CreateMockHttpClient(compressedBytes, "deflate"); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableCompression = true }; - using var connection = new Connection(TestUri, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); // Should not throw — decompression should work var result = await connection.SubmitAsync<object>(CreateTestRequest()); @@ -255,7 +255,7 @@ namespace Gremlin.Net.UnitTest.Driver var (httpClient, _) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - var connection = new Connection(TestUri, serializer, settings, httpClient); + var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); connection.Dispose(); // Double dispose should not throw @@ -267,9 +267,11 @@ namespace Gremlin.Net.UnitTest.Driver return RequestMessage.Build("g.V()").AddG("g").Create(); } - private static IMessageSerializer CreateMockSerializer() + private static IMessageSerializer CreateMockSerializer( + string mimeType = SerializationTokens.GraphBinary4MimeType) { var serializer = Substitute.For<IMessageSerializer>(); + serializer.MimeType.Returns(mimeType); serializer.SerializeMessageAsync(Arg.Any<RequestMessage>(), Arg.Any<CancellationToken>()) .Returns(Task.FromResult(new byte[] { 0x84 })); serializer.DeserializeMessageAsync(Arg.Any<byte[]>(), Arg.Any<CancellationToken>()) @@ -278,6 +280,668 @@ namespace Gremlin.Net.UnitTest.Driver return serializer; } + [Fact] + public async Task ShouldCallInterceptorsInOrder() + { + var (httpClient, _) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + var callOrder = new List<int>(); + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => { callOrder.Add(1); return Task.CompletedTask; }, + ctx => { callOrder.Add(2); return Task.CompletedTask; }, + ctx => { callOrder.Add(3); return Task.CompletedTask; }, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.Equal(new List<int> { 1, 2, 3 }, callOrder); + } + + [Fact] + public async Task ShouldPropagateInterceptorException() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + var expectedException = new InvalidOperationException("interceptor failed"); + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + _ => throw expectedException, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, interceptors); + + var ex = await Assert.ThrowsAsync<InvalidOperationException>( + () => connection.SubmitAsync<object>(CreateTestRequest())); + + Assert.Same(expectedException, ex); + // HTTP request should not have been sent + Assert.Null(handler.CapturedRequest); + } + + [Fact] + public async Task ShouldAllowInterceptorToModifyHeaders() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + ctx.Headers["Authorization"] = "Basic dGVzdDp0ZXN0"; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(handler.CapturedRequest); + Assert.True(handler.CapturedRequest!.Headers.Contains("Authorization")); + Assert.Equal("Basic dGVzdDp0ZXN0", + handler.CapturedRequest.Headers.GetValues("Authorization").First()); + } + + [Fact] + public async Task ShouldSeeEarlierInterceptorModifications() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + string? observedHeader = null; + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + ctx.Headers["X-Custom"] = "first"; + return Task.CompletedTask; + }, + ctx => + { + observedHeader = ctx.Headers.ContainsKey("X-Custom") ? ctx.Headers["X-Custom"] : null; + ctx.Headers["X-Custom"] = "second"; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.Equal("first", observedHeader); + Assert.NotNull(handler.CapturedRequest); + Assert.Equal("second", + handler.CapturedRequest!.Headers.GetValues("X-Custom").First()); + } + + [Fact] + public async Task ShouldSerializeBeforeInterceptorsWhenRequestSerializerProvided() + { + var (httpClient, _) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + object? observedBody = null; + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + observedBody = ctx.Body; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.IsType<byte[]>(observedBody); + } + + [Fact] + public async Task ShouldPassRequestMessageWhenRequestSerializerIsNull() + { + var (httpClient, _) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + object? observedBody = null; + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + observedBody = ctx.Body; + // Serialize the body so the request can proceed + ctx.Body = new byte[] { 0x84 }; + ctx.Headers["Content-Type"] = "application/vnd.graphbinary-v4.0"; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.IsType<RequestMessage>(observedBody); + } + + [Fact] + public async Task ShouldThrowWhenBodyIsNotByteArrayAfterInterceptors() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + + // No interceptor serializes the body + var interceptors = new List<Func<HttpRequestContext, Task>>(); + + using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + interceptors); + + var ex = await Assert.ThrowsAsync<InvalidOperationException>( + () => connection.SubmitAsync<object>(CreateTestRequest())); + + Assert.Contains("byte[] or HttpContent", ex.Message); + Assert.Contains("RequestMessage", ex.Message); + // HTTP request should not have been sent + Assert.Null(handler.CapturedRequest); + } + + [Fact] + public async Task ShouldSucceedWhenInterceptorSerializesBodyWithNullRequestSerializer() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + async ctx => + { + if (ctx.Body is RequestMessage msg) + { + ctx.Body = await serializer.SerializeMessageAsync(msg); + ctx.Headers["Content-Type"] = "application/vnd.graphbinary-v4.0"; + } + }, + }; + + using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + interceptors); + + var result = await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(result); + Assert.NotNull(handler.CapturedRequest); + } + + [Fact] + public async Task ShouldNotSetContentTypeWhenRequestSerializerIsNull() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + bool? hadContentType = null; + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + hadContentType = ctx.Headers.ContainsKey("Content-Type"); + // Serialize so the request can proceed + ctx.Body = new byte[] { 0x84 }; + ctx.Headers["Content-Type"] = "application/vnd.graphbinary-v4.0"; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.False(hadContentType, "Content-Type should not be set before interceptors when requestSerializer is null"); + Assert.NotNull(handler.CapturedRequest); + Assert.Equal("application/vnd.graphbinary-v4.0", + handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); + } + + [Fact] + public async Task ShouldWorkWithEmptyInterceptorList() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + var interceptors = new List<Func<HttpRequestContext, Task>>(); + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + interceptors); + + var result = await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(result); + Assert.NotNull(handler.CapturedRequest); + } + + [Fact] + public async Task ShouldWorkWithNoInterceptorsParameter() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + + var result = await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(result); + Assert.NotNull(handler.CapturedRequest); + } + + [Fact] + public async Task ShouldAllowInterceptorToModifyUri() + { + var altUri = new Uri("http://other-host:9999/gremlin"); + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + ctx.Uri = altUri; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(handler.CapturedRequest); + Assert.Equal(altUri, handler.CapturedRequest!.RequestUri); + } + + [Fact] + public async Task ShouldAllowInterceptorToReplaceBody() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + var replacementBody = new byte[] { 0x01, 0x02, 0x03 }; + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + ctx.Body = replacementBody; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(handler.CapturedRequest); + var sentBytes = await handler.CapturedRequest!.Content!.ReadAsByteArrayAsync(); + Assert.Equal(replacementBody, sentBytes); + } + + [Fact] + public async Task ShouldStopInterceptorChainOnException() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + var secondCalled = false; + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + _ => throw new InvalidOperationException("first failed"), + _ => + { + secondCalled = true; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + interceptors); + + await Assert.ThrowsAsync<InvalidOperationException>( + () => connection.SubmitAsync<object>(CreateTestRequest())); + + Assert.False(secondCalled, "Second interceptor should not run when first throws"); + Assert.Null(handler.CapturedRequest); + } + + [Fact] + public async Task ShouldSupportAsyncInterceptors() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + var interceptorCompleted = false; + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + async ctx => + { + // Simulate async work (e.g., fetching a token) + await Task.Delay(1); + ctx.Headers["X-Async-Header"] = "async-value"; + interceptorCompleted = true; + }, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.True(interceptorCompleted); + Assert.NotNull(handler.CapturedRequest); + Assert.Equal("async-value", + handler.CapturedRequest!.Headers.GetValues("X-Async-Header").First()); + } + + [Fact] + public async Task ShouldAllowInterceptorToReadSerializedBody() + { + var (httpClient, _) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + byte[]? capturedBody = null; + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + capturedBody = ctx.Body as byte[]; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(capturedBody); + // The mock serializer returns { 0x84 } + Assert.Equal(new byte[] { 0x84 }, capturedBody); + } + + [Fact] + public async Task ShouldWorkWithSingleInterceptor() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + var called = false; + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + called = true; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.True(called); + Assert.NotNull(handler.CapturedRequest); + } + + [Fact] + public async Task ShouldAllowInterceptorToRemoveHeader() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings { EnableUserAgentOnConnect = true }; + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + ctx.Headers.Remove("User-Agent"); + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(handler.CapturedRequest); + Assert.False(handler.CapturedRequest!.Headers.Contains("User-Agent"), + "Interceptor should be able to remove headers set by Connection"); + } + + [Fact] + public async Task ShouldThrowWhenBodyIsNullAfterInterceptors() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + ctx.Body = null!; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + interceptors); + + var ex = await Assert.ThrowsAsync<InvalidOperationException>( + () => connection.SubmitAsync<object>(CreateTestRequest())); + + Assert.Contains("null", ex.Message); + Assert.Null(handler.CapturedRequest); + } + + [Fact] + public async Task ShouldPreserveMultipleInterceptorHeaderModifications() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + ctx.Headers["X-First"] = "one"; + return Task.CompletedTask; + }, + ctx => + { + ctx.Headers["X-Second"] = "two"; + return Task.CompletedTask; + }, + ctx => + { + ctx.Headers["X-Third"] = "three"; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(handler.CapturedRequest); + Assert.Equal("one", handler.CapturedRequest!.Headers.GetValues("X-First").First()); + Assert.Equal("two", handler.CapturedRequest.Headers.GetValues("X-Second").First()); + Assert.Equal("three", handler.CapturedRequest.Headers.GetValues("X-Third").First()); + } + + [Fact] + public async Task ShouldAllowCustomContentTypeWhenRequestSerializerIsNull() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + ctx.Body = new byte[] { 0x01 }; + ctx.Headers["Content-Type"] = "application/json"; + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(handler.CapturedRequest); + Assert.Equal("application/json", + handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); + } + + [Fact] + public async Task ShouldUseResponseSerializerWhenRequestSerializerIsNull() + { + var (httpClient, _) = CreateMockHttpClient(); + var requestSerializer = CreateMockSerializer(); + var responseSerializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + async ctx => + { + if (ctx.Body is RequestMessage msg) + { + ctx.Body = await requestSerializer.SerializeMessageAsync(msg); + ctx.Headers["Content-Type"] = "application/vnd.graphbinary-v4.0"; + } + }, + }; + + using var connection = new Connection(TestUri, null, responseSerializer, settings, httpClient, + interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + // Verify the response serializer was called for deserialization + await responseSerializer.Received(1) + .DeserializeMessageAsync(Arg.Any<byte[]>(), Arg.Any<CancellationToken>()); + // Verify the request serializer was NOT called by Connection (interceptor called it directly) + await requestSerializer.DidNotReceive() + .DeserializeMessageAsync(Arg.Any<byte[]>(), Arg.Any<CancellationToken>()); + } + + [Fact] + public async Task ShouldAcceptHttpContentBodyFromInterceptor() + { + var (httpClient, handler) = CreateMockHttpClient(); + var serializer = CreateMockSerializer(); + var settings = new ConnectionSettings(); + var contentBytes = new byte[] { 0x01, 0x02, 0x03 }; + + var interceptors = new List<Func<HttpRequestContext, Task>> + { + ctx => + { + ctx.Body = new ByteArrayContent(contentBytes); + return Task.CompletedTask; + }, + }; + + using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + interceptors); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(handler.CapturedRequest); + var sentBytes = await handler.CapturedRequest!.Content!.ReadAsByteArrayAsync(); + Assert.Equal(contentBytes, sentBytes); + } + + [Fact] + public async Task ShouldUseResponseSerializerMimeTypeForAcceptHeader() + { + var (httpClient, handler) = CreateMockHttpClient(); + var requestSerializer = CreateMockSerializer("application/custom-request"); + var responseSerializer = CreateMockSerializer("application/custom-response"); + var settings = new ConnectionSettings(); + using var connection = new Connection(TestUri, requestSerializer, responseSerializer, settings, httpClient); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(handler.CapturedRequest); + Assert.Contains(handler.CapturedRequest!.Headers.Accept, + h => h.MediaType == "application/custom-response"); + } + + [Fact] + public async Task ShouldUseRequestSerializerMimeTypeForContentTypeHeader() + { + var (httpClient, handler) = CreateMockHttpClient(); + var requestSerializer = CreateMockSerializer("application/custom-request"); + var responseSerializer = CreateMockSerializer("application/custom-response"); + var settings = new ConnectionSettings(); + using var connection = new Connection(TestUri, requestSerializer, responseSerializer, settings, httpClient); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(handler.CapturedRequest); + Assert.Equal("application/custom-request", + handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); + } + + [Fact] + public async Task ShouldUseDifferentMimeTypesForRequestAndResponseSerializers() + { + var (httpClient, handler) = CreateMockHttpClient(); + var requestSerializer = CreateMockSerializer("application/vnd.custom-request-v1.0"); + var responseSerializer = CreateMockSerializer("application/vnd.custom-response-v2.0"); + var settings = new ConnectionSettings(); + using var connection = new Connection(TestUri, requestSerializer, responseSerializer, settings, httpClient); + + await connection.SubmitAsync<object>(CreateTestRequest()); + + Assert.NotNull(handler.CapturedRequest); + // Content-Type comes from request serializer + Assert.Equal("application/vnd.custom-request-v1.0", + handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); + // Accept comes from response serializer + Assert.Contains(handler.CapturedRequest.Headers.Accept, + h => h.MediaType == "application/vnd.custom-response-v2.0"); + } + /// <summary> /// A test HttpMessageHandler that captures the request and returns a canned response. /// </summary> diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs new file mode 100644 index 0000000000..86d2bc678d --- /dev/null +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs @@ -0,0 +1,131 @@ +#region License + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#endregion + +using System; +using System.Collections.Generic; +using System.Text; +using Gremlin.Net.Driver; +using Gremlin.Net.Driver.Messages; +using Xunit; + +namespace Gremlin.Net.UnitTest.Driver +{ + public class HttpRequestContextTests + { + [Fact] + public void ShouldConstructWithByteArrayBody() + { + var method = "POST"; + var uri = new Uri("http://localhost:8182/gremlin"); + var headers = new Dictionary<string, string> { { "Content-Type", "application/vnd.graphbinary-v4.0" } }; + var body = new byte[] { 0x01, 0x02, 0x03 }; + + var context = new HttpRequestContext(method, uri, headers, body); + + Assert.Equal(method, context.Method); + Assert.Equal(uri, context.Uri); + Assert.Same(headers, context.Headers); + Assert.Same(body, context.Body); + } + + [Fact] + public void ShouldConstructWithRequestMessageBody() + { + var method = "POST"; + var uri = new Uri("http://localhost:8182/gremlin"); + var headers = new Dictionary<string, string>(); + var body = RequestMessage.Build("g.V()").AddG("g").Create(); + + var context = new HttpRequestContext(method, uri, headers, body); + + Assert.Same(body, context.Body); + Assert.IsType<RequestMessage>(context.Body); + } + + [Fact] + public void ShouldAllowMutatingProperties() + { + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary<string, string>(), new byte[] { 0x01 }); + + var newUri = new Uri("https://example.com/gremlin"); + context.Method = "PUT"; + context.Uri = newUri; + context.Body = new byte[] { 0x02, 0x03 }; + context.Headers["Authorization"] = "Basic dGVzdA=="; + + Assert.Equal("PUT", context.Method); + Assert.Equal(newUri, context.Uri); + Assert.Equal(new byte[] { 0x02, 0x03 }, context.Body); + Assert.Equal("Basic dGVzdA==", context.Headers["Authorization"]); + } + + [Fact] + public void ShouldComputePayloadHashForKnownBody() + { + // SHA-256 of "hello" = 2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824 + var body = Encoding.UTF8.GetBytes("hello"); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary<string, string>(), body); + + var hash = context.GetPayloadHash(); + + Assert.Equal("2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hash); + } + + [Fact] + public void ShouldComputePayloadHashForEmptyBody() + { + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary<string, string>(), Array.Empty<byte>()); + + var hash = context.GetPayloadHash(); + + Assert.Equal("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hash); + } + + [Fact] + public void ShouldThrowWhenComputingPayloadHashForNonByteArrayBody() + { + var body = RequestMessage.Build("g.V()").AddG("g").Create(); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary<string, string>(), body); + + var ex = Assert.Throws<InvalidOperationException>(() => context.GetPayloadHash()); + + Assert.Contains("RequestMessage", ex.Message); + Assert.Contains("byte[]", ex.Message); + } + + [Fact] + public void ShouldThrowWhenComputingPayloadHashForNullBody() + { + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary<string, string>(), null!); + + var ex = Assert.Throws<InvalidOperationException>(() => context.GetPayloadHash()); + + Assert.Contains("null", ex.Message); + } + } +}
