This is an automated email from the ASF dual-hosted git repository.
xiazcy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
The following commit(s) were added to refs/heads/master by this push:
new da1619db75 Update order of operation for request interceptor and
serialization in gremlin-go (#3358)
da1619db75 is described below
commit da1619db7556d769c72b8c26f1cd5f7d36bf221d
Author: Yang Xia <[email protected]>
AuthorDate: Thu Mar 26 14:13:47 2026 -0700
Update order of operation for request interceptor and serialization in
gremlin-go (#3358)
* add proper error handling for non-graphbinary cases
* Update order of interceptor and serialization of request
* export request and rename to RequestMessage for clarity and public access
like other GLVs
* added wg to ensure in-flight goroutines complete & added response body
drain before close to prevent TCP RST errors
---
CHANGELOG.asciidoc | 4 +-
gremlin-go/driver/auth.go | 25 +-
gremlin-go/driver/auth_test.go | 26 +-
gremlin-go/driver/connection.go | 178 ++++++++------
gremlin-go/driver/connection_test.go | 188 +++++++++++++-
gremlin-go/driver/interceptor.go | 113 +++++++++
gremlin-go/driver/interceptor_test.go | 450 ++++++++++++++++++++++++++++++++++
gremlin-go/driver/request.go | 16 +-
gremlin-go/driver/request_test.go | 14 +-
gremlin-go/driver/serializer.go | 6 +-
gremlin-go/driver/serializer_test.go | 10 +-
11 files changed, 912 insertions(+), 118 deletions(-)
diff --git a/CHANGELOG.asciidoc b/CHANGELOG.asciidoc
index a3ea259df7..99ab01d415 100644
--- a/CHANGELOG.asciidoc
+++ b/CHANGELOG.asciidoc
@@ -41,8 +41,10 @@
image::https://raw.githubusercontent.com/apache/tinkerpop/master/docs/static/ima
* Removed deprecated `Graph.traversal()` method in JS in favor of the
anonymous `traversal()` function.
* Replace `Bytecode` with `GremlinLang` & update serialization to GraphBinary
4 for `gremlin-go`.
* Added `RequestInterceptor` to `gremlin-go` with `auth` reference
implementations to replace `authInfo`.
-* Refactored GraphBinary serializers to use `io.Writer` and `io.Reader`
instead of `*bytes.Buffer` for streaming capacities.
+* Refactored GraphBinary serializers in `gremlin-go` to use `io.Writer` and
`io.Reader` instead of `*bytes.Buffer` for streaming capacities.
* Refactored `httpProtocol` and `httpTransport` in `gremlin-go` into single
`connection.go` that handles HTTP request and response.
+* Reordered interceptor chain in `gremlin-go` so interceptors access raw
request before serialization.
+* Exported `request` in `gremlin-go` as `RequestMessage` with public
`Gremlin`/`Fields` for clarity, access and consistency.
* Refactored result handling in `gremlin-driver` by merging `ResultQueue` into
`ResultSet`.
* Replace `Bytecode` with `GremlinLang` in `gremlin-dotnet`.
* Replace `WebSocket` with `HTTP` (non-streaming) in `gremlin-dotnet`.
diff --git a/gremlin-go/driver/auth.go b/gremlin-go/driver/auth.go
index 74ca43a444..e0f279110a 100644
--- a/gremlin-go/driver/auth.go
+++ b/gremlin-go/driver/auth.go
@@ -22,6 +22,7 @@ package gremlingo
import (
"context"
"encoding/base64"
+ "fmt"
"sync"
"time"
@@ -39,19 +40,22 @@ func BasicAuth(username, password string)
RequestInterceptor {
}
}
-// Sigv4Auth returns a RequestInterceptor that signs requests using AWS SigV4.
+// SigV4Auth returns a RequestInterceptor that signs requests using AWS SigV4.
// It uses the default AWS credential chain (env vars, shared config, IAM
role, etc.)
-func Sigv4Auth(region, service string) RequestInterceptor {
- return Sigv4AuthWithCredentials(region, service, nil)
+func SigV4Auth(region, service string) RequestInterceptor {
+ return SigV4AuthWithCredentials(region, service, nil)
}
-// Sigv4AuthWithCredentials returns a RequestInterceptor that signs requests
using AWS SigV4
+// SigV4AuthWithCredentials returns a RequestInterceptor that signs requests
using AWS SigV4
// with the provided credentials provider. If provider is nil, uses default
credential chain.
+// If the request body has not been serialized yet (*RequestMessage), it is
automatically
+// serialized to GraphBinary before signing.
//
// Caches the signer and credentials provider for efficiency.
-func Sigv4AuthWithCredentials(region, service string, credentialsProvider
aws.CredentialsProvider) RequestInterceptor {
+func SigV4AuthWithCredentials(region, service string, credentialsProvider
aws.CredentialsProvider) RequestInterceptor {
// Create signer once - it's stateless and safe to reuse
signer := v4.NewSigner()
+ serialize := SerializeRequest()
// Cache for resolved credentials provider (lazy initialization)
var cachedProvider aws.CredentialsProvider
@@ -59,6 +63,17 @@ func Sigv4AuthWithCredentials(region, service string,
credentialsProvider aws.Cr
var providerErr error
return func(req *HttpRequest) error {
+ // If Body is still *RequestMessage, serialize it to
GraphBinary before signing.
+ if _, ok := req.Body.(*RequestMessage); ok {
+ if err := serialize(req); err != nil {
+ return fmt.Errorf("SigV4 auto-serialization
failed: %w", err)
+ }
+ }
+
+ if _, ok := req.Body.([]byte); !ok {
+ return fmt.Errorf("SigV4 signing requires body to be
[]byte; got %T", req.Body)
+ }
+
ctx := context.Background()
// Resolve credentials provider once if not provided
diff --git a/gremlin-go/driver/auth_test.go b/gremlin-go/driver/auth_test.go
index ba60f6e6c5..7ec4079b6e 100644
--- a/gremlin-go/driver/auth_test.go
+++ b/gremlin-go/driver/auth_test.go
@@ -30,7 +30,7 @@ import (
)
func createMockRequest() *HttpRequest {
- req, _ := NewHttpRequest("POST", "https://localhost:8182/gremlin")
+ req, _ := NewHttpRequest("POST", "https://test_url:8182/gremlin")
req.Headers.Set("Content-Type", graphBinaryMimeType)
req.Headers.Set("Accept", graphBinaryMimeType)
req.Body = []byte(`{"gremlin":"g.V()"}`)
@@ -72,24 +72,24 @@ func (m *mockCredentialsProvider) Retrieve(ctx
context.Context) (aws.Credentials
}, nil
}
-func TestSigv4Auth(t *testing.T) {
+func TestSigV4Auth(t *testing.T) {
t.Run("adds signed headers", func(t *testing.T) {
req := createMockRequest()
assert.Empty(t, req.Headers.Get("Authorization"))
assert.Empty(t, req.Headers.Get("X-Amz-Date"))
provider := &mockCredentialsProvider{
- accessKey: "MOCK_ACCESS_KEY",
- secretKey: "MOCK_SECRET_KEY",
+ accessKey: "MOCK_ID",
+ secretKey: "MOCK_KEY",
}
- interceptor := Sigv4AuthWithCredentials("us-west-2",
"neptune-db", provider)
+ interceptor := SigV4AuthWithCredentials("gremlin-east-1",
"tinkerpop-sigv4", provider)
err := interceptor(req)
assert.NoError(t, err)
assert.NotEmpty(t, req.Headers.Get("X-Amz-Date"))
authHeader := req.Headers.Get("Authorization")
- assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256
Credential=MOCK_ACCESS_KEY"))
- assert.Contains(t, authHeader,
"us-west-2/neptune-db/aws4_request")
+ assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256
Credential=MOCK_ID"))
+ assert.Contains(t, authHeader,
"gremlin-east-1/tinkerpop-sigv4/aws4_request")
assert.Contains(t, authHeader, "Signature=")
})
@@ -98,17 +98,17 @@ func TestSigv4Auth(t *testing.T) {
assert.Empty(t, req.Headers.Get("X-Amz-Security-Token"))
provider := &mockCredentialsProvider{
- accessKey: "MOCK_ACCESS_KEY",
- secretKey: "MOCK_SECRET_KEY",
- sessionToken: "MOCK_SESSION_TOKEN",
+ accessKey: "MOCK_ID",
+ secretKey: "MOCK_KEY",
+ sessionToken: "MOCK_TOKEN",
}
- interceptor := Sigv4AuthWithCredentials("us-west-2",
"neptune-db", provider)
+ interceptor := SigV4AuthWithCredentials("gremlin-east-1",
"tinkerpop-sigv4", provider)
err := interceptor(req)
assert.NoError(t, err)
- assert.Equal(t, "MOCK_SESSION_TOKEN",
req.Headers.Get("X-Amz-Security-Token"))
+ assert.Equal(t, "MOCK_TOKEN",
req.Headers.Get("X-Amz-Security-Token"))
authHeader := req.Headers.Get("Authorization")
assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256
Credential="))
- assert.Contains(t, authHeader, "Signature=")
+ assert.Contains(t, authHeader,
"gremlin-east-1/tinkerpop-sigv4/aws4_request")
})
}
diff --git a/gremlin-go/driver/connection.go b/gremlin-go/driver/connection.go
index 184efb681b..54def5821a 100644
--- a/gremlin-go/driver/connection.go
+++ b/gremlin-go/driver/connection.go
@@ -22,69 +22,17 @@ package gremlingo
import (
"bytes"
"compress/zlib"
- "crypto/sha256"
"crypto/tls"
- "encoding/hex"
+ "encoding/json"
+ "fmt"
"io"
"net"
"net/http"
- "net/url"
+ "strings"
+ "sync"
"time"
)
-// Common HTTP header keys
-const (
- HeaderContentType = "Content-Type"
- HeaderAccept = "Accept"
- HeaderUserAgent = "User-Agent"
- HeaderAcceptEncoding = "Accept-Encoding"
- HeaderAuthorization = "Authorization"
-)
-
-// HttpRequest represents an HTTP request that can be modified by interceptors.
-type HttpRequest struct {
- Method string
- URL *url.URL
- Headers http.Header
- Body []byte
-}
-
-// NewHttpRequest creates a new HttpRequest with the given method and URL.
-func NewHttpRequest(method, rawURL string) (*HttpRequest, error) {
- u, err := url.Parse(rawURL)
- if err != nil {
- return nil, err
- }
- return &HttpRequest{
- Method: method,
- URL: u,
- Headers: make(http.Header),
- }, nil
-}
-
-// ToStdRequest converts HttpRequest to a standard http.Request for signing.
-// Returns nil if the request cannot be created (invalid method or URL).
-func (r *HttpRequest) ToStdRequest() (*http.Request, error) {
- req, err := http.NewRequest(r.Method, r.URL.String(),
bytes.NewReader(r.Body))
- if err != nil {
- return nil, err
- }
- req.Header = r.Headers
- return req, nil
-}
-
-// PayloadHash returns the SHA256 hash of the request body for SigV4 signing.
-func (r *HttpRequest) PayloadHash() string {
- if len(r.Body) == 0 {
- return
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of
empty string
- }
- h := sha256.Sum256(r.Body)
- return hex.EncodeToString(h[:])
-}
-
-// RequestInterceptor is a function that modifies an HTTP request before it is
sent.
-type RequestInterceptor func(*HttpRequest) error
-
// connectionSettings holds configuration for the connection.
type connectionSettings struct {
tlsConfig *tls.Config
@@ -106,6 +54,7 @@ type connection struct {
logHandler *logHandler
serializer *GraphBinarySerializer
interceptors []RequestInterceptor
+ wg sync.WaitGroup
}
// Connection pool defaults aligned with Java driver
@@ -171,21 +120,19 @@ func (c *connection) AddInterceptor(interceptor
RequestInterceptor) {
}
// submit sends request and streams results directly to ResultSet
-func (c *connection) submit(req *request) (ResultSet, error) {
+func (c *connection) submit(req *RequestMessage) (ResultSet, error) {
rs := newChannelResultSet()
- data, err := c.serializer.SerializeMessage(req)
- if err != nil {
- rs.Close()
- return rs, err
- }
-
- go c.executeAndStream(data, rs)
+ c.wg.Add(1)
+ go func() {
+ defer c.wg.Done()
+ c.executeAndStream(req, rs)
+ }()
return rs, nil
}
-func (c *connection) executeAndStream(data []byte, rs ResultSet) {
+func (c *connection) executeAndStream(req *RequestMessage, rs ResultSet) {
defer rs.Close()
// Create HttpRequest for interceptors
@@ -195,12 +142,15 @@ func (c *connection) executeAndStream(data []byte, rs
ResultSet) {
rs.setError(err)
return
}
- httpReq.Body = data
// Set default headers before interceptors
c.setHttpRequestHeaders(httpReq)
- // Apply interceptors
+ // Set Body to the raw *RequestMessage so interceptors can
inspect/modify it
+ httpReq.Body = req
+
+ // Apply interceptors — they see *RequestMessage in Body
(pre-serialization).
+ // Interceptors may replace Body with []byte, io.Reader, or
*http.Request.
for _, interceptor := range c.interceptors {
if err := interceptor(httpReq); err != nil {
c.logHandler.logf(Error, failedToSendRequest,
err.Error())
@@ -209,27 +159,90 @@ func (c *connection) executeAndStream(data []byte, rs
ResultSet) {
}
}
- // Create actual http.Request from HttpRequest
- req, err := http.NewRequest(httpReq.Method, httpReq.URL.String(),
bytes.NewReader(httpReq.Body))
- if err != nil {
- c.logHandler.logf(Error, failedToSendRequest, err.Error())
- rs.setError(err)
+ // After interceptors, serialize if Body is still *RequestMessage
+ if r, ok := httpReq.Body.(*RequestMessage); ok {
+ if c.serializer != nil {
+ data, err := c.serializer.SerializeMessage(r)
+ if err != nil {
+ c.logHandler.logf(Error, failedToSendRequest,
err.Error())
+ rs.setError(err)
+ return
+ }
+ httpReq.Body = data
+ } else {
+ errMsg := "request body was not serialized; either
provide a serializer or add an interceptor that serializes the request"
+ c.logHandler.logf(Error, failedToSendRequest, errMsg)
+ rs.setError(fmt.Errorf("%s", errMsg))
+ return
+ }
+ }
+
+ // Create actual http.Request from HttpRequest based on Body type
+ var httpGoReq *http.Request
+ switch body := httpReq.Body.(type) {
+ case []byte:
+ httpGoReq, err = http.NewRequest(httpReq.Method,
httpReq.URL.String(), bytes.NewReader(body))
+ if err != nil {
+ c.logHandler.logf(Error, failedToSendRequest,
err.Error())
+ rs.setError(err)
+ return
+ }
+ httpGoReq.Header = httpReq.Headers
+ case io.Reader:
+ httpGoReq, err = http.NewRequest(httpReq.Method,
httpReq.URL.String(), body)
+ if err != nil {
+ c.logHandler.logf(Error, failedToSendRequest,
err.Error())
+ rs.setError(err)
+ return
+ }
+ httpGoReq.Header = httpReq.Headers
+ case *http.Request:
+ httpGoReq = body
+ default:
+ errMsg := fmt.Sprintf("unsupported body type after
interceptors: %T", body)
+ c.logHandler.logf(Error, failedToSendRequest, errMsg)
+ rs.setError(fmt.Errorf("%s", errMsg))
return
}
- req.Header = httpReq.Headers
- resp, err := c.httpClient.Do(req)
+ resp, err := c.httpClient.Do(httpGoReq)
if err != nil {
c.logHandler.logf(Error, failedToSendRequest, err.Error())
rs.setError(err)
return
}
defer func() {
+ // Drain any unread bytes so the connection can be reused
gracefully.
+ // Without this, Go's HTTP client sends a TCP RST instead of
FIN,
+ // causing "Connection reset by peer" errors on the server.
+ io.Copy(io.Discard, resp.Body)
if err := resp.Body.Close(); err != nil {
c.logHandler.logf(Debug, failedToCloseResponseBody,
err.Error())
}
}()
+ // If the HTTP status indicates an error and the response is not
GraphBinary,
+ // read the body as a text/JSON error message instead of attempting
binary
+ // deserialization which would produce cryptic errors.
+ contentType := resp.Header.Get(HeaderContentType)
+ if resp.StatusCode >= 400 && !strings.Contains(contentType,
graphBinaryMimeType) {
+ bodyBytes, readErr := io.ReadAll(resp.Body)
+ if readErr != nil {
+ c.logHandler.logf(Error, failedToReceiveResponse,
readErr.Error())
+ rs.setError(fmt.Errorf("Gremlin Server returned HTTP %d
and failed to read body: %w",
+ resp.StatusCode, readErr))
+ return
+ }
+ errorBody := string(bodyBytes)
+ errorMsg := tryExtractJSONError(errorBody)
+ if errorMsg == "" {
+ errorMsg = fmt.Sprintf("Gremlin Server returned HTTP
%d: %s", resp.StatusCode, errorBody)
+ }
+ c.logHandler.logf(Error, failedToReceiveResponse, errorMsg)
+ rs.setError(fmt.Errorf("%s", errorMsg))
+ return
+ }
+
reader, zlibReader, err := c.getReader(resp)
if err != nil {
c.logHandler.logf(Error, failedToReceiveResponse, err.Error())
@@ -308,6 +321,23 @@ func (c *connection) streamToResultSet(reader io.Reader,
rs ResultSet) {
}
}
+// tryExtractJSONError attempts to extract an error message from a JSON
response body.
+// The server sometimes responds with a JSON object containing a "message"
field
+// even when it cannot produce a GraphBinary response.
+func tryExtractJSONError(body string) string {
+ var obj map[string]interface{}
+ if err := json.Unmarshal([]byte(body), &obj); err != nil {
+ return ""
+ }
+ if msg, ok := obj["message"]; ok {
+ if s, ok := msg.(string); ok {
+ return s
+ }
+ }
+ return ""
+}
+
func (c *connection) close() {
+ c.wg.Wait()
c.httpClient.CloseIdleConnections()
}
diff --git a/gremlin-go/driver/connection_test.go
b/gremlin-go/driver/connection_test.go
index a222448f2c..7d95192211 100644
--- a/gremlin-go/driver/connection_test.go
+++ b/gremlin-go/driver/connection_test.go
@@ -958,7 +958,7 @@ func TestConnectionWithMockServer(t *testing.T) {
connectionTimeout: 100 * time.Millisecond,
})
- rs, err := conn.submit(&request{gremlin: "g.V()", fields:
map[string]interface{}{}})
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()",
Fields: map[string]interface{}{}})
assert.NoError(t, err) // submit returns nil, error goes to
ResultSet
// All() blocks until stream closes, then we can check error
@@ -979,7 +979,7 @@ func TestConnectionWithMockServer(t *testing.T) {
enableCompression: true,
})
- rs, err := conn.submit(&request{gremlin: "g.V()", fields:
map[string]interface{}{}})
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()",
Fields: map[string]interface{}{}})
require.NoError(t, err)
select {
@@ -993,10 +993,161 @@ func TestConnectionWithMockServer(t *testing.T) {
_, _ = rs.All() // drain
})
+
+ t.Run("returns plain text error for non-GraphBinary 500 response",
func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/plain")
+ w.WriteHeader(http.StatusInternalServerError)
+ w.Write([]byte("Internal Server Error"))
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()",
Fields: map[string]interface{}{}})
+ require.NoError(t, err)
+
+ _, _ = rs.All()
+ rsErr := rs.GetError()
+ require.Error(t, rsErr)
+ assert.Contains(t, rsErr.Error(), "HTTP 500")
+ assert.Contains(t, rsErr.Error(), "Internal Server Error")
+ })
+
+ t.Run("extracts message from JSON error response", func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusUnauthorized)
+ w.Write([]byte(`{"message":"Authentication required"}`))
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()",
Fields: map[string]interface{}{}})
+ require.NoError(t, err)
+
+ _, _ = rs.All()
+ rsErr := rs.GetError()
+ require.Error(t, rsErr)
+ assert.Equal(t, "Authentication required", rsErr.Error())
+ })
+
+ t.Run("falls back to raw body for non-JSON error response", func(t
*testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/html")
+ w.WriteHeader(http.StatusBadGateway)
+ w.Write([]byte("<html>Bad Gateway</html>"))
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()",
Fields: map[string]interface{}{}})
+ require.NoError(t, err)
+
+ _, _ = rs.All()
+ rsErr := rs.GetError()
+ require.Error(t, rsErr)
+ assert.Contains(t, rsErr.Error(), "HTTP 502")
+ assert.Contains(t, rsErr.Error(), "<html>Bad Gateway</html>")
+ })
+
+ t.Run("falls through to GraphBinary deserialization for GraphBinary
error responses", func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", graphBinaryMimeType)
+ w.WriteHeader(http.StatusInternalServerError)
+ // Write invalid GraphBinary — the point is that we
don't short-circuit
+ // to the text error path when Content-Type is
GraphBinary
+ w.Write([]byte{0x00})
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()",
Fields: map[string]interface{}{}})
+ require.NoError(t, err)
+
+ _, _ = rs.All()
+ rsErr := rs.GetError()
+ // Should get a deserialization error, NOT an "HTTP 500" text
error
+ if rsErr != nil {
+ assert.NotContains(t, rsErr.Error(), "HTTP 500")
+ }
+ })
+
+ t.Run("interceptors run before serialization is checked", func(t
*testing.T) {
+ var interceptorHeaders http.Header
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+ conn.AddInterceptor(func(req *HttpRequest) error {
+ interceptorHeaders = req.Headers.Clone()
+ return nil
+ })
+
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()",
Fields: map[string]interface{}{}})
+ require.NoError(t, err)
+ _, _ = rs.All()
+
+ // Interceptor should see the default headers
+ assert.Equal(t, graphBinaryMimeType,
interceptorHeaders.Get("Content-Type"))
+ assert.Equal(t, graphBinaryMimeType,
interceptorHeaders.Get("Accept"))
+ })
+
+ t.Run("close waits for in-flight requests", func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ time.Sleep(200 * time.Millisecond)
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()",
Fields: map[string]interface{}{}})
+ require.NoError(t, err)
+
+ start := time.Now()
+ conn.close()
+ elapsed := time.Since(start)
+
+ // close() should have waited for the in-flight goroutine
+ assert.GreaterOrEqual(t, elapsed.Milliseconds(), int64(150),
+ "close() should wait for in-flight requests to
complete")
+
+ // ResultSet should be closed (goroutine finished)
+ _, _ = rs.All()
+ })
}
// Tests for connection pool configuration settings
+func TestTryExtractJSONError(t *testing.T) {
+ t.Run("extracts message from valid JSON", func(t *testing.T) {
+ result := tryExtractJSONError(`{"message":"auth
failed","code":401}`)
+ assert.Equal(t, "auth failed", result)
+ })
+
+ t.Run("returns empty for JSON without message field", func(t
*testing.T) {
+ result := tryExtractJSONError(`{"error":"something went
wrong"}`)
+ assert.Equal(t, "", result)
+ })
+
+ t.Run("returns empty for invalid JSON", func(t *testing.T) {
+ result := tryExtractJSONError("not json at all")
+ assert.Equal(t, "", result)
+ })
+
+ t.Run("returns empty for HTML content", func(t *testing.T) {
+ result := tryExtractJSONError("<html><body>Error</body></html>")
+ assert.Equal(t, "", result)
+ })
+
+ t.Run("returns empty for empty string", func(t *testing.T) {
+ result := tryExtractJSONError("")
+ assert.Equal(t, "", result)
+ })
+}
+
func TestConnectionPoolSettings(t *testing.T) {
t.Run("default values are applied when settings are 0", func(t
*testing.T) {
// Create connection with empty settings (all zeros)
@@ -1134,3 +1285,36 @@ func TestDriverRemoteConnectionSettingsWiring(t
*testing.T) {
assert.Equal(t, 180*time.Second, transport.IdleConnTimeout)
})
}
+
+// TestConnectionWithMockServer_BasicAuth verifies that BasicAuth interceptor
sets the correct
+// Authorization header and the body is still valid serialized bytes.
+func TestConnectionWithMockServer_BasicAuth(t *testing.T) {
+ var capturedAuthHeader string
+ var capturedBody []byte
+
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ capturedAuthHeader = r.Header.Get("Authorization")
+ body, err := io.ReadAll(r.Body)
+ if err == nil {
+ capturedBody = body
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+ conn.AddInterceptor(BasicAuth("testuser", "testpass"))
+
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields:
map[string]interface{}{}})
+ require.NoError(t, err)
+ _, _ = rs.All() // drain
+
+ // BasicAuth should set Authorization header with
base64("testuser:testpass") = "dGVzdHVzZXI6dGVzdHBhc3M="
+ assert.Equal(t, "Basic dGVzdHVzZXI6dGVzdHBhc3M=", capturedAuthHeader,
+ "Authorization header should be Basic
base64(testuser:testpass)")
+
+ // Body should still be valid serialized bytes
+ assert.NotEmpty(t, capturedBody, "serialized body should be non-empty
with BasicAuth")
+ assert.Equal(t, byte(0x81), capturedBody[0],
+ "body should start with GraphBinary version byte 0x81")
+}
diff --git a/gremlin-go/driver/interceptor.go b/gremlin-go/driver/interceptor.go
new file mode 100644
index 0000000000..e7e0c8e087
--- /dev/null
+++ b/gremlin-go/driver/interceptor.go
@@ -0,0 +1,113 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+*/
+
+package gremlingo
+
+import (
+ "bytes"
+ "crypto/sha256"
+ "encoding/hex"
+ "io"
+ "net/http"
+ "net/url"
+)
+
+// Common HTTP header keys
+const (
+ HeaderContentType = "Content-Type"
+ HeaderAccept = "Accept"
+ HeaderUserAgent = "User-Agent"
+ HeaderAcceptEncoding = "Accept-Encoding"
+ HeaderAuthorization = "Authorization"
+)
+
+// HttpRequest represents an HTTP request that can be modified by interceptors.
+type HttpRequest struct {
+ Method string
+ URL *url.URL
+ Headers http.Header
+ Body any
+}
+
+// NewHttpRequest creates a new HttpRequest with the given method and URL.
+func NewHttpRequest(method, rawURL string) (*HttpRequest, error) {
+ u, err := url.Parse(rawURL)
+ if err != nil {
+ return nil, err
+ }
+ return &HttpRequest{
+ Method: method,
+ URL: u,
+ Headers: make(http.Header),
+ }, nil
+}
+
+// ToStdRequest converts HttpRequest to a standard http.Request for signing.
+// Returns nil if the request cannot be created (invalid method or URL).
+func (r *HttpRequest) ToStdRequest() (*http.Request, error) {
+ var body io.Reader
+ switch b := r.Body.(type) {
+ case []byte:
+ body = bytes.NewReader(b)
+ default:
+ body = http.NoBody
+ }
+ req, err := http.NewRequest(r.Method, r.URL.String(), body)
+ if err != nil {
+ return nil, err
+ }
+ req.Header = r.Headers
+ return req, nil
+}
+
+// PayloadHash returns the SHA256 hash of the request body for SigV4 signing.
+func (r *HttpRequest) PayloadHash() string {
+ switch b := r.Body.(type) {
+ case []byte:
+ if len(b) == 0 {
+ return
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of
empty string
+ }
+ h := sha256.Sum256(b)
+ return hex.EncodeToString(h[:])
+ default:
+ return
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of
empty string
+ }
+}
+
+// RequestInterceptor is a function that modifies an HTTP request before it is
sent.
+type RequestInterceptor func(*HttpRequest) error
+
+// SerializeRequest returns a RequestInterceptor that serializes the raw
*RequestMessage body
+// to GraphBinary []byte. Place this before auth interceptors (e.g.,
SigV4Auth) that
+// need the serialized body bytes.
+func SerializeRequest() RequestInterceptor {
+ serializer := newGraphBinarySerializer(nil)
+ return func(req *HttpRequest) error {
+ r, ok := req.Body.(*RequestMessage)
+ if !ok {
+ return nil // already serialized or not a
*RequestMessage
+ }
+ data, err := serializer.SerializeMessage(r)
+ if err != nil {
+ return err
+ }
+ req.Body = data
+ return nil
+ }
+}
diff --git a/gremlin-go/driver/interceptor_test.go
b/gremlin-go/driver/interceptor_test.go
new file mode 100644
index 0000000000..44c0c61988
--- /dev/null
+++ b/gremlin-go/driver/interceptor_test.go
@@ -0,0 +1,450 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+*/
+
+package gremlingo
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestInterceptorReceivesRawRequest verifies that interceptors receive the
raw *RequestMessage
+// object in HttpRequest.Body, not serialized []byte.
+func TestInterceptorReceivesRawRequest(t *testing.T) {
+ // Mock server that accepts the request (we don't care about the
response for this test)
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ // Create connection with non-nil serializer (default behavior of
newConnection)
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+
+ var capturedBodyType reflect.Type
+ var capturedBody interface{}
+
+ conn.AddInterceptor(func(req *HttpRequest) error {
+ capturedBody = req.Body
+ capturedBodyType = reflect.TypeOf(req.Body)
+ return nil
+ })
+
+ // Submit a request with a known gremlin query
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V().count()",
Fields: map[string]interface{}{}})
+ require.NoError(t, err)
+ _, _ = rs.All() // drain result set
+
+ assert.Equal(t, reflect.TypeOf((*RequestMessage)(nil)),
capturedBodyType,
+ "interceptor should receive *RequestMessage in Body, got %v",
capturedBodyType)
+
+ r, typeAssertOk := capturedBody.(*RequestMessage)
+ assert.True(t, typeAssertOk, "interceptor should be able to type-assert
Body to *RequestMessage")
+ if typeAssertOk {
+ assert.Equal(t, "g.V().count()", r.Gremlin,
+ "interceptor should be able to read the Gremlin field
from the raw request")
+ }
+}
+
+// TestSigV4AuthWithSerializeInterceptor verifies that SerializeRequest() +
SigV4Auth
+// works in a chain. SerializeRequest converts *RequestMessage to []byte, then
SigV4Auth
+// can sign the serialized body.
+func TestSigV4AuthWithSerializeInterceptor(t *testing.T) {
+ var capturedHeaders http.Header
+ var capturedBody []byte
+
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ capturedHeaders = r.Header.Clone()
+ body, err := io.ReadAll(r.Body)
+ if err == nil {
+ capturedBody = body
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+
+ mockProvider := &mockCredentialsProvider{
+ accessKey: "MOCK_ID",
+ secretKey: "MOCK_KEY",
+ }
+
+ conn.AddInterceptor(SerializeRequest())
+ conn.AddInterceptor(SigV4AuthWithCredentials("gremlin-east-1",
"tinkerpop-sigv4", mockProvider))
+
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V().count()",
Fields: map[string]interface{}{}})
+ require.NoError(t, err)
+ _, _ = rs.All() // drain
+
+ // SigV4 should have added Authorization and X-Amz-Date headers
+ assert.NotEmpty(t, capturedHeaders.Get("Authorization"),
+ "SigV4Auth should set Authorization header after
SerializeRequest")
+ assert.NotEmpty(t, capturedHeaders.Get("X-Amz-Date"),
+ "SigV4Auth should set X-Amz-Date header")
+ assert.Contains(t, capturedHeaders.Get("Authorization"),
"AWS4-HMAC-SHA256",
+ "Authorization header should use AWS4-HMAC-SHA256 signing
algorithm")
+
+ // Body should be valid serialized bytes
+ assert.NotEmpty(t, capturedBody, "body should be non-empty serialized
bytes")
+ assert.Equal(t, byte(0x81), capturedBody[0],
+ "body should start with GraphBinary version byte 0x81")
+}
+
+// TestSigV4Auth_AutoSerializesInChain verifies that SigV4Auth works as the
only
+// interceptor — it auto-serializes *RequestMessage before signing.
+func TestSigV4Auth_AutoSerializesInChain(t *testing.T) {
+ var capturedHeaders http.Header
+ var capturedBody []byte
+
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ capturedHeaders = r.Header.Clone()
+ body, err := io.ReadAll(r.Body)
+ if err == nil {
+ capturedBody = body
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+
+ mockProvider := &mockCredentialsProvider{
+ accessKey: "MOCK_ID",
+ secretKey: "MOCK_KEY",
+ }
+
+ // Only SigV4Auth — no SerializeRequest() needed
+ conn.AddInterceptor(SigV4AuthWithCredentials("gremlin-east-1",
"tinkerpop-sigv4", mockProvider))
+
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V().count()",
Fields: map[string]interface{}{}})
+ require.NoError(t, err)
+ _, _ = rs.All()
+
+ assert.NotEmpty(t, capturedHeaders.Get("Authorization"),
+ "SigV4Auth should set Authorization header")
+ assert.Contains(t, capturedHeaders.Get("Authorization"),
"AWS4-HMAC-SHA256")
+ assert.NotEmpty(t, capturedBody, "body should be non-empty serialized
bytes")
+ assert.Equal(t, byte(0x81), capturedBody[0],
+ "body should start with GraphBinary version byte 0x81")
+}
+
+// TestMultipleInterceptors_SerializeThenAuth verifies that a custom
interceptor can
+// modify the raw request, then SerializeRequest serializes it, then BasicAuth
adds headers.
+func TestMultipleInterceptors_SerializeThenAuth(t *testing.T) {
+ var capturedAuthHeader string
+ var capturedBody []byte
+
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ capturedAuthHeader = r.Header.Get("Authorization")
+ body, err := io.ReadAll(r.Body)
+ if err == nil {
+ capturedBody = body
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+
+ // Custom interceptor that modifies the raw request fields
+ conn.AddInterceptor(func(req *HttpRequest) error {
+ r, ok := req.Body.(*RequestMessage)
+ if !ok {
+ return fmt.Errorf("expected *RequestMessage, got %T",
req.Body)
+ }
+ // Add a custom field to the request
+ r.Fields["customField"] = "customValue"
+ return nil
+ })
+
+ // SerializeRequest converts the modified *RequestMessage to []byte
+ conn.AddInterceptor(SerializeRequest())
+
+ // BasicAuth adds the Authorization header (works on any body type)
+ conn.AddInterceptor(BasicAuth("admin", "secret"))
+
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields:
map[string]interface{}{}})
+ require.NoError(t, err)
+ _, _ = rs.All() // drain
+
+ // BasicAuth should have set the Authorization header
+ assert.Equal(t, "Basic YWRtaW46c2VjcmV0", capturedAuthHeader,
+ "Authorization header should be Basic base64(admin:secret)")
+
+ // Body should be valid serialized bytes (from SerializeRequest)
+ assert.NotEmpty(t, capturedBody, "body should be non-empty serialized
bytes")
+ assert.Equal(t, byte(0x81), capturedBody[0],
+ "body should start with GraphBinary version byte 0x81")
+}
+
+// TestInterceptor_IoReaderBody verifies that an interceptor can set Body to
an io.Reader
+// and the request is sent correctly.
+func TestInterceptor_IoReaderBody(t *testing.T) {
+ var capturedBody []byte
+
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ body, err := io.ReadAll(r.Body)
+ if err == nil {
+ capturedBody = body
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+
+ customPayload := []byte("custom binary payload")
+
+ // Interceptor replaces Body with an io.Reader
+ conn.AddInterceptor(func(req *HttpRequest) error {
+ req.Body = bytes.NewReader(customPayload)
+ return nil
+ })
+
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields:
map[string]interface{}{}})
+ require.NoError(t, err)
+ _, _ = rs.All() // drain
+
+ // The server should receive the custom payload from the io.Reader
+ assert.Equal(t, customPayload, capturedBody,
+ "server should receive the custom payload set via io.Reader")
+}
+
+// TestInterceptor_NilSerializerNoSerialization verifies that when serializer
is nil
+// and no interceptor serializes, the correct error message is produced.
+func TestInterceptor_NilSerializerNoSerialization(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+ conn.serializer = nil // explicitly nil serializer
+
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields:
map[string]interface{}{}})
+ require.NoError(t, err)
+
+ _, _ = rs.All() // drain — this triggers the async executeAndStream
+ rsErr := rs.GetError()
+ require.Error(t, rsErr, "should get an error when serializer is nil and
no interceptor serializes")
+ assert.Contains(t, rsErr.Error(), "request body was not serialized",
+ "error message should indicate the body was not serialized")
+}
+
+// TestInterceptor_HttpRequestBody verifies that an interceptor can set Body
to *http.Request
+// and the driver sends it directly, using the *http.Request's headers and
body instead of
+// HttpRequest.Headers.
+func TestInterceptor_HttpRequestBody(t *testing.T) {
+ var capturedHeaders http.Header
+ var capturedBody []byte
+
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ capturedHeaders = r.Header.Clone()
+ body, err := io.ReadAll(r.Body)
+ if err == nil {
+ capturedBody = body
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+
+ customBody := []byte("custom-http-request-body")
+
+ // Interceptor builds a complete *http.Request and sets it as Body
+ conn.AddInterceptor(func(req *HttpRequest) error {
+ httpGoReq, err := http.NewRequest(http.MethodPost,
req.URL.String(), bytes.NewReader(customBody))
+ if err != nil {
+ return err
+ }
+ httpGoReq.Header.Set("X-Custom-Header", "custom-value")
+ httpGoReq.Header.Set("Content-Type", "application/octet-stream")
+ req.Body = httpGoReq
+ return nil
+ })
+
+ // Also set a header on HttpRequest.Headers that should NOT appear,
+ // because *http.Request body bypasses HttpRequest.Headers
+ conn.AddInterceptor(func(req *HttpRequest) error {
+ req.Headers.Set("X-Should-Not-Appear", "ignored")
+ return nil
+ })
+
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields:
map[string]interface{}{}})
+ require.NoError(t, err)
+ _, _ = rs.All() // drain
+
+ // The server should receive headers from the *http.Request, not from
HttpRequest.Headers
+ assert.Equal(t, "custom-value", capturedHeaders.Get("X-Custom-Header"),
+ "server should receive custom header from *http.Request")
+ assert.Equal(t, "application/octet-stream",
capturedHeaders.Get("Content-Type"),
+ "server should receive Content-Type from *http.Request")
+ assert.Empty(t, capturedHeaders.Get("X-Should-Not-Appear"),
+ "headers set on HttpRequest.Headers should not appear when Body
is *http.Request")
+
+ // The server should receive the body from the *http.Request
+ assert.Equal(t, customBody, capturedBody,
+ "server should receive body from the *http.Request")
+}
+
+// TestInterceptor_ErrorPropagation verifies that when an interceptor returns
an error,
+// it is propagated to the ResultSet.
+func TestInterceptor_ErrorPropagation(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+
+ conn.AddInterceptor(func(req *HttpRequest) error {
+ return fmt.Errorf("interceptor failed")
+ })
+
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields:
map[string]interface{}{}})
+ require.NoError(t, err)
+
+ _, _ = rs.All() // drain — triggers async executeAndStream
+ rsErr := rs.GetError()
+ require.Error(t, rsErr, "interceptor error should propagate to
ResultSet")
+ assert.Contains(t, rsErr.Error(), "interceptor failed",
+ "ResultSet error should contain the interceptor's error
message")
+}
+
+// TestInterceptor_UnsupportedBodyType verifies that setting Body to an
unsupported type
+// (e.g., an int) produces the "unsupported body type" error.
+func TestInterceptor_UnsupportedBodyType(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+
+ // Interceptor sets Body to an unsupported type
+ conn.AddInterceptor(func(req *HttpRequest) error {
+ req.Body = 42
+ return nil
+ })
+
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields:
map[string]interface{}{}})
+ require.NoError(t, err)
+
+ _, _ = rs.All() // drain
+ rsErr := rs.GetError()
+ require.Error(t, rsErr, "unsupported body type should produce an error")
+ assert.Contains(t, rsErr.Error(), "unsupported body type",
+ "error message should indicate unsupported body type")
+}
+
+// TestInterceptor_ChainOrder verifies that interceptors run in the order they
are added.
+func TestInterceptor_ChainOrder(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w
http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ conn := newConnection(newTestLogHandler(), server.URL,
&connectionSettings{})
+
+ var order []int
+
+ conn.AddInterceptor(func(req *HttpRequest) error {
+ order = append(order, 1)
+ return nil
+ })
+ conn.AddInterceptor(func(req *HttpRequest) error {
+ order = append(order, 2)
+ return nil
+ })
+ conn.AddInterceptor(func(req *HttpRequest) error {
+ order = append(order, 3)
+ return nil
+ })
+
+ rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields:
map[string]interface{}{}})
+ require.NoError(t, err)
+ _, _ = rs.All() // drain
+
+ assert.Equal(t, []int{1, 2, 3}, order,
+ "interceptors should run in the order they were added")
+}
+
+// TestSigV4Auth_RejectsNonByteBody verifies that SigV4Auth returns an error
when Body
+// is not []byte and not *RequestMessage (e.g., an io.Reader).
+func TestSigV4Auth_RejectsNonByteBody(t *testing.T) {
+ provider := &mockCredentialsProvider{
+ accessKey: "MOCK_ID",
+ secretKey: "MOCK_KEY",
+ }
+ interceptor := SigV4AuthWithCredentials("gremlin-east-1",
"tinkerpop-sigv4", provider)
+
+ req, err := NewHttpRequest("POST", "https://test_url:8182/gremlin")
+ require.NoError(t, err)
+ req.Headers.Set("Content-Type", graphBinaryMimeType)
+ req.Headers.Set("Accept", graphBinaryMimeType)
+
+ // Set Body to an unsupported type (not []byte and not *RequestMessage)
+ req.Body = strings.NewReader("not bytes")
+
+ err = interceptor(req)
+ require.Error(t, err, "SigV4Auth should reject non-[]byte,
non-*RequestMessage body")
+ assert.Contains(t, err.Error(), "SigV4 signing requires body to be
[]byte",
+ "error message should indicate SigV4 requires []byte body")
+}
+
+// TestSigV4Auth_AutoSerializesRequestMessage verifies that SigV4Auth
automatically
+// serializes *RequestMessage to []byte before signing.
+func TestSigV4Auth_AutoSerializesRequestMessage(t *testing.T) {
+ provider := &mockCredentialsProvider{
+ accessKey: "MOCK_ID",
+ secretKey: "MOCK_KEY",
+ }
+ interceptor := SigV4AuthWithCredentials("gremlin-east-1",
"tinkerpop-sigv4", provider)
+
+ req, err := NewHttpRequest("POST", "https://test_url:8182/gremlin")
+ require.NoError(t, err)
+ req.Headers.Set("Content-Type", graphBinaryMimeType)
+ req.Headers.Set("Accept", graphBinaryMimeType)
+
+ // Set Body to *RequestMessage — SigV4Auth should auto-serialize it
+ req.Body = &RequestMessage{Gremlin: "g.V()", Fields:
map[string]interface{}{}}
+
+ err = interceptor(req)
+ require.NoError(t, err, "SigV4Auth should auto-serialize
*RequestMessage")
+
+ // Body should now be []byte (serialized)
+ bodyBytes, ok := req.Body.([]byte)
+ assert.True(t, ok, "Body should be []byte after SigV4Auth
auto-serialization")
+ assert.NotEmpty(t, bodyBytes, "serialized body should be non-empty")
+
+ // SigV4 headers should be set
+ assert.NotEmpty(t, req.Headers.Get("Authorization"), "Authorization
header should be set")
+ assert.NotEmpty(t, req.Headers.Get("X-Amz-Date"), "X-Amz-Date header
should be set")
+ assert.Contains(t, req.Headers.Get("Authorization"), "AWS4-HMAC-SHA256")
+}
diff --git a/gremlin-go/driver/request.go b/gremlin-go/driver/request.go
index 282a301108..eafc0a5071 100644
--- a/gremlin-go/driver/request.go
+++ b/gremlin-go/driver/request.go
@@ -19,10 +19,10 @@ under the License.
package gremlingo
-// request represents a request to the server.
-type request struct {
- gremlin string
- fields map[string]interface{}
+// RequestMessage represents a request to the server.
+type RequestMessage struct {
+ Gremlin string
+ Fields map[string]interface{}
}
// MakeStringRequest creates a request from a Gremlin string query for
submission to a Gremlin server.
@@ -45,7 +45,7 @@ type request struct {
// serializer := newGraphBinarySerializer(nil)
// bytes, _ := serializer.(graphBinarySerializer).SerializeMessage(&req)
// // Send bytes over gRPC, HTTP/2, etc.
-func MakeStringRequest(stringGremlin string, traversalSource string,
requestOptions RequestOptions) (req request) {
+func MakeStringRequest(stringGremlin string, traversalSource string,
requestOptions RequestOptions) (req RequestMessage) {
newFields := map[string]interface{}{
"language": "gremlin-lang",
"g": traversalSource,
@@ -71,9 +71,9 @@ func MakeStringRequest(stringGremlin string, traversalSource
string, requestOpti
newFields["materializeProperties"] =
requestOptions.materializeProperties
}
- return request{
- gremlin: stringGremlin,
- fields: newFields,
+ return RequestMessage{
+ Gremlin: stringGremlin,
+ Fields: newFields,
}
}
diff --git a/gremlin-go/driver/request_test.go
b/gremlin-go/driver/request_test.go
index d37e7420fd..6bd91f5756 100644
--- a/gremlin-go/driver/request_test.go
+++ b/gremlin-go/driver/request_test.go
@@ -28,27 +28,27 @@ import (
func TestRequest(t *testing.T) {
t.Run("Test makeStringRequest() with no bindings", func(t *testing.T) {
r := MakeStringRequest("g.V()", "g", *new(RequestOptions))
- assert.Equal(t, "g.V()", r.gremlin)
- assert.Equal(t, "g", r.fields["g"])
- assert.Equal(t, "gremlin-lang", r.fields["language"])
- assert.Nil(t, r.fields["bindings"])
+ assert.Equal(t, "g.V()", r.Gremlin)
+ assert.Equal(t, "g", r.Fields["g"])
+ assert.Equal(t, "gremlin-lang", r.Fields["language"])
+ assert.Nil(t, r.Fields["bindings"])
})
t.Run("Test makeStringRequest() with custom evaluationTimeout", func(t
*testing.T) {
r := MakeStringRequest("g.V()", "g",
new(RequestOptionsBuilder).SetEvaluationTimeout(1234).Create())
- assert.Equal(t, 1234, r.fields["evaluationTimeout"])
+ assert.Equal(t, 1234, r.Fields["evaluationTimeout"])
})
t.Run("Test makeStringRequest() with custom batchSize", func(t
*testing.T) {
r := MakeStringRequest("g.V()", "g",
new(RequestOptionsBuilder).SetBatchSize(123).Create())
- assert.Equal(t, 123, r.fields["batchSize"])
+ assert.Equal(t, 123, r.Fields["batchSize"])
})
t.Run("Test makeStringRequest() with custom userAgent", func(t
*testing.T) {
r := MakeStringRequest("g.V()", "g",
new(RequestOptionsBuilder).SetUserAgent("TestUserAgent").Create())
- assert.Equal(t, "TestUserAgent", r.fields["userAgent"])
+ assert.Equal(t, "TestUserAgent", r.Fields["userAgent"])
})
}
diff --git a/gremlin-go/driver/serializer.go b/gremlin-go/driver/serializer.go
index f1718a70d3..49030251cc 100644
--- a/gremlin-go/driver/serializer.go
+++ b/gremlin-go/driver/serializer.go
@@ -30,7 +30,7 @@ const graphBinaryMimeType = "application/vnd.graphbinary-v4.0"
// Serializer interface for serializers.
type Serializer interface {
- SerializeMessage(request *request) ([]byte, error)
+ SerializeMessage(request *RequestMessage) ([]byte, error)
DeserializeMessage(message []byte) (Response, error)
}
@@ -90,8 +90,8 @@ const versionByte byte = 0x81
// // Send bytes over custom transport
//
// SerializeMessage serializes a request message into GraphBinary.
-func (gs *GraphBinarySerializer) SerializeMessage(request *request) ([]byte,
error) {
- finalMessage, err := gs.buildMessage(request.gremlin, request.fields)
+func (gs *GraphBinarySerializer) SerializeMessage(request *RequestMessage)
([]byte, error) {
+ finalMessage, err := gs.buildMessage(request.Gremlin, request.Fields)
if err != nil {
return nil, err
}
diff --git a/gremlin-go/driver/serializer_test.go
b/gremlin-go/driver/serializer_test.go
index 87341999f9..fd8d660c8e 100644
--- a/gremlin-go/driver/serializer_test.go
+++ b/gremlin-go/driver/serializer_test.go
@@ -32,9 +32,9 @@ const mapDataOrder2 = "[129 0 0 0 2 3 0 0 0 0 1 103 3 0 0 0 0
1 103 3 0 0 0 0 8
func TestSerializer(t *testing.T) {
t.Run("test serialized request message", func(t *testing.T) {
- testRequest := request{
- gremlin: "g.V().count()",
- fields: map[string]interface{}{"g": "g", "language":
"gremlin-lang"},
+ testRequest := RequestMessage{
+ Gremlin: "g.V().count()",
+ Fields: map[string]interface{}{"g": "g", "language":
"gremlin-lang"},
}
serializer :=
newGraphBinarySerializer(newLogHandler(&defaultLogger{}, Error,
language.English))
serialized, _ := serializer.SerializeMessage(&testRequest)
@@ -59,9 +59,9 @@ func TestSerializer(t *testing.T) {
func TestSerializerFailures(t *testing.T) {
t.Run("test serialize request fields failure", func(t *testing.T) {
invalid := "invalid"
- testRequest := request{
+ testRequest := RequestMessage{
// Invalid pointer type in fields, so should fail
- fields: map[string]interface{}{"invalidInput":
&invalid, "g": "g"},
+ Fields: map[string]interface{}{"invalidInput":
&invalid, "g": "g"},
}
serializer :=
newGraphBinarySerializer(newLogHandler(&defaultLogger{}, Error,
language.English))
resp, err := serializer.SerializeMessage(&testRequest)