This is an automated email from the ASF dual-hosted git repository.

joaoreis pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-gocql-driver.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 48bb2bc  Change Batch API to be consistent with Query()
48bb2bc is described below

commit 48bb2bc2b499d9f96764d64944be769f3d8068bc
Author: tengu-alt <olexandr.luzh...@gmail.com>
AuthorDate: Wed Nov 20 13:37:03 2024 +0200

    Change Batch API to be consistent with Query()
    
    Exec() method for batch was added & Query() method was refactored.
    Batch for now behaves the same way as query.
    
    patch by Oleksandr Luzhniy; reviewed by João Reis, Danylo Savchenko, Bohdan 
Siryk, Jackson Fleming, for CASSGO-7
---
 CHANGELOG.md              |  2 ++
 batch_test.go             | 16 +++++++++-------
 cassandra_test.go         | 34 +++++++++++++++++-----------------
 doc.go                    |  2 +-
 example_batch_test.go     | 14 ++++++++++++--
 example_lwt_batch_test.go |  4 ++--
 integration_test.go       |  2 +-
 session.go                | 17 ++++++++++++++++-
 session_test.go           |  6 +++---
 9 files changed, 63 insertions(+), 34 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6bc24d4..67c88a1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -16,6 +16,8 @@ and this project adheres to [Semantic 
Versioning](https://semver.org/spec/v2.0.0
 
 - Detailed description for NumConns (CASSGO-3)
 
+- Change Batch API to be consistent with Query() (CASSGO-7)
+
 ### Fixed
 
 - Retry policy now takes into account query idempotency (CASSGO-27)
diff --git a/batch_test.go b/batch_test.go
index 25f8c83..44b5266 100644
--- a/batch_test.go
+++ b/batch_test.go
@@ -47,9 +47,9 @@ func TestBatch_Errors(t *testing.T) {
                t.Fatal(err)
        }
 
-       b := session.NewBatch(LoggedBatch)
-       b.Query("SELECT * FROM batch_errors WHERE id=2 AND val=?", nil)
-       if err := session.ExecuteBatch(b); err == nil {
+       b := session.Batch(LoggedBatch)
+       b = b.Query("SELECT * FROM gocql_test.batch_errors WHERE id=2 AND 
val=?", nil)
+       if err := b.Exec(); err == nil {
                t.Fatal("expected to get error for invalid query in batch")
        }
 }
@@ -68,15 +68,17 @@ func TestBatch_WithTimestamp(t *testing.T) {
 
        micros := time.Now().UnixNano()/1e3 - 1000
 
-       b := session.NewBatch(LoggedBatch)
+       b := session.Batch(LoggedBatch)
        b.WithTimestamp(micros)
-       b.Query("INSERT INTO batch_ts (id, val) VALUES (?, ?)", 1, "val")
-       if err := session.ExecuteBatch(b); err != nil {
+       b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 
1, "val")
+       b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 
2, "val")
+
+       if err := b.Exec(); err != nil {
                t.Fatal(err)
        }
 
        var storedTs int64
-       if err := session.Query(`SELECT writetime(val) FROM batch_ts WHERE id = 
?`, 1).Scan(&storedTs); err != nil {
+       if err := session.Query(`SELECT writetime(val) FROM gocql_test.batch_ts 
WHERE id = ?`, 1).Scan(&storedTs); err != nil {
                t.Fatal(err)
        }
 
diff --git a/cassandra_test.go b/cassandra_test.go
index 3b0c610..ec69691 100644
--- a/cassandra_test.go
+++ b/cassandra_test.go
@@ -45,7 +45,7 @@ import (
        "time"
        "unicode"
 
-       inf "gopkg.in/inf.v0"
+       "gopkg.in/inf.v0"
 )
 
 func TestEmptyHosts(t *testing.T) {
@@ -454,7 +454,7 @@ func TestCAS(t *testing.T) {
                t.Fatal("truncate:", err)
        }
 
-       successBatch := session.NewBatch(LoggedBatch)
+       successBatch := session.Batch(LoggedBatch)
        successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) 
VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
        if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, 
&revidCAS, &modifiedCAS); err != nil {
                t.Fatal("insert:", err)
@@ -462,7 +462,7 @@ func TestCAS(t *testing.T) {
                t.Fatalf("insert should have been applied: title=%v revID=%v 
modified=%v", titleCAS, revidCAS, modifiedCAS)
        }
 
-       successBatch = session.NewBatch(LoggedBatch)
+       successBatch = session.Batch(LoggedBatch)
        successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) 
VALUES (?, ?, ?) IF NOT EXISTS", title+"_foo", revid, modified)
        casMap := make(map[string]interface{})
        if applied, _, err := session.MapExecuteBatchCAS(successBatch, casMap); 
err != nil {
@@ -471,7 +471,7 @@ func TestCAS(t *testing.T) {
                t.Fatal("insert should have been applied")
        }
 
-       failBatch := session.NewBatch(LoggedBatch)
+       failBatch := session.Batch(LoggedBatch)
        failBatch.Query("INSERT INTO cas_table (title, revid, last_modified) 
VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
        if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, 
&revidCAS, &modifiedCAS); err != nil {
                t.Fatal("insert:", err)
@@ -479,14 +479,14 @@ func TestCAS(t *testing.T) {
                t.Fatalf("insert should have been applied: title=%v revID=%v 
modified=%v", titleCAS, revidCAS, modifiedCAS)
        }
 
-       insertBatch := session.NewBatch(LoggedBatch)
+       insertBatch := session.Batch(LoggedBatch)
        insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) 
VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
        insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) 
VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
        if err := session.ExecuteBatch(insertBatch); err != nil {
                t.Fatal("insert:", err)
        }
 
-       failBatch = session.NewBatch(LoggedBatch)
+       failBatch = session.Batch(LoggedBatch)
        failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) 
WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF 
last_modified=DATEOF(NOW());")
        failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) 
WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF 
last_modified=DATEOF(NOW());")
        if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, 
&revidCAS, &modifiedCAS); err != nil {
@@ -611,7 +611,7 @@ func TestBatch(t *testing.T) {
                t.Fatal("create table:", err)
        }
 
-       batch := session.NewBatch(LoggedBatch)
+       batch := session.Batch(LoggedBatch)
        for i := 0; i < 100; i++ {
                batch.Query(`INSERT INTO batch_table (id) VALUES (?)`, i)
        }
@@ -643,9 +643,9 @@ func TestUnpreparedBatch(t *testing.T) {
 
        var batch *Batch
        if session.cfg.ProtoVersion == 2 {
-               batch = session.NewBatch(CounterBatch)
+               batch = session.Batch(CounterBatch)
        } else {
-               batch = session.NewBatch(UnloggedBatch)
+               batch = session.Batch(UnloggedBatch)
        }
 
        for i := 0; i < 100; i++ {
@@ -684,7 +684,7 @@ func TestBatchLimit(t *testing.T) {
                t.Fatal("create table:", err)
        }
 
-       batch := session.NewBatch(LoggedBatch)
+       batch := session.Batch(LoggedBatch)
        for i := 0; i < 65537; i++ {
                batch.Query(`INSERT INTO batch_table2 (id) VALUES (?)`, i)
        }
@@ -738,7 +738,7 @@ func TestTooManyQueryArgs(t *testing.T) {
                t.Fatal("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 
2' should return an error")
        }
 
-       batch := session.NewBatch(UnloggedBatch)
+       batch := session.Batch(UnloggedBatch)
        batch.Query("INSERT INTO too_many_query_args (id, value) VALUES (?, 
?)", 1, 2, 3)
        err = session.ExecuteBatch(batch)
 
@@ -770,7 +770,7 @@ func TestNotEnoughQueryArgs(t *testing.T) {
                t.Fatal("'`SELECT * FROM not_enough_query_args WHERE id = ? and 
cluster = ?`, 1' should return an error")
        }
 
-       batch := session.NewBatch(UnloggedBatch)
+       batch := session.Batch(UnloggedBatch)
        batch.Query("INSERT INTO not_enough_query_args (id, cluster, value) 
VALUES (?, ?, ?)", 1, 2)
        err = session.ExecuteBatch(batch)
 
@@ -1392,7 +1392,7 @@ func TestBatchQueryInfo(t *testing.T) {
                return values, nil
        }
 
-       batch := session.NewBatch(LoggedBatch)
+       batch := session.Batch(LoggedBatch)
        batch.Bind("INSERT INTO batch_query_info (id, cluster, value) VALUES 
(?, ?,?)", write)
 
        if err := session.ExecuteBatch(batch); err != nil {
@@ -1520,7 +1520,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) {
        }
 
        stmt, conn := injectInvalidPreparedStatement(t, session, 
"test_reprepare_statement_batch")
-       batch := session.NewBatch(UnloggedBatch)
+       batch := session.Batch(UnloggedBatch)
        batch.Query(stmt, "bar")
        if err := conn.executeBatch(ctx, batch).Close(); err != nil {
                t.Fatalf("Failed to execute query for reprepare statement: %v", 
err)
@@ -1904,7 +1904,7 @@ func TestBatchStats(t *testing.T) {
                t.Fatalf("failed to create table with error '%v'", err)
        }
 
-       b := session.NewBatch(LoggedBatch)
+       b := session.Batch(LoggedBatch)
        b.Query("INSERT INTO batchStats (id) VALUES (?)", 1)
        b.Query("INSERT INTO batchStats (id) VALUES (?)", 2)
 
@@ -1947,7 +1947,7 @@ func TestBatchObserve(t *testing.T) {
 
        var observedBatch *observation
 
-       batch := session.NewBatch(LoggedBatch)
+       batch := session.Batch(LoggedBatch)
        batch.Observer(funcBatchObserver(func(ctx context.Context, o 
ObservedBatch) {
                if observedBatch != nil {
                        t.Fatal("batch observe called more than once")
@@ -3286,7 +3286,7 @@ func TestUnsetColBatch(t *testing.T) {
                t.Fatalf("failed to create table with error '%v'", err)
        }
 
-       b := session.NewBatch(LoggedBatch)
+       b := session.Batch(LoggedBatch)
        b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) 
VALUES (?,?,?)", 1, 1, UnsetValue)
        b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) 
VALUES (?,?,?)", 1, UnsetValue, "")
        b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) 
VALUES (?,?,?)", 2, 2, UnsetValue)
diff --git a/doc.go b/doc.go
index 236b55e..109a850 100644
--- a/doc.go
+++ b/doc.go
@@ -310,7 +310,7 @@
 // # Batches
 //
 // The CQL protocol supports sending batches of DML statements 
(INSERT/UPDATE/DELETE) and so does gocql.
-// Use Session.NewBatch to create a new batch and then fill-in details of 
individual queries.
+// Use Session.Batch to create a new batch and then fill-in details of 
individual queries.
 // Then execute the batch with Session.ExecuteBatch.
 //
 // Logged batches ensure atomicity, either all or none of the operations in 
the batch will succeed, but they have
diff --git a/example_batch_test.go b/example_batch_test.go
index 2695e48..b27085c 100644
--- a/example_batch_test.go
+++ b/example_batch_test.go
@@ -29,7 +29,7 @@ import (
        "fmt"
        "log"
 
-       gocql "github.com/gocql/gocql"
+       "github.com/gocql/gocql"
 )
 
 // Example_batch demonstrates how to execute a batch of statements.
@@ -49,7 +49,7 @@ func Example_batch() {
 
        ctx := context.Background()
 
-       b := session.NewBatch(gocql.UnloggedBatch).WithContext(ctx)
+       b := session.Batch(gocql.UnloggedBatch).WithContext(ctx)
        b.Entries = append(b.Entries, gocql.BatchEntry{
                Stmt:       "INSERT INTO example.batches (pk, ck, description) 
VALUES (?, ?, ?)",
                Args:       []interface{}{1, 2, "1.2"},
@@ -60,11 +60,19 @@ func Example_batch() {
                Args:       []interface{}{1, 3, "1.3"},
                Idempotent: true,
        })
+
        err = session.ExecuteBatch(b)
        if err != nil {
                log.Fatal(err)
        }
 
+       err = b.Query("INSERT INTO example.batches (pk, ck, description) VALUES 
(?, ?, ?)", 1, 4, "1.4").
+               Query("INSERT INTO example.batches (pk, ck, description) VALUES 
(?, ?, ?)", 1, 5, "1.5").
+               Exec()
+       if err != nil {
+               log.Fatal(err)
+       }
+
        scanner := session.Query("SELECT pk, ck, description FROM 
example.batches").Iter().Scanner()
        for scanner.Next() {
                var pk, ck int32
@@ -77,4 +85,6 @@ func Example_batch() {
        }
        // 1 2 1.2
        // 1 3 1.3
+       // 1 4 1.4
+       // 1 5 1.5
 }
diff --git a/example_lwt_batch_test.go b/example_lwt_batch_test.go
index 916367e..c3cc838 100644
--- a/example_lwt_batch_test.go
+++ b/example_lwt_batch_test.go
@@ -29,7 +29,7 @@ import (
        "fmt"
        "log"
 
-       gocql "github.com/gocql/gocql"
+       "github.com/gocql/gocql"
 )
 
 // ExampleSession_MapExecuteBatchCAS demonstrates how to execute a batch 
lightweight transaction.
@@ -62,7 +62,7 @@ func ExampleSession_MapExecuteBatchCAS() {
        }
 
        executeBatch := func(ck2Version int) {
-               b := session.NewBatch(gocql.LoggedBatch)
+               b := session.Batch(gocql.LoggedBatch)
                b.Entries = append(b.Entries, gocql.BatchEntry{
                        Stmt: "UPDATE my_lwt_batch_table SET value=? WHERE pk=? 
AND ck=? IF version=?",
                        Args: []interface{}{"b", "pk1", "ck1", 1},
diff --git a/integration_test.go b/integration_test.go
index 3622dfb..61ffbf5 100644
--- a/integration_test.go
+++ b/integration_test.go
@@ -218,7 +218,7 @@ func TestCustomPayloadMessages(t *testing.T) {
        iter.Close()
 
        // Batch Message
-       b := session.NewBatch(LoggedBatch)
+       b := session.Batch(LoggedBatch)
        b.CustomPayload = customPayload
        b.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)")
        if err := session.ExecuteBatch(b); err != nil {
diff --git a/session.go b/session.go
index b884735..d04a136 100644
--- a/session.go
+++ b/session.go
@@ -731,6 +731,13 @@ func (b *Batch) execute(ctx context.Context, conn *Conn) 
*Iter {
        return conn.executeBatch(ctx, b)
 }
 
+// Exec executes a batch operation and returns nil if successful
+// otherwise an error is returned describing the failure.
+func (b *Batch) Exec() error {
+       iter := b.session.executeBatch(b)
+       return iter.Close()
+}
+
 func (s *Session) executeBatch(batch *Batch) *Iter {
        // fail fast
        if s.Closed() {
@@ -1748,7 +1755,14 @@ type Batch struct {
 }
 
 // NewBatch creates a new batch operation using defaults defined in the cluster
+//
+// Deprecated: use session.Batch instead
 func (s *Session) NewBatch(typ BatchType) *Batch {
+       return s.Batch(typ)
+}
+
+// Batch creates a new batch operation using defaults defined in the cluster
+func (s *Session) Batch(typ BatchType) *Batch {
        s.mu.RLock()
        batch := &Batch{
                Type:             typ,
@@ -1848,8 +1862,9 @@ func (b *Batch) SpeculativeExecutionPolicy(sp 
SpeculativeExecutionPolicy) *Batch
 }
 
 // Query adds the query to the batch operation
-func (b *Batch) Query(stmt string, args ...interface{}) {
+func (b *Batch) Query(stmt string, args ...interface{}) *Batch {
        b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args})
+       return b
 }
 
 // Bind adds the query to the batch operation and correlates it with a binding 
callback
diff --git a/session_test.go b/session_test.go
index c7bafbb..8633f99 100644
--- a/session_test.go
+++ b/session_test.go
@@ -96,7 +96,7 @@ func TestSessionAPI(t *testing.T) {
                t.Fatalf("expected itr.err to be '%v', got '%v'", 
ErrNoConnections, itr.err)
        }
 
-       testBatch := s.NewBatch(LoggedBatch)
+       testBatch := s.Batch(LoggedBatch)
        testBatch.Query("test")
        err := s.ExecuteBatch(testBatch)
 
@@ -219,7 +219,7 @@ func TestBatchBasicAPI(t *testing.T) {
        s.pool = cfg.PoolConfig.buildPool(s)
 
        // Test UnloggedBatch
-       b := s.NewBatch(UnloggedBatch)
+       b := s.Batch(UnloggedBatch)
        if b.Type != UnloggedBatch {
                t.Fatalf("expceted batch.Type to be '%v', got '%v'", 
UnloggedBatch, b.Type)
        } else if b.rt != cfg.RetryPolicy {
@@ -227,7 +227,7 @@ func TestBatchBasicAPI(t *testing.T) {
        }
 
        // Test LoggedBatch
-       b = s.NewBatch(LoggedBatch)
+       b = s.Batch(LoggedBatch)
        if b.Type != LoggedBatch {
                t.Fatalf("expected batch.Type to be '%v', got '%v'", 
LoggedBatch, b.Type)
        }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org
For additional commands, e-mail: commits-h...@cassandra.apache.org

Reply via email to