This is an automated email from the ASF dual-hosted git repository.

joaoreis pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-gocql-driver.git


The following commit(s) were added to refs/heads/trunk by this push:
     new d32a392e CASSGO-11 Support vector type Support marshalling and 
unmarshalling of vector custom type.
d32a392e is described below

commit d32a392ecef472341baece6192280a2cde80d9f1
Author: Lukasz Antoniak <lukasz.anton...@gmail.com>
AuthorDate: Thu Oct 10 07:50:37 2024 +0200

    CASSGO-11 Support vector type
    Support marshalling and unmarshalling of vector custom type.
    
    patched by Lukasz Antoniak; reviewed by João Reis, Stanislav Bychkov, 
Oleksandr Luzhniy, Mykyta Oleksiienko and Bohdan Siryk for CASSGO-11
---
 .github/workflows/main.yml |   4 +-
 CHANGELOG.md               |   2 +
 cassandra_test.go          |  57 ++++---
 common_test.go             |  13 ++
 frame.go                   |  17 ++
 helpers.go                 | 172 ++++++++++++++-----
 helpers_test.go            |  61 ++++++-
 marshal.go                 | 179 +++++++++++++++++++-
 marshal_test.go            |  32 ++++
 metadata.go                |  60 ++++---
 metadata_test.go           |  12 +-
 vector_test.go             | 401 +++++++++++++++++++++++++++++++++++++++++++++
 12 files changed, 911 insertions(+), 99 deletions(-)

diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index ab4c1366..db4228dc 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -8,7 +8,7 @@ on:
     types: [ opened, synchronize, reopened ]
 
 env:
-  CCM_VERSION: "6e71061146f7ae67b84ccd2b1d90d7319b640e4c"
+  CCM_VERSION: "4621dfee5ad73956b831091a8b863d100d25c610"
 
 jobs:
   build:
@@ -35,7 +35,7 @@ jobs:
       fail-fast: false
       matrix:
         go: [ '1.22', '1.23' ]
-        cassandra_version: [ '4.0.13', '4.1.6' ]
+        cassandra_version: [ '4.1.6', '5.0.3' ]
         auth: [ "false" ]
         compressor: [ "no-compression", "snappy", "lz4" ]
         tags: [ "cassandra", "integration", "ccm" ]
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 37ae55e3..ffaa6629 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,8 @@ and this project adheres to [Semantic 
Versioning](https://semver.org/spec/v2.0.0
 
 ### Added
 
+- Support vector type 
[CASSGO-11](https://issues.apache.org/jira/browse/CASSGO-11)
+
 - Allow SERIAL and LOCAL_SERIAL on SELECT statements 
[CASSGO-26](https://issues.apache.org/jira/browse/CASSGO-26)
 
 - Support of sending queries to the specific node with Query.SetHostID() 
(CASSGO-4)
diff --git a/cassandra_test.go b/cassandra_test.go
index 54a54f42..46909902 100644
--- a/cassandra_test.go
+++ b/cassandra_test.go
@@ -481,15 +481,15 @@ func TestCAS(t *testing.T) {
        }
 
        insertBatch := session.Batch(LoggedBatch)
-       insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) 
VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
-       insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) 
VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
+       insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) 
VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, TOTIMESTAMP(NOW()))")
+       insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) 
VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, TOTIMESTAMP(NOW()))")
        if err := session.ExecuteBatch(insertBatch); err != nil {
                t.Fatal("insert:", err)
        }
 
        failBatch = session.Batch(LoggedBatch)
-       failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) 
WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF 
last_modified=DATEOF(NOW());")
-       failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) 
WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF 
last_modified=DATEOF(NOW());")
+       failBatch.Query("UPDATE cas_table SET last_modified = 
TOTIMESTAMP(NOW()) WHERE title='_foo' AND 
revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF 
last_modified=TOTIMESTAMP(NOW());")
+       failBatch.Query("UPDATE cas_table SET last_modified = 
TOTIMESTAMP(NOW()) WHERE title='_foo' AND 
revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF 
last_modified=TOTIMESTAMP(NOW());")
        if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, 
&revidCAS, &modifiedCAS); err != nil {
                t.Fatal("insert:", err)
        } else if applied {
@@ -533,21 +533,21 @@ func TestCAS(t *testing.T) {
        }
 
        failBatch = session.Batch(LoggedBatch)
-       failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) 
WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF 
last_modified = ?", modified)
+       failBatch.Query("UPDATE cas_table SET last_modified = 
TOTIMESTAMP(NOW()) WHERE title='_foo' AND 
revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified)
        if _, _, err := session.ExecuteBatchCAS(failBatch, new(bool)); err == 
nil {
                t.Fatal("update should have errored")
        }
        // make sure MapScanCAS does not panic when MapScan fails
        casMap = make(map[string]interface{})
        casMap["last_modified"] = false
-       if _, err := session.Query(`UPDATE cas_table SET last_modified = 
DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 
IF last_modified = ?`,
+       if _, err := session.Query(`UPDATE cas_table SET last_modified = 
TOTIMESTAMP(NOW()) WHERE title='_foo' AND 
revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?`,
                modified).MapScanCAS(casMap); err == nil {
                t.Fatal("update should hvae errored", err)
        }
 
        // make sure MapExecuteBatchCAS does not panic when MapScan fails
        failBatch = session.Batch(LoggedBatch)
-       failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) 
WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF 
last_modified = ?", modified)
+       failBatch.Query("UPDATE cas_table SET last_modified = 
TOTIMESTAMP(NOW()) WHERE title='_foo' AND 
revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified)
        casMap = make(map[string]interface{})
        casMap["last_modified"] = false
        if _, _, err := session.MapExecuteBatchCAS(failBatch, casMap); err == 
nil {
@@ -2516,18 +2516,19 @@ func TestAggregateMetadata(t *testing.T) {
                t.Fatal("expected two aggregates")
        }
 
+       protoVer := byte(session.cfg.ProtoVersion)
        expectedAggregrate := AggregateMetadata{
                Keyspace:      "gocql_test",
                Name:          "average",
-               ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt}},
+               ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt, proto: 
protoVer}},
                InitCond:      "(0, 0)",
-               ReturnType:    NativeType{typ: TypeDouble},
+               ReturnType:    NativeType{typ: TypeDouble, proto: protoVer},
                StateType: TupleTypeInfo{
-                       NativeType: NativeType{typ: TypeTuple},
+                       NativeType: NativeType{typ: TypeTuple, proto: protoVer},
 
                        Elems: []TypeInfo{
-                               NativeType{typ: TypeInt},
-                               NativeType{typ: TypeBigInt},
+                               NativeType{typ: TypeInt, proto: protoVer},
+                               NativeType{typ: TypeBigInt, proto: protoVer},
                        },
                },
                stateFunc: "avgstate",
@@ -2566,28 +2567,29 @@ func TestFunctionMetadata(t *testing.T) {
        avgState := functions[1]
        avgFinal := functions[0]
 
+       protoVer := byte(session.cfg.ProtoVersion)
        avgStateBody := "if (val !=null) {state.setInt(0, state.getInt(0)+1); 
state.setLong(1, state.getLong(1)+val.intValue());}return state;"
        expectedAvgState := FunctionMetadata{
                Keyspace: "gocql_test",
                Name:     "avgstate",
                ArgumentTypes: []TypeInfo{
                        TupleTypeInfo{
-                               NativeType: NativeType{typ: TypeTuple},
+                               NativeType: NativeType{typ: TypeTuple, proto: 
protoVer},
 
                                Elems: []TypeInfo{
-                                       NativeType{typ: TypeInt},
-                                       NativeType{typ: TypeBigInt},
+                                       NativeType{typ: TypeInt, proto: 
protoVer},
+                                       NativeType{typ: TypeBigInt, proto: 
protoVer},
                                },
                        },
-                       NativeType{typ: TypeInt},
+                       NativeType{typ: TypeInt, proto: protoVer},
                },
                ArgumentNames: []string{"state", "val"},
                ReturnType: TupleTypeInfo{
-                       NativeType: NativeType{typ: TypeTuple},
+                       NativeType: NativeType{typ: TypeTuple, proto: protoVer},
 
                        Elems: []TypeInfo{
-                               NativeType{typ: TypeInt},
-                               NativeType{typ: TypeBigInt},
+                               NativeType{typ: TypeInt, proto: protoVer},
+                               NativeType{typ: TypeBigInt, proto: protoVer},
                        },
                },
                CalledOnNullInput: true,
@@ -2604,16 +2606,16 @@ func TestFunctionMetadata(t *testing.T) {
                Name:     "avgfinal",
                ArgumentTypes: []TypeInfo{
                        TupleTypeInfo{
-                               NativeType: NativeType{typ: TypeTuple},
+                               NativeType: NativeType{typ: TypeTuple, proto: 
protoVer},
 
                                Elems: []TypeInfo{
-                                       NativeType{typ: TypeInt},
-                                       NativeType{typ: TypeBigInt},
+                                       NativeType{typ: TypeInt, proto: 
protoVer},
+                                       NativeType{typ: TypeBigInt, proto: 
protoVer},
                                },
                        },
                },
                ArgumentNames:     []string{"state"},
-               ReturnType:        NativeType{typ: TypeDouble},
+               ReturnType:        NativeType{typ: TypeDouble, proto: protoVer},
                CalledOnNullInput: true,
                Language:          "java",
                Body:              finalStateBody,
@@ -2717,15 +2719,16 @@ func TestKeyspaceMetadata(t *testing.T) {
        if flagCassVersion.Before(3, 0, 0) {
                textType = TypeVarchar
        }
+       protoVer := byte(session.cfg.ProtoVersion)
        expectedType := UserTypeMetadata{
                Keyspace:   "gocql_test",
                Name:       "basicview",
                FieldNames: []string{"birthday", "nationality", "weight", 
"height"},
                FieldTypes: []TypeInfo{
-                       NativeType{typ: TypeTimestamp},
-                       NativeType{typ: textType},
-                       NativeType{typ: textType},
-                       NativeType{typ: textType},
+                       NativeType{typ: TypeTimestamp, proto: protoVer},
+                       NativeType{typ: textType, proto: protoVer},
+                       NativeType{typ: textType, proto: protoVer},
+                       NativeType{typ: textType, proto: protoVer},
                },
        }
        if !reflect.DeepEqual(*keyspaceMetadata.UserTypes["basicview"], 
expectedType) {
diff --git a/common_test.go b/common_test.go
index e420bbf1..5e4f0cc0 100644
--- a/common_test.go
+++ b/common_test.go
@@ -28,6 +28,7 @@ import (
        "flag"
        "fmt"
        "log"
+       "math/rand"
        "net"
        "reflect"
        "strings"
@@ -54,6 +55,10 @@ var (
        flagCassVersion cassVersion
 )
 
+var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano()))
+
+const randCharset = 
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
+
 func init() {
        flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version 
being tested against")
 
@@ -281,6 +286,14 @@ func assertTrue(t *testing.T, description string, value 
bool) {
        }
 }
 
+func randomText(size int) string {
+       result := make([]byte, size)
+       for i := range result {
+               result[i] = randCharset[rand.Intn(len(randCharset))]
+       }
+       return string(result)
+}
+
 func assertEqual(t *testing.T, description string, expected, actual 
interface{}) {
        t.Helper()
        if expected != actual {
diff --git a/frame.go b/frame.go
index 99b07e28..926c54b8 100644
--- a/frame.go
+++ b/frame.go
@@ -34,6 +34,7 @@ import (
        "io/ioutil"
        "net"
        "runtime"
+       "strconv"
        "strings"
        "time"
 )
@@ -905,6 +906,22 @@ func (f *framer) readTypeInfo() TypeInfo {
                collection.Elem = f.readTypeInfo()
 
                return collection
+       case TypeCustom:
+               if strings.HasPrefix(simple.custom, VECTOR_TYPE) {
+                       spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE)
+                       spec = spec[1 : len(spec)-1] // remove parenthesis
+                       idx := strings.LastIndex(spec, ",")
+                       typeStr := spec[:idx]
+                       dimStr := spec[idx+1:]
+                       subType := 
getCassandraLongType(strings.TrimSpace(typeStr), f.proto, nopLogger{})
+                       dim, _ := strconv.Atoi(strings.TrimSpace(dimStr))
+                       vector := VectorType{
+                               NativeType: simple,
+                               SubType:    subType,
+                               Dimensions: dim,
+                       }
+                       return vector
+               }
        }
 
        return simple
diff --git a/helpers.go b/helpers.go
index 823c1068..79842b7f 100644
--- a/helpers.go
+++ b/helpers.go
@@ -25,10 +25,12 @@
 package gocql
 
 import (
+       "encoding/hex"
        "fmt"
        "math/big"
        "net"
        "reflect"
+       "strconv"
        "strings"
        "time"
 
@@ -162,59 +164,165 @@ func getCassandraBaseType(name string) Type {
        }
 }
 
-func getCassandraType(name string, logger StdLogger) TypeInfo {
+// TODO: Cover with unit tests.
+// Parses long Java-style type definition to internal data structures.
+func getCassandraLongType(name string, protoVer byte, logger StdLogger) 
TypeInfo {
+       if strings.HasPrefix(name, SET_TYPE) {
+               return CollectionType{
+                       NativeType: NewNativeType(protoVer, TypeSet),
+                       Elem:       
getCassandraLongType(unwrapCompositeTypeDefinition(name, SET_TYPE, '('), 
protoVer, logger),
+               }
+       } else if strings.HasPrefix(name, LIST_TYPE) {
+               return CollectionType{
+                       NativeType: NewNativeType(protoVer, TypeList),
+                       Elem:       
getCassandraLongType(unwrapCompositeTypeDefinition(name, LIST_TYPE, '('), 
protoVer, logger),
+               }
+       } else if strings.HasPrefix(name, MAP_TYPE) {
+               names := splitJavaCompositeTypes(name, MAP_TYPE)
+               if len(names) != 2 {
+                       logger.Printf("gocql: error parsing map type, it has %d 
subelements, expecting 2\n", len(names))
+                       return NewNativeType(protoVer, TypeCustom)
+               }
+               return CollectionType{
+                       NativeType: NewNativeType(protoVer, TypeMap),
+                       Key:        getCassandraLongType(names[0], protoVer, 
logger),
+                       Elem:       getCassandraLongType(names[1], protoVer, 
logger),
+               }
+       } else if strings.HasPrefix(name, TUPLE_TYPE) {
+               names := splitJavaCompositeTypes(name, TUPLE_TYPE)
+               types := make([]TypeInfo, len(names))
+
+               for i, name := range names {
+                       types[i] = getCassandraLongType(name, protoVer, logger)
+               }
+
+               return TupleTypeInfo{
+                       NativeType: NewNativeType(protoVer, TypeTuple),
+                       Elems:      types,
+               }
+       } else if strings.HasPrefix(name, UDT_TYPE) {
+               names := splitJavaCompositeTypes(name, UDT_TYPE)
+               fields := make([]UDTField, len(names)-2)
+
+               for i := 2; i < len(names); i++ {
+                       spec := strings.Split(names[i], ":")
+                       fieldName, _ := hex.DecodeString(spec[0])
+                       fields[i-2] = UDTField{
+                               Name: string(fieldName),
+                               Type: getCassandraLongType(spec[1], protoVer, 
logger),
+                       }
+               }
+
+               udtName, _ := hex.DecodeString(names[1])
+               return UDTTypeInfo{
+                       NativeType: NewNativeType(protoVer, TypeUDT),
+                       KeySpace:   names[0],
+                       Name:       string(udtName),
+                       Elements:   fields,
+               }
+       } else if strings.HasPrefix(name, VECTOR_TYPE) {
+               names := splitJavaCompositeTypes(name, VECTOR_TYPE)
+               subType := getCassandraLongType(strings.TrimSpace(names[0]), 
protoVer, logger)
+               dim, err := strconv.Atoi(strings.TrimSpace(names[1]))
+               if err != nil {
+                       logger.Printf("gocql: error parsing vector dimensions: 
%v\n", err)
+                       return NewNativeType(protoVer, TypeCustom)
+               }
+
+               return VectorType{
+                       NativeType: NewCustomType(protoVer, TypeCustom, 
VECTOR_TYPE),
+                       SubType:    subType,
+                       Dimensions: dim,
+               }
+       } else {
+               // basic type
+               return NativeType{
+                       proto: protoVer,
+                       typ:   getApacheCassandraType(name),
+               }
+       }
+}
+
+// Parses short CQL type representation (e.g. map<text, text>) to internal 
data structures.
+func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo {
        if strings.HasPrefix(name, "frozen<") {
-               return getCassandraType(strings.TrimPrefix(name[:len(name)-1], 
"frozen<"), logger)
+               return getCassandraType(unwrapCompositeTypeDefinition(name, 
"frozen", '<'), protoVer, logger)
        } else if strings.HasPrefix(name, "set<") {
                return CollectionType{
-                       NativeType: NativeType{typ: TypeSet},
-                       Elem:       
getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger),
+                       NativeType: NewNativeType(protoVer, TypeSet),
+                       Elem:       
getCassandraType(unwrapCompositeTypeDefinition(name, "set", '<'), protoVer, 
logger),
                }
        } else if strings.HasPrefix(name, "list<") {
                return CollectionType{
-                       NativeType: NativeType{typ: TypeList},
-                       Elem:       
getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger),
+                       NativeType: NewNativeType(protoVer, TypeList),
+                       Elem:       
getCassandraType(unwrapCompositeTypeDefinition(name, "list", '<'), protoVer, 
logger),
                }
        } else if strings.HasPrefix(name, "map<") {
-               names := 
splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<"))
+               names := splitCQLCompositeTypes(name, "map")
                if len(names) != 2 {
                        logger.Printf("Error parsing map type, it has %d 
subelements, expecting 2\n", len(names))
-                       return NativeType{
-                               typ: TypeCustom,
-                       }
+                       return NewNativeType(protoVer, TypeCustom)
                }
                return CollectionType{
-                       NativeType: NativeType{typ: TypeMap},
-                       Key:        getCassandraType(names[0], logger),
-                       Elem:       getCassandraType(names[1], logger),
+                       NativeType: NewNativeType(protoVer, TypeMap),
+                       Key:        getCassandraType(names[0], protoVer, 
logger),
+                       Elem:       getCassandraType(names[1], protoVer, 
logger),
                }
        } else if strings.HasPrefix(name, "tuple<") {
-               names := 
splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<"))
+               names := splitCQLCompositeTypes(name, "tuple")
                types := make([]TypeInfo, len(names))
 
                for i, name := range names {
-                       types[i] = getCassandraType(name, logger)
+                       types[i] = getCassandraType(name, protoVer, logger)
                }
 
                return TupleTypeInfo{
-                       NativeType: NativeType{typ: TypeTuple},
+                       NativeType: NewNativeType(protoVer, TypeTuple),
                        Elems:      types,
                }
+       } else if strings.HasPrefix(name, "vector<") {
+               names := splitCQLCompositeTypes(name, "vector")
+               subType := getCassandraType(strings.TrimSpace(names[0]), 
protoVer, logger)
+               dim, _ := strconv.Atoi(strings.TrimSpace(names[1]))
+
+               return VectorType{
+                       NativeType: NewCustomType(protoVer, TypeCustom, 
VECTOR_TYPE),
+                       SubType:    subType,
+                       Dimensions: dim,
+               }
        } else {
                return NativeType{
-                       typ: getCassandraBaseType(name),
+                       proto: protoVer,
+                       typ:   getCassandraBaseType(name),
                }
        }
 }
 
-func splitCompositeTypes(name string) []string {
-       if !strings.Contains(name, "<") {
-               return strings.Split(name, ", ")
+func splitCQLCompositeTypes(name string, typeName string) []string {
+       return splitCompositeTypes(name, typeName, '<', '>')
+}
+
+func splitJavaCompositeTypes(name string, typeName string) []string {
+       return splitCompositeTypes(name, typeName, '(', ')')
+}
+
+func unwrapCompositeTypeDefinition(name string, typeName string, typeOpen 
int32) string {
+       return strings.TrimPrefix(name[:len(name)-1], typeName+string(typeOpen))
+}
+
+func splitCompositeTypes(name string, typeName string, typeOpen int32, 
typeClose int32) []string {
+       def := unwrapCompositeTypeDefinition(name, typeName, typeOpen)
+       if !strings.Contains(def, string(typeOpen)) {
+               parts := strings.Split(def, ",")
+               for i := range parts {
+                       parts[i] = strings.TrimSpace(parts[i])
+               }
+               return parts
        }
        var parts []string
        lessCount := 0
        segment := ""
-       for _, char := range name {
+       for _, char := range def {
                if char == ',' && lessCount == 0 {
                        if segment != "" {
                                parts = append(parts, 
strings.TrimSpace(segment))
@@ -223,9 +331,9 @@ func splitCompositeTypes(name string) []string {
                        continue
                }
                segment += string(char)
-               if char == '<' {
+               if char == typeOpen {
                        lessCount++
-               } else if char == '>' {
+               } else if char == typeClose {
                        lessCount--
                }
        }
@@ -235,20 +343,6 @@ func splitCompositeTypes(name string) []string {
        return parts
 }
 
-func apacheToCassandraType(t string) string {
-       t = strings.Replace(t, apacheCassandraTypePrefix, "", -1)
-       t = strings.Replace(t, "(", "<", -1)
-       t = strings.Replace(t, ")", ">", -1)
-       types := strings.FieldsFunc(t, func(r rune) bool {
-               return r == '<' || r == '>' || r == ','
-       })
-       for _, typ := range types {
-               t = strings.Replace(t, typ, 
getApacheCassandraType(typ).String(), -1)
-       }
-       // This is done so it exactly matches what Cassandra returns
-       return strings.Replace(t, ",", ", ", -1)
-}
-
 func getApacheCassandraType(class string) Type {
        switch strings.TrimPrefix(class, apacheCassandraTypePrefix) {
        case "AsciiType":
@@ -297,6 +391,10 @@ func getApacheCassandraType(class string) Type {
                return TypeTuple
        case "DurationType":
                return TypeDuration
+       case "SimpleDateType":
+               return TypeDate
+       case "UserType":
+               return TypeUDT
        default:
                return TypeCustom
        }
diff --git a/helpers_test.go b/helpers_test.go
index 67922ba5..275752aa 100644
--- a/helpers_test.go
+++ b/helpers_test.go
@@ -30,7 +30,7 @@ import (
 )
 
 func TestGetCassandraType_Set(t *testing.T) {
-       typ := getCassandraType("set<text>", &defaultLogger{})
+       typ := getCassandraType("set<text>", protoVersion4, &defaultLogger{})
        set, ok := typ.(CollectionType)
        if !ok {
                t.Fatalf("expected CollectionType got %T", typ)
@@ -223,11 +223,68 @@ func TestGetCassandraType(t *testing.T) {
                                Elem:       NativeType{typ: TypeDuration},
                        },
                },
+               {
+                       "vector<float, 3>", VectorType{
+                               NativeType: NativeType{
+                                       typ:    TypeCustom,
+                                       custom: VECTOR_TYPE,
+                               },
+                               SubType:    NativeType{typ: TypeFloat},
+                               Dimensions: 3,
+                       },
+               },
+               {
+                       "vector<vector<float, 3>, 5>", VectorType{
+                               NativeType: NativeType{
+                                       typ:    TypeCustom,
+                                       custom: VECTOR_TYPE,
+                               },
+                               SubType: VectorType{
+                                       NativeType: NativeType{
+                                               typ:    TypeCustom,
+                                               custom: VECTOR_TYPE,
+                                       },
+                                       SubType:    NativeType{typ: TypeFloat},
+                                       Dimensions: 3,
+                               },
+                               Dimensions: 5,
+                       },
+               },
+               {
+                       "vector<map<uuid,timestamp>, 5>", VectorType{
+                               NativeType: NativeType{
+                                       typ:    TypeCustom,
+                                       custom: VECTOR_TYPE,
+                               },
+                               SubType: CollectionType{
+                                       NativeType: NativeType{typ: TypeMap},
+                                       Key:        NativeType{typ: TypeUUID},
+                                       Elem:       NativeType{typ: 
TypeTimestamp},
+                               },
+                               Dimensions: 5,
+                       },
+               },
+               {
+                       "vector<frozen<tuple<int, float>>, 100>", VectorType{
+                               NativeType: NativeType{
+                                       typ:    TypeCustom,
+                                       custom: VECTOR_TYPE,
+                               },
+                               SubType: TupleTypeInfo{
+                                       NativeType: NativeType{typ: TypeTuple},
+                                       Elems: []TypeInfo{
+                                               NativeType{typ: TypeInt},
+                                               NativeType{typ: TypeFloat},
+                                       },
+                               },
+                               Dimensions: 100,
+                       },
+               },
        }
 
        for _, test := range tests {
                t.Run(test.input, func(t *testing.T) {
-                       got := getCassandraType(test.input, &defaultLogger{})
+                       got := getCassandraType(test.input, 0, &defaultLogger{})
 
                        // TODO(zariel): define an equal method on the types?
                        if !reflect.DeepEqual(got, test.exp) {
diff --git a/marshal.go b/marshal.go
index 719a6228..368c7f12 100644
--- a/marshal.go
+++ b/marshal.go
@@ -172,6 +172,10 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, 
error) {
                return marshalDate(info, value)
        case TypeDuration:
                return marshalDuration(info, value)
+       case TypeCustom:
+               if vector, ok := info.(VectorType); ok {
+                       return marshalVector(vector, value)
+               }
        }
 
        // detect protocol 2 UDT
@@ -276,6 +280,10 @@ func Unmarshal(info TypeInfo, data []byte, value 
interface{}) error {
                return unmarshalDate(info, data, value)
        case TypeDuration:
                return unmarshalDuration(info, data, value)
+       case TypeCustom:
+               if vector, ok := info.(VectorType); ok {
+                       return unmarshalVector(vector, data, value)
+               }
        }
 
        // detect protocol 2 UDT
@@ -1716,6 +1724,165 @@ func unmarshalList(info TypeInfo, data []byte, value 
interface{}) error {
        return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: 
*slice, *array.", info, value)
 }
 
+func marshalVector(info VectorType, value interface{}) ([]byte, error) {
+       if value == nil {
+               return nil, nil
+       } else if _, ok := value.(unsetColumn); ok {
+               return nil, nil
+       }
+
+       rv := reflect.ValueOf(value)
+       t := rv.Type()
+       k := t.Kind()
+       if k == reflect.Slice && rv.IsNil() {
+               return nil, nil
+       }
+
+       switch k {
+       case reflect.Slice, reflect.Array:
+               buf := &bytes.Buffer{}
+               n := rv.Len()
+               if n != info.Dimensions {
+                       return nil, marshalErrorf("expected vector with %d 
dimensions, received %d", info.Dimensions, n)
+               }
+
+               for i := 0; i < n; i++ {
+                       item, err := Marshal(info.SubType, 
rv.Index(i).Interface())
+                       if err != nil {
+                               return nil, err
+                       }
+                       if isVectorVariableLengthType(info.SubType) {
+                               writeUnsignedVInt(buf, uint64(len(item)))
+                       }
+                       buf.Write(item)
+               }
+               return buf.Bytes(), nil
+       }
+       return nil, marshalErrorf("can not marshal %T into %s. Accepted types: 
slice, array.", value, info)
+}
+
+func unmarshalVector(info VectorType, data []byte, value interface{}) error {
+       rv := reflect.ValueOf(value)
+       if rv.Kind() != reflect.Ptr {
+               return unmarshalErrorf("can not unmarshal into non-pointer %T", 
value)
+       }
+       rv = rv.Elem()
+       t := rv.Type()
+       k := t.Kind()
+       switch k {
+       case reflect.Slice, reflect.Array:
+               if data == nil {
+                       if k == reflect.Array {
+                               return unmarshalErrorf("unmarshal vector: can 
not store nil in array value")
+                       }
+                       if rv.IsNil() {
+                               return nil
+                       }
+                       rv.Set(reflect.Zero(t))
+                       return nil
+               }
+               if k == reflect.Array {
+                       if rv.Len() != info.Dimensions {
+                               return unmarshalErrorf("unmarshal vector: array 
of size %d cannot store vector of %d dimensions", rv.Len(), info.Dimensions)
+                       }
+               } else {
+                       rv.Set(reflect.MakeSlice(t, info.Dimensions, 
info.Dimensions))
+               }
+               elemSize := len(data) / info.Dimensions
+               for i := 0; i < info.Dimensions; i++ {
+                       offset := 0
+                       if isVectorVariableLengthType(info.SubType) {
+                               m, p, err := readUnsignedVInt(data)
+                               if err != nil {
+                                       return err
+                               }
+                               elemSize = int(m)
+                               offset = p
+                       }
+                       if offset > 0 {
+                               data = data[offset:]
+                       }
+                       var unmarshalData []byte
+                       if elemSize >= 0 {
+                               if len(data) < elemSize {
+                                       return unmarshalErrorf("unmarshal 
vector: unexpected eof")
+                               }
+                               unmarshalData = data[:elemSize]
+                               data = data[elemSize:]
+                       }
+                       err := Unmarshal(info.SubType, unmarshalData, 
rv.Index(i).Addr().Interface())
+                       if err != nil {
+                               return unmarshalErrorf("failed to unmarshal %s 
into %T: %s", info.SubType, unmarshalData, err.Error())
+                       }
+               }
+               return nil
+       }
+       return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: 
slice, array.", info, value)
+}
+
+// isVectorVariableLengthType determines if a type requires explicit length 
serialization within a vector.
+// Variable-length types need their length encoded before the actual data to 
allow proper deserialization.
+// Fixed-length types, on the other hand, don't require this kind of length 
prefix.
+func isVectorVariableLengthType(elemType TypeInfo) bool {
+       switch elemType.Type() {
+       case TypeVarchar, TypeAscii, TypeBlob, TypeText,
+               TypeCounter,
+               TypeDuration, TypeDate, TypeTime,
+               TypeDecimal, TypeSmallInt, TypeTinyInt, TypeVarint,
+               TypeInet,
+               TypeList, TypeSet, TypeMap, TypeUDT, TypeTuple:
+               return true
+       case TypeCustom:
+               if vecType, ok := elemType.(VectorType); ok {
+                       return isVectorVariableLengthType(vecType.SubType)
+               }
+               return true
+       }
+       return false
+}
+
+func writeUnsignedVInt(buf *bytes.Buffer, v uint64) {
+       numBytes := computeUnsignedVIntSize(v)
+       if numBytes <= 1 {
+               buf.WriteByte(byte(v))
+               return
+       }
+
+       extraBytes := numBytes - 1
+       var tmp = make([]byte, numBytes)
+       for i := extraBytes; i >= 0; i-- {
+               tmp[i] = byte(v)
+               v >>= 8
+       }
+       tmp[0] |= byte(^(0xff >> uint(extraBytes)))
+       buf.Write(tmp)
+}
+
+func readUnsignedVInt(data []byte) (uint64, int, error) {
+       if len(data) <= 0 {
+               return 0, 0, errors.New("unexpected eof")
+       }
+       firstByte := data[0]
+       if firstByte&0x80 == 0 {
+               return uint64(firstByte), 1, nil
+       }
+       numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24
+       ret := uint64(firstByte & (0xff >> uint(numBytes)))
+       if len(data) < numBytes+1 {
+               return 0, 0, fmt.Errorf("data expect to have %d bytes, but it 
has only %d", numBytes+1, len(data))
+       }
+       for i := 0; i < numBytes; i++ {
+               ret <<= 8
+               ret |= uint64(data[i+1] & 0xff)
+       }
+       return ret, numBytes + 1, nil
+}
+
+func computeUnsignedVIntSize(v uint64) int {
+       lead0 := bits.LeadingZeros64(v)
+       return (639 - lead0*9) >> 6
+}
+
 func marshalMap(info TypeInfo, value interface{}) ([]byte, error) {
        mapInfo, ok := info.(CollectionType)
        if !ok {
@@ -2475,7 +2642,11 @@ type NativeType struct {
        custom string // only used for TypeCustom
 }
 
-func NewNativeType(proto byte, typ Type, custom string) NativeType {
+func NewNativeType(proto byte, typ Type) NativeType {
+       return NativeType{proto, typ, ""}
+}
+
+func NewCustomType(proto byte, typ Type, custom string) NativeType {
        return NativeType{proto, typ, custom}
 }
 
@@ -2514,6 +2685,12 @@ type CollectionType struct {
        Elem TypeInfo // only used for TypeMap, TypeList and TypeSet
 }
 
+type VectorType struct {
+       NativeType
+       SubType    TypeInfo
+       Dimensions int
+}
+
 func (t CollectionType) NewWithError() (interface{}, error) {
        typ, err := goType(t)
        if err != nil {
diff --git a/marshal_test.go b/marshal_test.go
index d1101550..969d8e5f 100644
--- a/marshal_test.go
+++ b/marshal_test.go
@@ -30,6 +30,7 @@ package gocql
 import (
        "bytes"
        "encoding/binary"
+       "fmt"
        "math"
        "math/big"
        "net"
@@ -2509,6 +2510,37 @@ func TestReadCollectionSize(t *testing.T) {
        }
 }
 
+func TestReadUnsignedVInt(t *testing.T) {
+       tests := []struct {
+               decodedInt  uint64
+               encodedVint []byte
+       }{
+               {
+                       decodedInt:  0,
+                       encodedVint: []byte{0},
+               },
+               {
+                       decodedInt:  100,
+                       encodedVint: []byte{100},
+               },
+               {
+                       decodedInt:  256000,
+                       encodedVint: []byte{195, 232, 0},
+               },
+       }
+       for _, test := range tests {
+               t.Run(fmt.Sprintf("%d", test.decodedInt), func(t *testing.T) {
+                       actual, _, err := readUnsignedVInt(test.encodedVint)
+                       if err != nil {
+                               t.Fatalf("Expected no error, got %v", err)
+                       }
+                       if actual != test.decodedInt {
+                               t.Fatalf("Expected %d, but got %d", 
test.decodedInt, actual)
+                       }
+               })
+       }
+}
+
 func BenchmarkUnmarshalUUID(b *testing.B) {
        b.ReportAllocs()
        src := make([]byte, 16)
diff --git a/metadata.go b/metadata.go
index 63e27aeb..c7f8e4b9 100644
--- a/metadata.go
+++ b/metadata.go
@@ -361,13 +361,13 @@ func compileMetadata(
                col := &columns[i]
                // decode the validator for TypeInfo and order
                if col.ClusteringOrder != "" { // Cassandra 3.x+
-                       col.Type = getCassandraType(col.Validator, logger)
+                       col.Type = getCassandraType(col.Validator, 
byte(protoVersion), logger)
                        col.Order = ASC
                        if col.ClusteringOrder == "desc" {
                                col.Order = DESC
                        }
                } else {
-                       validatorParsed := parseType(col.Validator, logger)
+                       validatorParsed := parseType(col.Validator, 
byte(protoVersion), logger)
                        col.Type = validatorParsed.types[0]
                        col.Order = ASC
                        if validatorParsed.reversed[0] {
@@ -389,9 +389,9 @@ func compileMetadata(
        }
 
        if protoVersion == protoVersion1 {
-               compileV1Metadata(tables, logger)
+               compileV1Metadata(tables, protoVersion, logger)
        } else {
-               compileV2Metadata(tables, logger)
+               compileV2Metadata(tables, protoVersion, logger)
        }
 }
 
@@ -400,14 +400,14 @@ func compileMetadata(
 // column metadata as V2+ (because V1 doesn't support the "type" column in the
 // system.schema_columns table) so determining PartitionKey and ClusterColumns
 // is more complex.
-func compileV1Metadata(tables []TableMetadata, logger StdLogger) {
+func compileV1Metadata(tables []TableMetadata, protoVer int, logger StdLogger) 
{
        for i := range tables {
                table := &tables[i]
 
                // decode the key validator
-               keyValidatorParsed := parseType(table.KeyValidator, logger)
+               keyValidatorParsed := parseType(table.KeyValidator, 
byte(protoVer), logger)
                // decode the comparator
-               comparatorParsed := parseType(table.Comparator, logger)
+               comparatorParsed := parseType(table.Comparator, byte(protoVer), 
logger)
 
                // the partition key length is the same as the number of types 
in the
                // key validator
@@ -493,7 +493,7 @@ func compileV1Metadata(tables []TableMetadata, logger 
StdLogger) {
                                alias = table.ValueAlias
                        }
                        // decode the default validator
-                       defaultValidatorParsed := 
parseType(table.DefaultValidator, logger)
+                       defaultValidatorParsed := 
parseType(table.DefaultValidator, byte(protoVer), logger)
                        column := &ColumnMetadata{
                                Keyspace: table.Keyspace,
                                Table:    table.Name,
@@ -507,7 +507,7 @@ func compileV1Metadata(tables []TableMetadata, logger 
StdLogger) {
 }
 
 // The simpler compile case for V2+ protocol
-func compileV2Metadata(tables []TableMetadata, logger StdLogger) {
+func compileV2Metadata(tables []TableMetadata, protoVer int, logger StdLogger) 
{
        for i := range tables {
                table := &tables[i]
 
@@ -515,7 +515,7 @@ func compileV2Metadata(tables []TableMetadata, logger 
StdLogger) {
                table.ClusteringColumns = make([]*ColumnMetadata, 
clusteringColumnCount)
 
                if table.KeyValidator != "" {
-                       keyValidatorParsed := parseType(table.KeyValidator, 
logger)
+                       keyValidatorParsed := parseType(table.KeyValidator, 
byte(protoVer), logger)
                        table.PartitionKey = make([]*ColumnMetadata, 
len(keyValidatorParsed.types))
                } else { // Cassandra 3.x+
                        partitionKeyCount := 
componentColumnCountOfType(table.Columns, ColumnPartitionKey)
@@ -925,11 +925,11 @@ func getColumnMetadata(session *Session, keyspaceName 
string) ([]ColumnMetadata,
        return columns, nil
 }
 
-func getTypeInfo(t string, logger StdLogger) TypeInfo {
+func getTypeInfo(t string, protoVer byte, logger StdLogger) TypeInfo {
        if strings.HasPrefix(t, apacheCassandraTypePrefix) {
-               t = apacheToCassandraType(t)
+               return getCassandraLongType(t, protoVer, logger)
        }
-       return getCassandraType(t, logger)
+       return getCassandraType(t, protoVer, logger)
 }
 
 func getUserTypeMetadata(session *Session, keyspaceName string) 
([]UserTypeMetadata, error) {
@@ -965,7 +965,7 @@ func getUserTypeMetadata(session *Session, keyspaceName 
string) ([]UserTypeMetad
                }
                uType.FieldTypes = make([]TypeInfo, len(argumentTypes))
                for i, argumentType := range argumentTypes {
-                       uType.FieldTypes[i] = getTypeInfo(argumentType, 
session.logger)
+                       uType.FieldTypes[i] = getTypeInfo(argumentType, 
byte(session.cfg.ProtoVersion), session.logger)
                }
                uTypes = append(uTypes, uType)
        }
@@ -1218,10 +1218,10 @@ func getFunctionsMetadata(session *Session, 
keyspaceName string) ([]FunctionMeta
                if err != nil {
                        return nil, err
                }
-               function.ReturnType = getTypeInfo(returnType, session.logger)
+               function.ReturnType = getTypeInfo(returnType, 
byte(session.cfg.ProtoVersion), session.logger)
                function.ArgumentTypes = make([]TypeInfo, len(argumentTypes))
                for i, argumentType := range argumentTypes {
-                       function.ArgumentTypes[i] = getTypeInfo(argumentType, 
session.logger)
+                       function.ArgumentTypes[i] = getTypeInfo(argumentType, 
byte(session.cfg.ProtoVersion), session.logger)
                }
                functions = append(functions, function)
        }
@@ -1275,11 +1275,11 @@ func getAggregatesMetadata(session *Session, 
keyspaceName string) ([]AggregateMe
                if err != nil {
                        return nil, err
                }
-               aggregate.ReturnType = getTypeInfo(returnType, session.logger)
-               aggregate.StateType = getTypeInfo(stateType, session.logger)
+               aggregate.ReturnType = getTypeInfo(returnType, 
byte(session.cfg.ProtoVersion), session.logger)
+               aggregate.StateType = getTypeInfo(stateType, 
byte(session.cfg.ProtoVersion), session.logger)
                aggregate.ArgumentTypes = make([]TypeInfo, len(argumentTypes))
                for i, argumentType := range argumentTypes {
-                       aggregate.ArgumentTypes[i] = getTypeInfo(argumentType, 
session.logger)
+                       aggregate.ArgumentTypes[i] = getTypeInfo(argumentType, 
byte(session.cfg.ProtoVersion), session.logger)
                }
                aggregates = append(aggregates, aggregate)
        }
@@ -1296,6 +1296,7 @@ type typeParser struct {
        input  string
        index  int
        logger StdLogger
+       proto  byte
 }
 
 // the type definition parser result
@@ -1307,8 +1308,8 @@ type typeParserResult struct {
 }
 
 // Parse the type definition used for validator and comparator schema data
-func parseType(def string, logger StdLogger) typeParserResult {
-       parser := &typeParser{input: def, logger: logger}
+func parseType(def string, protoVer byte, logger StdLogger) typeParserResult {
+       parser := &typeParser{input: def, proto: protoVer, logger: logger}
        return parser.parse()
 }
 
@@ -1319,6 +1320,9 @@ const (
        LIST_TYPE       = "org.apache.cassandra.db.marshal.ListType"
        SET_TYPE        = "org.apache.cassandra.db.marshal.SetType"
        MAP_TYPE        = "org.apache.cassandra.db.marshal.MapType"
+       UDT_TYPE        = "org.apache.cassandra.db.marshal.UserType"
+       TUPLE_TYPE      = "org.apache.cassandra.db.marshal.TupleType"
+       VECTOR_TYPE     = "org.apache.cassandra.db.marshal.VectorType"
 )
 
 // represents a class specification in the type def AST
@@ -1327,6 +1331,7 @@ type typeParserClassNode struct {
        params []typeParserParamNode
        // this is the segment of the input string that defined this node
        input string
+       proto byte
 }
 
 // represents a class parameter in the type def AST
@@ -1346,6 +1351,7 @@ func (t *typeParser) parse() typeParserResult {
                                NativeType{
                                        typ:    TypeCustom,
                                        custom: t.input,
+                                       proto:  t.proto,
                                },
                        },
                        reversed:    []bool{false},
@@ -1423,7 +1429,8 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo {
                elem := class.params[0].class.asTypeInfo()
                return CollectionType{
                        NativeType: NativeType{
-                               typ: TypeList,
+                               typ:   TypeList,
+                               proto: class.proto,
                        },
                        Elem: elem,
                }
@@ -1432,7 +1439,8 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo {
                elem := class.params[0].class.asTypeInfo()
                return CollectionType{
                        NativeType: NativeType{
-                               typ: TypeSet,
+                               typ:   TypeSet,
+                               proto: class.proto,
                        },
                        Elem: elem,
                }
@@ -1442,7 +1450,8 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo {
                elem := class.params[1].class.asTypeInfo()
                return CollectionType{
                        NativeType: NativeType{
-                               typ: TypeMap,
+                               typ:   TypeMap,
+                               proto: class.proto,
                        },
                        Key:  key,
                        Elem: elem,
@@ -1450,7 +1459,7 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo {
        }
 
        // must be a simple type or custom type
-       info := NativeType{typ: getApacheCassandraType(class.name)}
+       info := NativeType{typ: getApacheCassandraType(class.name), proto: 
class.proto}
        if info.typ == TypeCustom {
                // add the entire class definition
                info.custom = class.input
@@ -1480,6 +1489,7 @@ func (t *typeParser) parseClassNode() (node 
*typeParserClassNode, ok bool) {
                name:   name,
                params: params,
                input:  t.input[startIndex:endIndex],
+               proto:  t.proto,
        }
        return node, true
 }
diff --git a/metadata_test.go b/metadata_test.go
index 6e3633cc..6b3d1198 100644
--- a/metadata_test.go
+++ b/metadata_test.go
@@ -636,12 +636,14 @@ func TestTypeParser(t *testing.T) {
                },
        )
 
-       // custom
+       // udt
        assertParseNonCompositeType(
                t,
                
"org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)",
-               assertTypeInfo{Type: TypeCustom, Custom: 
"org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)"},
+               assertTypeInfo{Type: TypeUDT, Custom: ""},
        )
+
+       // custom
        assertParseNonCompositeType(
                t,
                
"org.apache.cassandra.db.marshal.DynamicCompositeType(u=>org.apache.cassandra.db.marshal.UUIDType,d=>org.apache.cassandra.db.marshal.DateType,t=>org.apache.cassandra.db.marshal.TimeUUIDType,b=>org.apache.cassandra.db.marshal.BytesType,s=>org.apache.cassandra.db.marshal.UTF8Type,B=>org.apache.cassandra.db.marshal.BooleanType,a=>org.apache.cassandra.db.marshal.AsciiType,l=>org.apache.cassandra.db.marshal.LongType,i=>org.apache.cassandra.db.marshal.IntegerType,x=>org.apache.cassandra.db.m
 [...]
@@ -700,7 +702,7 @@ func assertParseNonCompositeType(
 ) {
 
        log := &defaultLogger{}
-       result := parseType(def, log)
+       result := parseType(def, protoVersion4, log)
        if len(result.reversed) != 1 {
                t.Errorf("%s expected %d reversed values but there were %d", 
def, 1, len(result.reversed))
        }
@@ -731,7 +733,7 @@ func assertParseCompositeType(
 ) {
 
        log := &defaultLogger{}
-       result := parseType(def, log)
+       result := parseType(def, protoVersion4, log)
        if len(result.reversed) != len(typesExpected) {
                t.Errorf("%s expected %d reversed values but there were %d", 
def, len(typesExpected), len(result.reversed))
        }
@@ -747,7 +749,7 @@ func assertParseCompositeType(
        if !result.isComposite {
                t.Errorf("%s: Expected composite", def)
        }
-       if result.collections == nil {
+       if result.collections == nil && collectionsExpected != nil {
                t.Errorf("%s: Expected non-nil collections: %v", def, 
result.collections)
        }
 
diff --git a/vector_test.go b/vector_test.go
new file mode 100644
index 00000000..4e52a885
--- /dev/null
+++ b/vector_test.go
@@ -0,0 +1,401 @@
+//go:build all || cassandra
+// +build all cassandra
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/*
+ * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40
+ * Copyright (c) 2016, The Gocql authors,
+ * provided under the BSD-3-Clause License.
+ * See the NOTICE file distributed with this work for additional information.
+ */
+
+package gocql
+
+import (
+       "fmt"
+       "github.com/stretchr/testify/require"
+       "gopkg.in/inf.v0"
+       "net"
+       "reflect"
+       "testing"
+       "time"
+)
+
+type person struct {
+       FirstName string `cql:"first_name"`
+       LastName  string `cql:"last_name"`
+       Age       int    `cql:"age"`
+}
+
+func (p person) String() string {
+       return fmt.Sprintf("Person{firstName: %s, lastName: %s, Age: %d}", 
p.FirstName, p.LastName, p.Age)
+}
+
+func TestVector_Marshaler(t *testing.T) {
+       session := createSession(t)
+       defer session.Close()
+
+       if flagCassVersion.Before(5, 0, 0) {
+               t.Skip("Vector types have been introduced in Cassandra 5.0")
+       }
+
+       err := createTable(session, `CREATE TABLE IF NOT EXISTS 
gocql_test.vector_fixed(id int primary key, vec vector<float, 3>);`)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       err = createTable(session, `CREATE TABLE IF NOT EXISTS 
gocql_test.vector_variable(id int primary key, vec vector<text, 4>);`)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       insertFixVec := []float32{8, 2.5, -5.0}
+       err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 
1, insertFixVec).Exec()
+       if err != nil {
+               t.Fatal(err)
+       }
+       var selectFixVec []float32
+       err = session.Query("SELECT vec FROM vector_fixed WHERE id = ?", 
1).Scan(&selectFixVec)
+       if err != nil {
+               t.Fatal(err)
+       }
+       assertDeepEqual(t, "fixed size element vector", insertFixVec, 
selectFixVec)
+
+       longText := randomText(500)
+       insertVarVec := []string{"apache", "cassandra", longText, "gocql"}
+       err = session.Query("INSERT INTO vector_variable(id, vec) VALUES(?, 
?)", 1, insertVarVec).Exec()
+       if err != nil {
+               t.Fatal(err)
+       }
+       var selectVarVec []string
+       err = session.Query("SELECT vec FROM vector_variable WHERE id = ?", 
1).Scan(&selectVarVec)
+       if err != nil {
+               t.Fatal(err)
+       }
+       assertDeepEqual(t, "variable size element vector", insertVarVec, 
selectVarVec)
+}
+
+func TestVector_Types(t *testing.T) {
+       session := createSession(t)
+       defer session.Close()
+
+       if flagCassVersion.Before(5, 0, 0) {
+               t.Skip("Vector types have been introduced in Cassandra 5.0")
+       }
+
+       timestamp1, _ := time.Parse("2006-01-02", "2000-01-01")
+       timestamp2, _ := time.Parse("2006-01-02 15:04:05", "2024-01-01 
10:31:45")
+       timestamp3, _ := time.Parse("2006-01-02 15:04:05.000", "2024-05-01 
10:31:45.987")
+
+       date1, _ := time.Parse("2006-01-02", "2000-01-01")
+       date2, _ := time.Parse("2006-01-02", "2022-03-14")
+       date3, _ := time.Parse("2006-01-02", "2024-12-31")
+
+       time1, _ := time.Parse("15:04:05", "01:00:00")
+       time2, _ := time.Parse("15:04:05", "15:23:59")
+       time3, _ := time.Parse("15:04:05.000", "10:31:45.987")
+
+       duration1 := Duration{0, 1, 1920000000000}
+       duration2 := Duration{1, 1, 1920000000000}
+       duration3 := Duration{31, 0, 60000000000}
+
+       map1 := make(map[string]int)
+       map1["a"] = 1
+       map1["b"] = 2
+       map1["c"] = 3
+       map2 := make(map[string]int)
+       map2["abc"] = 123
+       map3 := make(map[string]int)
+
+       tests := []struct {
+               name       string
+               cqlType    string
+               value      interface{}
+               comparator func(interface{}, interface{})
+       }{
+               {name: "ascii", cqlType: TypeAscii.String(), value: 
[]string{"a", "1", "Z"}},
+               {name: "bigint", cqlType: TypeBigInt.String(), value: 
[]int64{1, 2, 3}},
+               {name: "blob", cqlType: TypeBlob.String(), value: 
[][]byte{[]byte{1, 2, 3}, []byte{4, 5, 6, 7}, []byte{8, 9}}},
+               {name: "boolean", cqlType: TypeBoolean.String(), value: 
[]bool{true, false, true}},
+               {name: "counter", cqlType: TypeCounter.String(), value: 
[]int64{5, 6, 7}},
+               {name: "decimal", cqlType: TypeDecimal.String(), value: 
[]inf.Dec{*inf.NewDec(1, 0), *inf.NewDec(2, 1), *inf.NewDec(-3, 2)}},
+               {name: "double", cqlType: TypeDouble.String(), value: 
[]float64{0.1, -1.2, 3}},
+               {name: "float", cqlType: TypeFloat.String(), value: 
[]float32{0.1, -1.2, 3}},
+               {name: "int", cqlType: TypeInt.String(), value: []int32{1, 2, 
3}},
+               {name: "text", cqlType: TypeText.String(), value: []string{"a", 
"b", "c"}},
+               {name: "timestamp", cqlType: TypeTimestamp.String(), value: 
[]time.Time{timestamp1, timestamp2, timestamp3}},
+               {name: "uuid", cqlType: TypeUUID.String(), value: 
[]UUID{MustRandomUUID(), MustRandomUUID(), MustRandomUUID()}},
+               {name: "varchar", cqlType: TypeVarchar.String(), value: 
[]string{"abc", "def", "ghi"}},
+               {name: "varint", cqlType: TypeVarint.String(), value: 
[]uint64{uint64(1234), uint64(123498765), uint64(18446744073709551615)}},
+               {name: "timeuuid", cqlType: TypeTimeUUID.String(), value: 
[]UUID{TimeUUID(), TimeUUID(), TimeUUID()}},
+               {
+                       name:    "inet",
+                       cqlType: TypeInet.String(),
+                       value:   []net.IP{net.IPv4(127, 0, 0, 1), net.IPv4(192, 
168, 1, 1), net.IPv4(8, 8, 8, 8)},
+                       comparator: func(e interface{}, a interface{}) {
+                               expected := e.([]net.IP)
+                               actual := a.([]net.IP)
+                               assertEqual(t, "vector size", len(expected), 
len(actual))
+                               for i, _ := range expected {
+                                       assertTrue(t, "vector", 
expected[i].Equal(actual[i]))
+                               }
+                       },
+               },
+               {name: "date", cqlType: TypeDate.String(), value: 
[]time.Time{date1, date2, date3}},
+               {name: "time", cqlType: TypeTimestamp.String(), value: 
[]time.Time{time1, time2, time3}},
+               {name: "smallint", cqlType: TypeSmallInt.String(), value: 
[]int16{127, 256, -1234}},
+               {name: "tinyint", cqlType: TypeTinyInt.String(), value: 
[]int8{127, 9, -123}},
+               {name: "duration", cqlType: TypeDuration.String(), value: 
[]Duration{duration1, duration2, duration3}},
+               {name: "vector_vector_float", cqlType: "vector<float, 5>", 
value: [][]float32{{0.1, -1.2, 3, 5, 5}, {10.1, -122222.0002, 35.0, 1, 1}, {0, 
0, 0, 0, 0}}},
+               {name: "vector_vector_set_float", cqlType: "vector<set<float>, 
5>", value: [][][]float32{
+                       {{1, 2}, {2, -1}, {3}, {0}, {-1.3}},
+                       {{2, 3}, {2, -1}, {3}, {0}, {-1.3}},
+                       {{1, 1000.0}, {0}, {}, {12, 14, 15, 16}, {-1.3}},
+               }},
+               {name: "vector_tuple_text_int_float", cqlType: "tuple<text, 
int, float>", value: [][]interface{}{{"a", 1, float32(0.5)}, {"b", 2, 
float32(-1.2)}, {"c", 3, float32(0)}}},
+               {name: "vector_tuple_text_list_text", cqlType: "tuple<text, 
list<text>>", value: [][]interface{}{{"a", []string{"b", "c"}}, {"d", 
[]string{"e", "f", "g"}}, {"h", []string{"i"}}}},
+               {name: "vector_set_text", cqlType: "set<text>", value: 
[][]string{{"a", "b"}, {"c", "d"}, {"e", "f"}}},
+               {name: "vector_list_int", cqlType: "list<int>", value: 
[][]int32{{1, 2, 3}, {-1, -2, -3}, {0, 0, 0}}},
+               {name: "vector_map_text_int", cqlType: "map<text, int>", value: 
[]map[string]int{map1, map2, map3}},
+       }
+
+       for _, test := range tests {
+               t.Run(test.name, func(t *testing.T) {
+                       tableName := fmt.Sprintf("vector_%s", test.name)
+                       err := createTable(session, fmt.Sprintf(`CREATE TABLE 
IF NOT EXISTS gocql_test.%s(id int primary key, vec vector<%s, 3>);`, 
tableName, test.cqlType))
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+
+                       err = session.Query(fmt.Sprintf("INSERT INTO %s(id, 
vec) VALUES(?, ?)", tableName), 1, test.value).Exec()
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+
+                       v := reflect.New(reflect.TypeOf(test.value))
+                       err = session.Query(fmt.Sprintf("SELECT vec FROM %s 
WHERE id = ?", tableName), 1).Scan(v.Interface())
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       if test.comparator != nil {
+                               test.comparator(test.value, 
v.Elem().Interface())
+                       } else {
+                               assertDeepEqual(t, "vector", test.value, 
v.Elem().Interface())
+                       }
+               })
+       }
+}
+
+func TestVector_MarshalerUDT(t *testing.T) {
+       session := createSession(t)
+       defer session.Close()
+
+       if flagCassVersion.Before(5, 0, 0) {
+               t.Skip("Vector types have been introduced in Cassandra 5.0")
+       }
+
+       err := createTable(session, `CREATE TYPE gocql_test.person(
+               first_name text,
+               last_name text,
+               age int);`)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       err = createTable(session, `CREATE TABLE gocql_test.vector_relatives(
+               id int,
+               couple vector<person, 2>,
+               primary key(id)
+       );`)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       p1 := person{"Johny", "Bravo", 25}
+       p2 := person{"Capitan", "Planet", 5}
+       insVec := []person{p1, p2}
+
+       err = session.Query("INSERT INTO vector_relatives(id, couple) VALUES(?, 
?)", 1, insVec).Exec()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       var selVec []person
+
+       err = session.Query("SELECT couple FROM vector_relatives WHERE id = ?", 
1).Scan(&selVec)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       assertDeepEqual(t, "udt", &insVec, &selVec)
+}
+
+func TestVector_Empty(t *testing.T) {
+       session := createSession(t)
+       defer session.Close()
+
+       if flagCassVersion.Before(5, 0, 0) {
+               t.Skip("Vector types have been introduced in Cassandra 5.0")
+       }
+
+       err := createTable(session, `CREATE TABLE IF NOT EXISTS 
gocql_test.vector_fixed_null(id int primary key, vec vector<float, 3>);`)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       err = createTable(session, `CREATE TABLE IF NOT EXISTS 
gocql_test.vector_variable_null(id int primary key, vec vector<text, 4>);`)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       err = session.Query("INSERT INTO vector_fixed_null(id) VALUES(?)", 
1).Exec()
+       if err != nil {
+               t.Fatal(err)
+       }
+       var selectFixVec []float32
+       err = session.Query("SELECT vec FROM vector_fixed_null WHERE id = ?", 
1).Scan(&selectFixVec)
+       if err != nil {
+               t.Fatal(err)
+       }
+       assertTrue(t, "fixed size element vector is empty", selectFixVec == nil)
+
+       err = session.Query("INSERT INTO vector_variable_null(id) VALUES(?)", 
1).Exec()
+       if err != nil {
+               t.Fatal(err)
+       }
+       var selectVarVec []string
+       err = session.Query("SELECT vec FROM vector_variable_null WHERE id = 
?", 1).Scan(&selectVarVec)
+       if err != nil {
+               t.Fatal(err)
+       }
+       assertTrue(t, "variable size element vector is empty", selectVarVec == 
nil)
+}
+
+func TestVector_MissingDimension(t *testing.T) {
+       session := createSession(t)
+       defer session.Close()
+
+       if flagCassVersion.Before(5, 0, 0) {
+               t.Skip("Vector types have been introduced in Cassandra 5.0")
+       }
+
+       err := createTable(session, `CREATE TABLE IF NOT EXISTS 
gocql_test.vector_fixed(id int primary key, vec vector<float, 3>);`)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 
1, []float32{8, -5.0}).Exec()
+       require.Error(t, err, "expected vector with 3 dimensions, received 2")
+
+       err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 
1, []float32{8, -5.0, 1, 3}).Exec()
+       require.Error(t, err, "expected vector with 3 dimensions, received 4")
+}
+
+func TestVector_SubTypeParsing(t *testing.T) {
+       tests := []struct {
+               name     string
+               custom   string
+               expected TypeInfo
+       }{
+               {name: "text", custom: 
"org.apache.cassandra.db.marshal.UTF8Type", expected: NativeType{typ: 
TypeVarchar}},
+               {name: "set_int", custom: 
"org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.Int32Type)",
 expected: CollectionType{NativeType{typ: TypeSet}, nil, NativeType{typ: 
TypeInt}}},
+               {
+                       name:   "udt",
+                       custom: 
"org.apache.cassandra.db.marshal.UserType(gocql_test,706572736f6e,66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,6c6173745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,616765:org.apache.cassandra.db.marshal.Int32Type)",
+                       expected: UDTTypeInfo{
+                               NativeType{typ: TypeUDT},
+                               "gocql_test",
+                               "person",
+                               []UDTField{
+                                       UDTField{"first_name", NativeType{typ: 
TypeVarchar}},
+                                       UDTField{"last_name", NativeType{typ: 
TypeVarchar}},
+                                       UDTField{"age", NativeType{typ: 
TypeInt}},
+                               },
+                       },
+               },
+               {
+                       name:   "tuple",
+                       custom: 
"org.apache.cassandra.db.marshal.TupleType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.UTF8Type)",
+                       expected: TupleTypeInfo{
+                               NativeType{typ: TypeTuple},
+                               []TypeInfo{
+                                       NativeType{typ: TypeVarchar},
+                                       NativeType{typ: TypeInt},
+                                       NativeType{typ: TypeVarchar},
+                               },
+                       },
+               },
+               {
+                       name:   "vector_vector_inet",
+                       custom: 
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.InetAddressType,
 2), 3)",
+                       expected: VectorType{
+                               NativeType{typ: TypeCustom, custom: 
VECTOR_TYPE},
+                               VectorType{
+                                       NativeType{typ: TypeCustom, custom: 
VECTOR_TYPE},
+                                       NativeType{typ: TypeInet},
+                                       2,
+                               },
+                               3,
+                       },
+               },
+               {
+                       name:   "map_int_vector_text",
+                       custom: 
"org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type,
 10))",
+                       expected: CollectionType{
+                               NativeType{typ: TypeMap},
+                               NativeType{typ: TypeInt},
+                               VectorType{
+                                       NativeType{typ: TypeCustom, custom: 
VECTOR_TYPE},
+                                       NativeType{typ: TypeVarchar},
+                                       10,
+                               },
+                       },
+               },
+               {
+                       name:   "set_map_vector_text_text",
+                       custom: 
"org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.Int32Type,
 10),org.apache.cassandra.db.marshal.UTF8Type))",
+                       expected: CollectionType{
+                               NativeType{typ: TypeSet},
+                               nil,
+                               CollectionType{
+                                       NativeType{typ: TypeMap},
+                                       VectorType{
+                                               NativeType{typ: TypeCustom, 
custom: VECTOR_TYPE},
+                                               NativeType{typ: TypeInt},
+                                               10,
+                                       },
+                                       NativeType{typ: TypeVarchar},
+                               },
+                       },
+               },
+       }
+
+       for _, test := range tests {
+               t.Run(test.name, func(t *testing.T) {
+                       f := newFramer(nil, 0)
+                       f.writeShort(0)
+                       
f.writeString(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", 
test.custom))
+                       parsedType := f.readTypeInfo()
+                       require.IsType(t, parsedType, VectorType{})
+                       vectorType := parsedType.(VectorType)
+                       assertEqual(t, "dimensions", 2, vectorType.Dimensions)
+                       assertDeepEqual(t, "vector", test.expected, 
vectorType.SubType)
+               })
+       }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org
For additional commands, e-mail: commits-h...@cassandra.apache.org

Reply via email to