This is an automated email from the ASF dual-hosted git repository.

jameshartig pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-gocql-driver.git


The following commit(s) were added to refs/heads/trunk by this push:
     new bf16ec3  CASSGO-4 Support of sending queries to the specific node
bf16ec3 is described below

commit bf16ec371974a1c2082d525b6d78fac108f8c49b
Author: Bohdan Siryk <bohdan.siry...@gmail.com>
AuthorDate: Thu Nov 21 13:12:53 2024 +0200

    CASSGO-4 Support of sending queries to the specific node
    
    Query.SetHostID() allows users to specify on which node the Query will be 
executed.
    It is not a tipycal use case, but it makes sense with virtual tables which 
are available since C* 4.0.
    
    Patch by Bohdan Siryk; Reviewed by João Reis, James Hartig for CASSGO-4
---
 CHANGELOG.md      |  2 ++
 cassandra_test.go | 50 +++++++++++++++++++++++++++++++++++++++++++++++++-
 connectionpool.go |  7 +++++++
 query_executor.go | 25 +++++++++++++++++++++++--
 session.go        | 29 +++++++++++++++++++++++++++++
 5 files changed, 110 insertions(+), 3 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index d6f6faa..f5ac2bd 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,8 @@ and this project adheres to [Semantic 
Versioning](https://semver.org/spec/v2.0.0
 
 ### Added
 
+- Support of sending queries to the specific node with Query.SetHostID() 
(CASSGO-4)
+
 ### Changed
 
 - Move lz4 compressor to lz4 package within the gocql module (CASSGO-32)
diff --git a/cassandra_test.go b/cassandra_test.go
index 545b307..6f13cff 100644
--- a/cassandra_test.go
+++ b/cassandra_test.go
@@ -3327,7 +3327,6 @@ func TestUnsetColBatch(t *testing.T) {
        }
        var id, mInt, count int
        var mText string
-
        if err := session.Query("SELECT count(*) FROM 
gocql_test.batchUnsetInsert;").Scan(&count); err != nil {
                t.Fatalf("Failed to select with err: %v", err)
        } else if count != 2 {
@@ -3362,3 +3361,52 @@ func TestQuery_NamedValues(t *testing.T) {
                t.Fatal(err)
        }
 }
+
+// This test ensures that queries are sent to the specified host only
+func TestQuery_SetHostID(t *testing.T) {
+       session := createSession(t)
+       defer session.Close()
+
+       hosts := session.GetHosts()
+
+       const iterations = 5
+       for _, expectedHost := range hosts {
+               for i := 0; i < iterations; i++ {
+                       var actualHostID string
+                       err := session.Query("SELECT host_id FROM 
system.local").
+                               SetHostID(expectedHost.HostID()).
+                               Scan(&actualHostID)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+
+                       if expectedHost.HostID() != actualHostID {
+                               t.Fatalf("Expected query to be executed on host 
%s, but it was executed on %s",
+                                       expectedHost.HostID(),
+                                       actualHostID,
+                               )
+                       }
+               }
+       }
+
+       // ensuring properly handled invalid host id
+       err := session.Query("SELECT host_id FROM system.local").
+               SetHostID("[invalid]").
+               Exec()
+       if !errors.Is(err, ErrNoConnections) {
+               t.Fatalf("Expected error to be: %v, but got %v", 
ErrNoConnections, err)
+       }
+
+       // ensuring that the driver properly handles the case
+       // when specified host for the query is down
+       host := hosts[0]
+       pool, _ := session.pool.getPoolByHostID(host.HostID())
+       // simulating specified host is down
+       pool.host.setState(NodeDown)
+       err = session.Query("SELECT host_id FROM system.local").
+               SetHostID(host.HostID()).
+               Exec()
+       if !errors.Is(err, ErrNoConnections) {
+               t.Fatalf("Expected error to be: %v, but got %v", 
ErrNoConnections, err)
+       }
+}
diff --git a/connectionpool.go b/connectionpool.go
index 2ccd3c8..9b8295e 100644
--- a/connectionpool.go
+++ b/connectionpool.go
@@ -243,6 +243,13 @@ func (p *policyConnPool) getPool(host *HostInfo) (pool 
*hostConnPool, ok bool) {
        return
 }
 
+func (p *policyConnPool) getPoolByHostID(hostID string) (pool *hostConnPool, 
ok bool) {
+       p.mu.RLock()
+       pool, ok = p.hostConnPools[hostID]
+       p.mu.RUnlock()
+       return
+}
+
 func (p *policyConnPool) Close() {
        p.mu.Lock()
        defer p.mu.Unlock()
diff --git a/query_executor.go b/query_executor.go
index d6be02e..9eaf19d 100644
--- a/query_executor.go
+++ b/query_executor.go
@@ -41,6 +41,7 @@ type ExecutableQuery interface {
        Keyspace() string
        Table() string
        IsIdempotent() bool
+       GetHostID() string
 
        withContext(context.Context) ExecutableQuery
 
@@ -83,12 +84,32 @@ func (q *queryExecutor) speculate(ctx context.Context, qry 
ExecutableQuery, sp S
 }
 
 func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
-       hostIter := q.policy.Pick(qry)
+       var hostIter NextHost
+
+       // check if the host id is specified for the query,
+       // if it is, the query should be executed at the corresponding host.
+       if hostID := qry.GetHostID(); hostID != "" {
+               hostIter = func() SelectedHost {
+                       pool, ok := q.pool.getPoolByHostID(hostID)
+                       // if the specified host is down
+                       // we return nil to avoid endless query execution in 
queryExecutor.do()
+                       if !ok || !pool.host.IsUp() {
+                               return nil
+                       }
+                       return (*selectedHost)(pool.host)
+               }
+       }
+
+       // if host is not specified for the query,
+       // then a host will be picked by HostSelectionPolicy
+       if hostIter == nil {
+               hostIter = q.policy.Pick(qry)
+       }
 
        // check if the query is not marked as idempotent, if
        // it is, we force the policy to NonSpeculative
        sp := qry.speculativeExecutionPolicy()
-       if !qry.IsIdempotent() || sp.Attempts() == 0 {
+       if qry.GetHostID() != "" || !qry.IsIdempotent() || sp.Attempts() == 0 {
                return q.do(qry.Context(), qry, hostIter), nil
        }
 
diff --git a/session.go b/session.go
index c47e753..8965f0f 100644
--- a/session.go
+++ b/session.go
@@ -456,6 +456,7 @@ func (s *Session) Query(stmt string, values ...interface{}) 
*Query {
        qry.session = s
        qry.stmt = stmt
        qry.values = values
+       qry.hostID = ""
        qry.defaultsFromSession()
        return qry
 }
@@ -949,6 +950,10 @@ type Query struct {
 
        // routingInfo is a pointer because Query can be copied and copyable 
struct can't hold a mutex.
        routingInfo *queryRoutingInfo
+
+       // hostID specifies the host on which the query should be executed.
+       // If it is empty, then the host is picked by HostSelectionPolicy
+       hostID string
 }
 
 type queryRoutingInfo struct {
@@ -1442,6 +1447,20 @@ func (q *Query) releaseAfterExecution() {
        q.decRefCount()
 }
 
+// SetHostID allows to define the host the query should be executed against. 
If the
+// host was filtered or otherwise unavailable, then the query will error. If 
an empty
+// string is sent, the default behavior, using the configured 
HostSelectionPolicy will
+// be used. A hostID can be obtained from HostInfo.HostID() after calling 
GetHosts().
+func (q *Query) SetHostID(hostID string) *Query {
+       q.hostID = hostID
+       return q
+}
+
+// GetHostID returns id of the host on which query should be executed.
+func (q *Query) GetHostID() string {
+       return q.hostID
+}
+
 // Iter represents an iterator that can be used to iterate over all rows that
 // were returned by a query. The iterator might send additional queries to the
 // database during the iteration if paging was enabled.
@@ -2057,6 +2076,11 @@ func (b *Batch) releaseAfterExecution() {
        // that would race with speculative executions.
 }
 
+// GetHostID satisfies ExecutableQuery interface but does noop.
+func (b *Batch) GetHostID() string {
+       return ""
+}
+
 type BatchType byte
 
 const (
@@ -2189,6 +2213,11 @@ func (t *traceWriter) Trace(traceId []byte) {
        }
 }
 
+// GetHosts return a list of hosts in the ring the driver knows of.
+func (s *Session) GetHosts() []*HostInfo {
+       return s.ring.allHosts()
+}
+
 type ObservedQuery struct {
        Keyspace  string
        Statement string


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org
For additional commands, e-mail: commits-h...@cassandra.apache.org

Reply via email to