This is an automated email from the ASF dual-hosted git repository. joaoreis pushed a commit to branch trunk in repository https://gitbox.apache.org/repos/asf/cassandra-gocql-driver.git
The following commit(s) were added to refs/heads/trunk by this push: new d32a392e CASSGO-11 Support vector type Support marshalling and unmarshalling of vector custom type. d32a392e is described below commit d32a392ecef472341baece6192280a2cde80d9f1 Author: Lukasz Antoniak <lukasz.anton...@gmail.com> AuthorDate: Thu Oct 10 07:50:37 2024 +0200 CASSGO-11 Support vector type Support marshalling and unmarshalling of vector custom type. patched by Lukasz Antoniak; reviewed by João Reis, Stanislav Bychkov, Oleksandr Luzhniy, Mykyta Oleksiienko and Bohdan Siryk for CASSGO-11 --- .github/workflows/main.yml | 4 +- CHANGELOG.md | 2 + cassandra_test.go | 57 ++++--- common_test.go | 13 ++ frame.go | 17 ++ helpers.go | 172 ++++++++++++++----- helpers_test.go | 61 ++++++- marshal.go | 179 +++++++++++++++++++- marshal_test.go | 32 ++++ metadata.go | 60 ++++--- metadata_test.go | 12 +- vector_test.go | 401 +++++++++++++++++++++++++++++++++++++++++++++ 12 files changed, 911 insertions(+), 99 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ab4c1366..db4228dc 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -8,7 +8,7 @@ on: types: [ opened, synchronize, reopened ] env: - CCM_VERSION: "6e71061146f7ae67b84ccd2b1d90d7319b640e4c" + CCM_VERSION: "4621dfee5ad73956b831091a8b863d100d25c610" jobs: build: @@ -35,7 +35,7 @@ jobs: fail-fast: false matrix: go: [ '1.22', '1.23' ] - cassandra_version: [ '4.0.13', '4.1.6' ] + cassandra_version: [ '4.1.6', '5.0.3' ] auth: [ "false" ] compressor: [ "no-compression", "snappy", "lz4" ] tags: [ "cassandra", "integration", "ccm" ] diff --git a/CHANGELOG.md b/CHANGELOG.md index 37ae55e3..ffaa6629 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Support vector type [CASSGO-11](https://issues.apache.org/jira/browse/CASSGO-11) + - Allow SERIAL and LOCAL_SERIAL on SELECT statements [CASSGO-26](https://issues.apache.org/jira/browse/CASSGO-26) - Support of sending queries to the specific node with Query.SetHostID() (CASSGO-4) diff --git a/cassandra_test.go b/cassandra_test.go index 54a54f42..46909902 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -481,15 +481,15 @@ func TestCAS(t *testing.T) { } insertBatch := session.Batch(LoggedBatch) - insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))") - insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))") + insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, TOTIMESTAMP(NOW()))") + insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, TOTIMESTAMP(NOW()))") if err := session.ExecuteBatch(insertBatch); err != nil { t.Fatal("insert:", err) } failBatch = session.Batch(LoggedBatch) - failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());") - failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());") + failBatch.Query("UPDATE cas_table SET last_modified = TOTIMESTAMP(NOW()) WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF last_modified=TOTIMESTAMP(NOW());") + failBatch.Query("UPDATE cas_table SET last_modified = TOTIMESTAMP(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified=TOTIMESTAMP(NOW());") if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil { t.Fatal("insert:", err) } else if applied { @@ -533,21 +533,21 @@ func TestCAS(t *testing.T) { } failBatch = session.Batch(LoggedBatch) - failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified) + failBatch.Query("UPDATE cas_table SET last_modified = TOTIMESTAMP(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified) if _, _, err := session.ExecuteBatchCAS(failBatch, new(bool)); err == nil { t.Fatal("update should have errored") } // make sure MapScanCAS does not panic when MapScan fails casMap = make(map[string]interface{}) casMap["last_modified"] = false - if _, err := session.Query(`UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?`, + if _, err := session.Query(`UPDATE cas_table SET last_modified = TOTIMESTAMP(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?`, modified).MapScanCAS(casMap); err == nil { t.Fatal("update should hvae errored", err) } // make sure MapExecuteBatchCAS does not panic when MapScan fails failBatch = session.Batch(LoggedBatch) - failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified) + failBatch.Query("UPDATE cas_table SET last_modified = TOTIMESTAMP(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified) casMap = make(map[string]interface{}) casMap["last_modified"] = false if _, _, err := session.MapExecuteBatchCAS(failBatch, casMap); err == nil { @@ -2516,18 +2516,19 @@ func TestAggregateMetadata(t *testing.T) { t.Fatal("expected two aggregates") } + protoVer := byte(session.cfg.ProtoVersion) expectedAggregrate := AggregateMetadata{ Keyspace: "gocql_test", Name: "average", - ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt}}, + ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt, proto: protoVer}}, InitCond: "(0, 0)", - ReturnType: NativeType{typ: TypeDouble}, + ReturnType: NativeType{typ: TypeDouble, proto: protoVer}, StateType: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeBigInt}, + NativeType{typ: TypeInt, proto: protoVer}, + NativeType{typ: TypeBigInt, proto: protoVer}, }, }, stateFunc: "avgstate", @@ -2566,28 +2567,29 @@ func TestFunctionMetadata(t *testing.T) { avgState := functions[1] avgFinal := functions[0] + protoVer := byte(session.cfg.ProtoVersion) avgStateBody := "if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;" expectedAvgState := FunctionMetadata{ Keyspace: "gocql_test", Name: "avgstate", ArgumentTypes: []TypeInfo{ TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeBigInt}, + NativeType{typ: TypeInt, proto: protoVer}, + NativeType{typ: TypeBigInt, proto: protoVer}, }, }, - NativeType{typ: TypeInt}, + NativeType{typ: TypeInt, proto: protoVer}, }, ArgumentNames: []string{"state", "val"}, ReturnType: TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeBigInt}, + NativeType{typ: TypeInt, proto: protoVer}, + NativeType{typ: TypeBigInt, proto: protoVer}, }, }, CalledOnNullInput: true, @@ -2604,16 +2606,16 @@ func TestFunctionMetadata(t *testing.T) { Name: "avgfinal", ArgumentTypes: []TypeInfo{ TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NativeType{typ: TypeTuple, proto: protoVer}, Elems: []TypeInfo{ - NativeType{typ: TypeInt}, - NativeType{typ: TypeBigInt}, + NativeType{typ: TypeInt, proto: protoVer}, + NativeType{typ: TypeBigInt, proto: protoVer}, }, }, }, ArgumentNames: []string{"state"}, - ReturnType: NativeType{typ: TypeDouble}, + ReturnType: NativeType{typ: TypeDouble, proto: protoVer}, CalledOnNullInput: true, Language: "java", Body: finalStateBody, @@ -2717,15 +2719,16 @@ func TestKeyspaceMetadata(t *testing.T) { if flagCassVersion.Before(3, 0, 0) { textType = TypeVarchar } + protoVer := byte(session.cfg.ProtoVersion) expectedType := UserTypeMetadata{ Keyspace: "gocql_test", Name: "basicview", FieldNames: []string{"birthday", "nationality", "weight", "height"}, FieldTypes: []TypeInfo{ - NativeType{typ: TypeTimestamp}, - NativeType{typ: textType}, - NativeType{typ: textType}, - NativeType{typ: textType}, + NativeType{typ: TypeTimestamp, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, + NativeType{typ: textType, proto: protoVer}, }, } if !reflect.DeepEqual(*keyspaceMetadata.UserTypes["basicview"], expectedType) { diff --git a/common_test.go b/common_test.go index e420bbf1..5e4f0cc0 100644 --- a/common_test.go +++ b/common_test.go @@ -28,6 +28,7 @@ import ( "flag" "fmt" "log" + "math/rand" "net" "reflect" "strings" @@ -54,6 +55,10 @@ var ( flagCassVersion cassVersion ) +var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) + +const randCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + func init() { flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against") @@ -281,6 +286,14 @@ func assertTrue(t *testing.T, description string, value bool) { } } +func randomText(size int) string { + result := make([]byte, size) + for i := range result { + result[i] = randCharset[rand.Intn(len(randCharset))] + } + return string(result) +} + func assertEqual(t *testing.T, description string, expected, actual interface{}) { t.Helper() if expected != actual { diff --git a/frame.go b/frame.go index 99b07e28..926c54b8 100644 --- a/frame.go +++ b/frame.go @@ -34,6 +34,7 @@ import ( "io/ioutil" "net" "runtime" + "strconv" "strings" "time" ) @@ -905,6 +906,22 @@ func (f *framer) readTypeInfo() TypeInfo { collection.Elem = f.readTypeInfo() return collection + case TypeCustom: + if strings.HasPrefix(simple.custom, VECTOR_TYPE) { + spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE) + spec = spec[1 : len(spec)-1] // remove parenthesis + idx := strings.LastIndex(spec, ",") + typeStr := spec[:idx] + dimStr := spec[idx+1:] + subType := getCassandraLongType(strings.TrimSpace(typeStr), f.proto, nopLogger{}) + dim, _ := strconv.Atoi(strings.TrimSpace(dimStr)) + vector := VectorType{ + NativeType: simple, + SubType: subType, + Dimensions: dim, + } + return vector + } } return simple diff --git a/helpers.go b/helpers.go index 823c1068..79842b7f 100644 --- a/helpers.go +++ b/helpers.go @@ -25,10 +25,12 @@ package gocql import ( + "encoding/hex" "fmt" "math/big" "net" "reflect" + "strconv" "strings" "time" @@ -162,59 +164,165 @@ func getCassandraBaseType(name string) Type { } } -func getCassandraType(name string, logger StdLogger) TypeInfo { +// TODO: Cover with unit tests. +// Parses long Java-style type definition to internal data structures. +func getCassandraLongType(name string, protoVer byte, logger StdLogger) TypeInfo { + if strings.HasPrefix(name, SET_TYPE) { + return CollectionType{ + NativeType: NewNativeType(protoVer, TypeSet), + Elem: getCassandraLongType(unwrapCompositeTypeDefinition(name, SET_TYPE, '('), protoVer, logger), + } + } else if strings.HasPrefix(name, LIST_TYPE) { + return CollectionType{ + NativeType: NewNativeType(protoVer, TypeList), + Elem: getCassandraLongType(unwrapCompositeTypeDefinition(name, LIST_TYPE, '('), protoVer, logger), + } + } else if strings.HasPrefix(name, MAP_TYPE) { + names := splitJavaCompositeTypes(name, MAP_TYPE) + if len(names) != 2 { + logger.Printf("gocql: error parsing map type, it has %d subelements, expecting 2\n", len(names)) + return NewNativeType(protoVer, TypeCustom) + } + return CollectionType{ + NativeType: NewNativeType(protoVer, TypeMap), + Key: getCassandraLongType(names[0], protoVer, logger), + Elem: getCassandraLongType(names[1], protoVer, logger), + } + } else if strings.HasPrefix(name, TUPLE_TYPE) { + names := splitJavaCompositeTypes(name, TUPLE_TYPE) + types := make([]TypeInfo, len(names)) + + for i, name := range names { + types[i] = getCassandraLongType(name, protoVer, logger) + } + + return TupleTypeInfo{ + NativeType: NewNativeType(protoVer, TypeTuple), + Elems: types, + } + } else if strings.HasPrefix(name, UDT_TYPE) { + names := splitJavaCompositeTypes(name, UDT_TYPE) + fields := make([]UDTField, len(names)-2) + + for i := 2; i < len(names); i++ { + spec := strings.Split(names[i], ":") + fieldName, _ := hex.DecodeString(spec[0]) + fields[i-2] = UDTField{ + Name: string(fieldName), + Type: getCassandraLongType(spec[1], protoVer, logger), + } + } + + udtName, _ := hex.DecodeString(names[1]) + return UDTTypeInfo{ + NativeType: NewNativeType(protoVer, TypeUDT), + KeySpace: names[0], + Name: string(udtName), + Elements: fields, + } + } else if strings.HasPrefix(name, VECTOR_TYPE) { + names := splitJavaCompositeTypes(name, VECTOR_TYPE) + subType := getCassandraLongType(strings.TrimSpace(names[0]), protoVer, logger) + dim, err := strconv.Atoi(strings.TrimSpace(names[1])) + if err != nil { + logger.Printf("gocql: error parsing vector dimensions: %v\n", err) + return NewNativeType(protoVer, TypeCustom) + } + + return VectorType{ + NativeType: NewCustomType(protoVer, TypeCustom, VECTOR_TYPE), + SubType: subType, + Dimensions: dim, + } + } else { + // basic type + return NativeType{ + proto: protoVer, + typ: getApacheCassandraType(name), + } + } +} + +// Parses short CQL type representation (e.g. map<text, text>) to internal data structures. +func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo { if strings.HasPrefix(name, "frozen<") { - return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger) + return getCassandraType(unwrapCompositeTypeDefinition(name, "frozen", '<'), protoVer, logger) } else if strings.HasPrefix(name, "set<") { return CollectionType{ - NativeType: NativeType{typ: TypeSet}, - Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger), + NativeType: NewNativeType(protoVer, TypeSet), + Elem: getCassandraType(unwrapCompositeTypeDefinition(name, "set", '<'), protoVer, logger), } } else if strings.HasPrefix(name, "list<") { return CollectionType{ - NativeType: NativeType{typ: TypeList}, - Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger), + NativeType: NewNativeType(protoVer, TypeList), + Elem: getCassandraType(unwrapCompositeTypeDefinition(name, "list", '<'), protoVer, logger), } } else if strings.HasPrefix(name, "map<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) + names := splitCQLCompositeTypes(name, "map") if len(names) != 2 { logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) - return NativeType{ - typ: TypeCustom, - } + return NewNativeType(protoVer, TypeCustom) } return CollectionType{ - NativeType: NativeType{typ: TypeMap}, - Key: getCassandraType(names[0], logger), - Elem: getCassandraType(names[1], logger), + NativeType: NewNativeType(protoVer, TypeMap), + Key: getCassandraType(names[0], protoVer, logger), + Elem: getCassandraType(names[1], protoVer, logger), } } else if strings.HasPrefix(name, "tuple<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) + names := splitCQLCompositeTypes(name, "tuple") types := make([]TypeInfo, len(names)) for i, name := range names { - types[i] = getCassandraType(name, logger) + types[i] = getCassandraType(name, protoVer, logger) } return TupleTypeInfo{ - NativeType: NativeType{typ: TypeTuple}, + NativeType: NewNativeType(protoVer, TypeTuple), Elems: types, } + } else if strings.HasPrefix(name, "vector<") { + names := splitCQLCompositeTypes(name, "vector") + subType := getCassandraType(strings.TrimSpace(names[0]), protoVer, logger) + dim, _ := strconv.Atoi(strings.TrimSpace(names[1])) + + return VectorType{ + NativeType: NewCustomType(protoVer, TypeCustom, VECTOR_TYPE), + SubType: subType, + Dimensions: dim, + } } else { return NativeType{ - typ: getCassandraBaseType(name), + proto: protoVer, + typ: getCassandraBaseType(name), } } } -func splitCompositeTypes(name string) []string { - if !strings.Contains(name, "<") { - return strings.Split(name, ", ") +func splitCQLCompositeTypes(name string, typeName string) []string { + return splitCompositeTypes(name, typeName, '<', '>') +} + +func splitJavaCompositeTypes(name string, typeName string) []string { + return splitCompositeTypes(name, typeName, '(', ')') +} + +func unwrapCompositeTypeDefinition(name string, typeName string, typeOpen int32) string { + return strings.TrimPrefix(name[:len(name)-1], typeName+string(typeOpen)) +} + +func splitCompositeTypes(name string, typeName string, typeOpen int32, typeClose int32) []string { + def := unwrapCompositeTypeDefinition(name, typeName, typeOpen) + if !strings.Contains(def, string(typeOpen)) { + parts := strings.Split(def, ",") + for i := range parts { + parts[i] = strings.TrimSpace(parts[i]) + } + return parts } var parts []string lessCount := 0 segment := "" - for _, char := range name { + for _, char := range def { if char == ',' && lessCount == 0 { if segment != "" { parts = append(parts, strings.TrimSpace(segment)) @@ -223,9 +331,9 @@ func splitCompositeTypes(name string) []string { continue } segment += string(char) - if char == '<' { + if char == typeOpen { lessCount++ - } else if char == '>' { + } else if char == typeClose { lessCount-- } } @@ -235,20 +343,6 @@ func splitCompositeTypes(name string) []string { return parts } -func apacheToCassandraType(t string) string { - t = strings.Replace(t, apacheCassandraTypePrefix, "", -1) - t = strings.Replace(t, "(", "<", -1) - t = strings.Replace(t, ")", ">", -1) - types := strings.FieldsFunc(t, func(r rune) bool { - return r == '<' || r == '>' || r == ',' - }) - for _, typ := range types { - t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1) - } - // This is done so it exactly matches what Cassandra returns - return strings.Replace(t, ",", ", ", -1) -} - func getApacheCassandraType(class string) Type { switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { case "AsciiType": @@ -297,6 +391,10 @@ func getApacheCassandraType(class string) Type { return TypeTuple case "DurationType": return TypeDuration + case "SimpleDateType": + return TypeDate + case "UserType": + return TypeUDT default: return TypeCustom } diff --git a/helpers_test.go b/helpers_test.go index 67922ba5..275752aa 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -30,7 +30,7 @@ import ( ) func TestGetCassandraType_Set(t *testing.T) { - typ := getCassandraType("set<text>", &defaultLogger{}) + typ := getCassandraType("set<text>", protoVersion4, &defaultLogger{}) set, ok := typ.(CollectionType) if !ok { t.Fatalf("expected CollectionType got %T", typ) @@ -223,11 +223,68 @@ func TestGetCassandraType(t *testing.T) { Elem: NativeType{typ: TypeDuration}, }, }, + { + "vector<float, 3>", VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: NativeType{typ: TypeFloat}, + Dimensions: 3, + }, + }, + { + "vector<vector<float, 3>, 5>", VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: NativeType{typ: TypeFloat}, + Dimensions: 3, + }, + Dimensions: 5, + }, + }, + { + "vector<map<uuid,timestamp>, 5>", VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: CollectionType{ + NativeType: NativeType{typ: TypeMap}, + Key: NativeType{typ: TypeUUID}, + Elem: NativeType{typ: TypeTimestamp}, + }, + Dimensions: 5, + }, + }, + { + "vector<frozen<tuple<int, float>>, 100>", VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: TupleTypeInfo{ + NativeType: NativeType{typ: TypeTuple}, + Elems: []TypeInfo{ + NativeType{typ: TypeInt}, + NativeType{typ: TypeFloat}, + }, + }, + Dimensions: 100, + }, + }, } for _, test := range tests { t.Run(test.input, func(t *testing.T) { - got := getCassandraType(test.input, &defaultLogger{}) + got := getCassandraType(test.input, 0, &defaultLogger{}) // TODO(zariel): define an equal method on the types? if !reflect.DeepEqual(got, test.exp) { diff --git a/marshal.go b/marshal.go index 719a6228..368c7f12 100644 --- a/marshal.go +++ b/marshal.go @@ -172,6 +172,10 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { return marshalDate(info, value) case TypeDuration: return marshalDuration(info, value) + case TypeCustom: + if vector, ok := info.(VectorType); ok { + return marshalVector(vector, value) + } } // detect protocol 2 UDT @@ -276,6 +280,10 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { return unmarshalDate(info, data, value) case TypeDuration: return unmarshalDuration(info, data, value) + case TypeCustom: + if vector, ok := info.(VectorType); ok { + return unmarshalVector(vector, data, value) + } } // detect protocol 2 UDT @@ -1716,6 +1724,165 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: *slice, *array.", info, value) } +func marshalVector(info VectorType, value interface{}) ([]byte, error) { + if value == nil { + return nil, nil + } else if _, ok := value.(unsetColumn); ok { + return nil, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + if k == reflect.Slice && rv.IsNil() { + return nil, nil + } + + switch k { + case reflect.Slice, reflect.Array: + buf := &bytes.Buffer{} + n := rv.Len() + if n != info.Dimensions { + return nil, marshalErrorf("expected vector with %d dimensions, received %d", info.Dimensions, n) + } + + for i := 0; i < n; i++ { + item, err := Marshal(info.SubType, rv.Index(i).Interface()) + if err != nil { + return nil, err + } + if isVectorVariableLengthType(info.SubType) { + writeUnsignedVInt(buf, uint64(len(item))) + } + buf.Write(item) + } + return buf.Bytes(), nil + } + return nil, marshalErrorf("can not marshal %T into %s. Accepted types: slice, array.", value, info) +} + +func unmarshalVector(info VectorType, data []byte, value interface{}) error { + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + t := rv.Type() + k := t.Kind() + switch k { + case reflect.Slice, reflect.Array: + if data == nil { + if k == reflect.Array { + return unmarshalErrorf("unmarshal vector: can not store nil in array value") + } + if rv.IsNil() { + return nil + } + rv.Set(reflect.Zero(t)) + return nil + } + if k == reflect.Array { + if rv.Len() != info.Dimensions { + return unmarshalErrorf("unmarshal vector: array of size %d cannot store vector of %d dimensions", rv.Len(), info.Dimensions) + } + } else { + rv.Set(reflect.MakeSlice(t, info.Dimensions, info.Dimensions)) + } + elemSize := len(data) / info.Dimensions + for i := 0; i < info.Dimensions; i++ { + offset := 0 + if isVectorVariableLengthType(info.SubType) { + m, p, err := readUnsignedVInt(data) + if err != nil { + return err + } + elemSize = int(m) + offset = p + } + if offset > 0 { + data = data[offset:] + } + var unmarshalData []byte + if elemSize >= 0 { + if len(data) < elemSize { + return unmarshalErrorf("unmarshal vector: unexpected eof") + } + unmarshalData = data[:elemSize] + data = data[elemSize:] + } + err := Unmarshal(info.SubType, unmarshalData, rv.Index(i).Addr().Interface()) + if err != nil { + return unmarshalErrorf("failed to unmarshal %s into %T: %s", info.SubType, unmarshalData, err.Error()) + } + } + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: slice, array.", info, value) +} + +// isVectorVariableLengthType determines if a type requires explicit length serialization within a vector. +// Variable-length types need their length encoded before the actual data to allow proper deserialization. +// Fixed-length types, on the other hand, don't require this kind of length prefix. +func isVectorVariableLengthType(elemType TypeInfo) bool { + switch elemType.Type() { + case TypeVarchar, TypeAscii, TypeBlob, TypeText, + TypeCounter, + TypeDuration, TypeDate, TypeTime, + TypeDecimal, TypeSmallInt, TypeTinyInt, TypeVarint, + TypeInet, + TypeList, TypeSet, TypeMap, TypeUDT, TypeTuple: + return true + case TypeCustom: + if vecType, ok := elemType.(VectorType); ok { + return isVectorVariableLengthType(vecType.SubType) + } + return true + } + return false +} + +func writeUnsignedVInt(buf *bytes.Buffer, v uint64) { + numBytes := computeUnsignedVIntSize(v) + if numBytes <= 1 { + buf.WriteByte(byte(v)) + return + } + + extraBytes := numBytes - 1 + var tmp = make([]byte, numBytes) + for i := extraBytes; i >= 0; i-- { + tmp[i] = byte(v) + v >>= 8 + } + tmp[0] |= byte(^(0xff >> uint(extraBytes))) + buf.Write(tmp) +} + +func readUnsignedVInt(data []byte) (uint64, int, error) { + if len(data) <= 0 { + return 0, 0, errors.New("unexpected eof") + } + firstByte := data[0] + if firstByte&0x80 == 0 { + return uint64(firstByte), 1, nil + } + numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 + ret := uint64(firstByte & (0xff >> uint(numBytes))) + if len(data) < numBytes+1 { + return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", numBytes+1, len(data)) + } + for i := 0; i < numBytes; i++ { + ret <<= 8 + ret |= uint64(data[i+1] & 0xff) + } + return ret, numBytes + 1, nil +} + +func computeUnsignedVIntSize(v uint64) int { + lead0 := bits.LeadingZeros64(v) + return (639 - lead0*9) >> 6 +} + func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { mapInfo, ok := info.(CollectionType) if !ok { @@ -2475,7 +2642,11 @@ type NativeType struct { custom string // only used for TypeCustom } -func NewNativeType(proto byte, typ Type, custom string) NativeType { +func NewNativeType(proto byte, typ Type) NativeType { + return NativeType{proto, typ, ""} +} + +func NewCustomType(proto byte, typ Type, custom string) NativeType { return NativeType{proto, typ, custom} } @@ -2514,6 +2685,12 @@ type CollectionType struct { Elem TypeInfo // only used for TypeMap, TypeList and TypeSet } +type VectorType struct { + NativeType + SubType TypeInfo + Dimensions int +} + func (t CollectionType) NewWithError() (interface{}, error) { typ, err := goType(t) if err != nil { diff --git a/marshal_test.go b/marshal_test.go index d1101550..969d8e5f 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -30,6 +30,7 @@ package gocql import ( "bytes" "encoding/binary" + "fmt" "math" "math/big" "net" @@ -2509,6 +2510,37 @@ func TestReadCollectionSize(t *testing.T) { } } +func TestReadUnsignedVInt(t *testing.T) { + tests := []struct { + decodedInt uint64 + encodedVint []byte + }{ + { + decodedInt: 0, + encodedVint: []byte{0}, + }, + { + decodedInt: 100, + encodedVint: []byte{100}, + }, + { + decodedInt: 256000, + encodedVint: []byte{195, 232, 0}, + }, + } + for _, test := range tests { + t.Run(fmt.Sprintf("%d", test.decodedInt), func(t *testing.T) { + actual, _, err := readUnsignedVInt(test.encodedVint) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if actual != test.decodedInt { + t.Fatalf("Expected %d, but got %d", test.decodedInt, actual) + } + }) + } +} + func BenchmarkUnmarshalUUID(b *testing.B) { b.ReportAllocs() src := make([]byte, 16) diff --git a/metadata.go b/metadata.go index 63e27aeb..c7f8e4b9 100644 --- a/metadata.go +++ b/metadata.go @@ -361,13 +361,13 @@ func compileMetadata( col := &columns[i] // decode the validator for TypeInfo and order if col.ClusteringOrder != "" { // Cassandra 3.x+ - col.Type = getCassandraType(col.Validator, logger) + col.Type = getCassandraType(col.Validator, byte(protoVersion), logger) col.Order = ASC if col.ClusteringOrder == "desc" { col.Order = DESC } } else { - validatorParsed := parseType(col.Validator, logger) + validatorParsed := parseType(col.Validator, byte(protoVersion), logger) col.Type = validatorParsed.types[0] col.Order = ASC if validatorParsed.reversed[0] { @@ -389,9 +389,9 @@ func compileMetadata( } if protoVersion == protoVersion1 { - compileV1Metadata(tables, logger) + compileV1Metadata(tables, protoVersion, logger) } else { - compileV2Metadata(tables, logger) + compileV2Metadata(tables, protoVersion, logger) } } @@ -400,14 +400,14 @@ func compileMetadata( // column metadata as V2+ (because V1 doesn't support the "type" column in the // system.schema_columns table) so determining PartitionKey and ClusterColumns // is more complex. -func compileV1Metadata(tables []TableMetadata, logger StdLogger) { +func compileV1Metadata(tables []TableMetadata, protoVer int, logger StdLogger) { for i := range tables { table := &tables[i] // decode the key validator - keyValidatorParsed := parseType(table.KeyValidator, logger) + keyValidatorParsed := parseType(table.KeyValidator, byte(protoVer), logger) // decode the comparator - comparatorParsed := parseType(table.Comparator, logger) + comparatorParsed := parseType(table.Comparator, byte(protoVer), logger) // the partition key length is the same as the number of types in the // key validator @@ -493,7 +493,7 @@ func compileV1Metadata(tables []TableMetadata, logger StdLogger) { alias = table.ValueAlias } // decode the default validator - defaultValidatorParsed := parseType(table.DefaultValidator, logger) + defaultValidatorParsed := parseType(table.DefaultValidator, byte(protoVer), logger) column := &ColumnMetadata{ Keyspace: table.Keyspace, Table: table.Name, @@ -507,7 +507,7 @@ func compileV1Metadata(tables []TableMetadata, logger StdLogger) { } // The simpler compile case for V2+ protocol -func compileV2Metadata(tables []TableMetadata, logger StdLogger) { +func compileV2Metadata(tables []TableMetadata, protoVer int, logger StdLogger) { for i := range tables { table := &tables[i] @@ -515,7 +515,7 @@ func compileV2Metadata(tables []TableMetadata, logger StdLogger) { table.ClusteringColumns = make([]*ColumnMetadata, clusteringColumnCount) if table.KeyValidator != "" { - keyValidatorParsed := parseType(table.KeyValidator, logger) + keyValidatorParsed := parseType(table.KeyValidator, byte(protoVer), logger) table.PartitionKey = make([]*ColumnMetadata, len(keyValidatorParsed.types)) } else { // Cassandra 3.x+ partitionKeyCount := componentColumnCountOfType(table.Columns, ColumnPartitionKey) @@ -925,11 +925,11 @@ func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata, return columns, nil } -func getTypeInfo(t string, logger StdLogger) TypeInfo { +func getTypeInfo(t string, protoVer byte, logger StdLogger) TypeInfo { if strings.HasPrefix(t, apacheCassandraTypePrefix) { - t = apacheToCassandraType(t) + return getCassandraLongType(t, protoVer, logger) } - return getCassandraType(t, logger) + return getCassandraType(t, protoVer, logger) } func getUserTypeMetadata(session *Session, keyspaceName string) ([]UserTypeMetadata, error) { @@ -965,7 +965,7 @@ func getUserTypeMetadata(session *Session, keyspaceName string) ([]UserTypeMetad } uType.FieldTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { - uType.FieldTypes[i] = getTypeInfo(argumentType, session.logger) + uType.FieldTypes[i] = getTypeInfo(argumentType, byte(session.cfg.ProtoVersion), session.logger) } uTypes = append(uTypes, uType) } @@ -1218,10 +1218,10 @@ func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMeta if err != nil { return nil, err } - function.ReturnType = getTypeInfo(returnType, session.logger) + function.ReturnType = getTypeInfo(returnType, byte(session.cfg.ProtoVersion), session.logger) function.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { - function.ArgumentTypes[i] = getTypeInfo(argumentType, session.logger) + function.ArgumentTypes[i] = getTypeInfo(argumentType, byte(session.cfg.ProtoVersion), session.logger) } functions = append(functions, function) } @@ -1275,11 +1275,11 @@ func getAggregatesMetadata(session *Session, keyspaceName string) ([]AggregateMe if err != nil { return nil, err } - aggregate.ReturnType = getTypeInfo(returnType, session.logger) - aggregate.StateType = getTypeInfo(stateType, session.logger) + aggregate.ReturnType = getTypeInfo(returnType, byte(session.cfg.ProtoVersion), session.logger) + aggregate.StateType = getTypeInfo(stateType, byte(session.cfg.ProtoVersion), session.logger) aggregate.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { - aggregate.ArgumentTypes[i] = getTypeInfo(argumentType, session.logger) + aggregate.ArgumentTypes[i] = getTypeInfo(argumentType, byte(session.cfg.ProtoVersion), session.logger) } aggregates = append(aggregates, aggregate) } @@ -1296,6 +1296,7 @@ type typeParser struct { input string index int logger StdLogger + proto byte } // the type definition parser result @@ -1307,8 +1308,8 @@ type typeParserResult struct { } // Parse the type definition used for validator and comparator schema data -func parseType(def string, logger StdLogger) typeParserResult { - parser := &typeParser{input: def, logger: logger} +func parseType(def string, protoVer byte, logger StdLogger) typeParserResult { + parser := &typeParser{input: def, proto: protoVer, logger: logger} return parser.parse() } @@ -1319,6 +1320,9 @@ const ( LIST_TYPE = "org.apache.cassandra.db.marshal.ListType" SET_TYPE = "org.apache.cassandra.db.marshal.SetType" MAP_TYPE = "org.apache.cassandra.db.marshal.MapType" + UDT_TYPE = "org.apache.cassandra.db.marshal.UserType" + TUPLE_TYPE = "org.apache.cassandra.db.marshal.TupleType" + VECTOR_TYPE = "org.apache.cassandra.db.marshal.VectorType" ) // represents a class specification in the type def AST @@ -1327,6 +1331,7 @@ type typeParserClassNode struct { params []typeParserParamNode // this is the segment of the input string that defined this node input string + proto byte } // represents a class parameter in the type def AST @@ -1346,6 +1351,7 @@ func (t *typeParser) parse() typeParserResult { NativeType{ typ: TypeCustom, custom: t.input, + proto: t.proto, }, }, reversed: []bool{false}, @@ -1423,7 +1429,8 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { elem := class.params[0].class.asTypeInfo() return CollectionType{ NativeType: NativeType{ - typ: TypeList, + typ: TypeList, + proto: class.proto, }, Elem: elem, } @@ -1432,7 +1439,8 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { elem := class.params[0].class.asTypeInfo() return CollectionType{ NativeType: NativeType{ - typ: TypeSet, + typ: TypeSet, + proto: class.proto, }, Elem: elem, } @@ -1442,7 +1450,8 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { elem := class.params[1].class.asTypeInfo() return CollectionType{ NativeType: NativeType{ - typ: TypeMap, + typ: TypeMap, + proto: class.proto, }, Key: key, Elem: elem, @@ -1450,7 +1459,7 @@ func (class *typeParserClassNode) asTypeInfo() TypeInfo { } // must be a simple type or custom type - info := NativeType{typ: getApacheCassandraType(class.name)} + info := NativeType{typ: getApacheCassandraType(class.name), proto: class.proto} if info.typ == TypeCustom { // add the entire class definition info.custom = class.input @@ -1480,6 +1489,7 @@ func (t *typeParser) parseClassNode() (node *typeParserClassNode, ok bool) { name: name, params: params, input: t.input[startIndex:endIndex], + proto: t.proto, } return node, true } diff --git a/metadata_test.go b/metadata_test.go index 6e3633cc..6b3d1198 100644 --- a/metadata_test.go +++ b/metadata_test.go @@ -636,12 +636,14 @@ func TestTypeParser(t *testing.T) { }, ) - // custom + // udt assertParseNonCompositeType( t, "org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)", - assertTypeInfo{Type: TypeCustom, Custom: "org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)"}, + assertTypeInfo{Type: TypeUDT, Custom: ""}, ) + + // custom assertParseNonCompositeType( t, "org.apache.cassandra.db.marshal.DynamicCompositeType(u=>org.apache.cassandra.db.marshal.UUIDType,d=>org.apache.cassandra.db.marshal.DateType,t=>org.apache.cassandra.db.marshal.TimeUUIDType,b=>org.apache.cassandra.db.marshal.BytesType,s=>org.apache.cassandra.db.marshal.UTF8Type,B=>org.apache.cassandra.db.marshal.BooleanType,a=>org.apache.cassandra.db.marshal.AsciiType,l=>org.apache.cassandra.db.marshal.LongType,i=>org.apache.cassandra.db.marshal.IntegerType,x=>org.apache.cassandra.db.m [...] @@ -700,7 +702,7 @@ func assertParseNonCompositeType( ) { log := &defaultLogger{} - result := parseType(def, log) + result := parseType(def, protoVersion4, log) if len(result.reversed) != 1 { t.Errorf("%s expected %d reversed values but there were %d", def, 1, len(result.reversed)) } @@ -731,7 +733,7 @@ func assertParseCompositeType( ) { log := &defaultLogger{} - result := parseType(def, log) + result := parseType(def, protoVersion4, log) if len(result.reversed) != len(typesExpected) { t.Errorf("%s expected %d reversed values but there were %d", def, len(typesExpected), len(result.reversed)) } @@ -747,7 +749,7 @@ func assertParseCompositeType( if !result.isComposite { t.Errorf("%s: Expected composite", def) } - if result.collections == nil { + if result.collections == nil && collectionsExpected != nil { t.Errorf("%s: Expected non-nil collections: %v", def, result.collections) } diff --git a/vector_test.go b/vector_test.go new file mode 100644 index 00000000..4e52a885 --- /dev/null +++ b/vector_test.go @@ -0,0 +1,401 @@ +//go:build all || cassandra +// +build all cassandra + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 + * Copyright (c) 2016, The Gocql authors, + * provided under the BSD-3-Clause License. + * See the NOTICE file distributed with this work for additional information. + */ + +package gocql + +import ( + "fmt" + "github.com/stretchr/testify/require" + "gopkg.in/inf.v0" + "net" + "reflect" + "testing" + "time" +) + +type person struct { + FirstName string `cql:"first_name"` + LastName string `cql:"last_name"` + Age int `cql:"age"` +} + +func (p person) String() string { + return fmt.Sprintf("Person{firstName: %s, lastName: %s, Age: %d}", p.FirstName, p.LastName, p.Age) +} + +func TestVector_Marshaler(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + err := createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_fixed(id int primary key, vec vector<float, 3>);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_variable(id int primary key, vec vector<text, 4>);`) + if err != nil { + t.Fatal(err) + } + + insertFixVec := []float32{8, 2.5, -5.0} + err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, insertFixVec).Exec() + if err != nil { + t.Fatal(err) + } + var selectFixVec []float32 + err = session.Query("SELECT vec FROM vector_fixed WHERE id = ?", 1).Scan(&selectFixVec) + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "fixed size element vector", insertFixVec, selectFixVec) + + longText := randomText(500) + insertVarVec := []string{"apache", "cassandra", longText, "gocql"} + err = session.Query("INSERT INTO vector_variable(id, vec) VALUES(?, ?)", 1, insertVarVec).Exec() + if err != nil { + t.Fatal(err) + } + var selectVarVec []string + err = session.Query("SELECT vec FROM vector_variable WHERE id = ?", 1).Scan(&selectVarVec) + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "variable size element vector", insertVarVec, selectVarVec) +} + +func TestVector_Types(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + timestamp1, _ := time.Parse("2006-01-02", "2000-01-01") + timestamp2, _ := time.Parse("2006-01-02 15:04:05", "2024-01-01 10:31:45") + timestamp3, _ := time.Parse("2006-01-02 15:04:05.000", "2024-05-01 10:31:45.987") + + date1, _ := time.Parse("2006-01-02", "2000-01-01") + date2, _ := time.Parse("2006-01-02", "2022-03-14") + date3, _ := time.Parse("2006-01-02", "2024-12-31") + + time1, _ := time.Parse("15:04:05", "01:00:00") + time2, _ := time.Parse("15:04:05", "15:23:59") + time3, _ := time.Parse("15:04:05.000", "10:31:45.987") + + duration1 := Duration{0, 1, 1920000000000} + duration2 := Duration{1, 1, 1920000000000} + duration3 := Duration{31, 0, 60000000000} + + map1 := make(map[string]int) + map1["a"] = 1 + map1["b"] = 2 + map1["c"] = 3 + map2 := make(map[string]int) + map2["abc"] = 123 + map3 := make(map[string]int) + + tests := []struct { + name string + cqlType string + value interface{} + comparator func(interface{}, interface{}) + }{ + {name: "ascii", cqlType: TypeAscii.String(), value: []string{"a", "1", "Z"}}, + {name: "bigint", cqlType: TypeBigInt.String(), value: []int64{1, 2, 3}}, + {name: "blob", cqlType: TypeBlob.String(), value: [][]byte{[]byte{1, 2, 3}, []byte{4, 5, 6, 7}, []byte{8, 9}}}, + {name: "boolean", cqlType: TypeBoolean.String(), value: []bool{true, false, true}}, + {name: "counter", cqlType: TypeCounter.String(), value: []int64{5, 6, 7}}, + {name: "decimal", cqlType: TypeDecimal.String(), value: []inf.Dec{*inf.NewDec(1, 0), *inf.NewDec(2, 1), *inf.NewDec(-3, 2)}}, + {name: "double", cqlType: TypeDouble.String(), value: []float64{0.1, -1.2, 3}}, + {name: "float", cqlType: TypeFloat.String(), value: []float32{0.1, -1.2, 3}}, + {name: "int", cqlType: TypeInt.String(), value: []int32{1, 2, 3}}, + {name: "text", cqlType: TypeText.String(), value: []string{"a", "b", "c"}}, + {name: "timestamp", cqlType: TypeTimestamp.String(), value: []time.Time{timestamp1, timestamp2, timestamp3}}, + {name: "uuid", cqlType: TypeUUID.String(), value: []UUID{MustRandomUUID(), MustRandomUUID(), MustRandomUUID()}}, + {name: "varchar", cqlType: TypeVarchar.String(), value: []string{"abc", "def", "ghi"}}, + {name: "varint", cqlType: TypeVarint.String(), value: []uint64{uint64(1234), uint64(123498765), uint64(18446744073709551615)}}, + {name: "timeuuid", cqlType: TypeTimeUUID.String(), value: []UUID{TimeUUID(), TimeUUID(), TimeUUID()}}, + { + name: "inet", + cqlType: TypeInet.String(), + value: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv4(192, 168, 1, 1), net.IPv4(8, 8, 8, 8)}, + comparator: func(e interface{}, a interface{}) { + expected := e.([]net.IP) + actual := a.([]net.IP) + assertEqual(t, "vector size", len(expected), len(actual)) + for i, _ := range expected { + assertTrue(t, "vector", expected[i].Equal(actual[i])) + } + }, + }, + {name: "date", cqlType: TypeDate.String(), value: []time.Time{date1, date2, date3}}, + {name: "time", cqlType: TypeTimestamp.String(), value: []time.Time{time1, time2, time3}}, + {name: "smallint", cqlType: TypeSmallInt.String(), value: []int16{127, 256, -1234}}, + {name: "tinyint", cqlType: TypeTinyInt.String(), value: []int8{127, 9, -123}}, + {name: "duration", cqlType: TypeDuration.String(), value: []Duration{duration1, duration2, duration3}}, + {name: "vector_vector_float", cqlType: "vector<float, 5>", value: [][]float32{{0.1, -1.2, 3, 5, 5}, {10.1, -122222.0002, 35.0, 1, 1}, {0, 0, 0, 0, 0}}}, + {name: "vector_vector_set_float", cqlType: "vector<set<float>, 5>", value: [][][]float32{ + {{1, 2}, {2, -1}, {3}, {0}, {-1.3}}, + {{2, 3}, {2, -1}, {3}, {0}, {-1.3}}, + {{1, 1000.0}, {0}, {}, {12, 14, 15, 16}, {-1.3}}, + }}, + {name: "vector_tuple_text_int_float", cqlType: "tuple<text, int, float>", value: [][]interface{}{{"a", 1, float32(0.5)}, {"b", 2, float32(-1.2)}, {"c", 3, float32(0)}}}, + {name: "vector_tuple_text_list_text", cqlType: "tuple<text, list<text>>", value: [][]interface{}{{"a", []string{"b", "c"}}, {"d", []string{"e", "f", "g"}}, {"h", []string{"i"}}}}, + {name: "vector_set_text", cqlType: "set<text>", value: [][]string{{"a", "b"}, {"c", "d"}, {"e", "f"}}}, + {name: "vector_list_int", cqlType: "list<int>", value: [][]int32{{1, 2, 3}, {-1, -2, -3}, {0, 0, 0}}}, + {name: "vector_map_text_int", cqlType: "map<text, int>", value: []map[string]int{map1, map2, map3}}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tableName := fmt.Sprintf("vector_%s", test.name) + err := createTable(session, fmt.Sprintf(`CREATE TABLE IF NOT EXISTS gocql_test.%s(id int primary key, vec vector<%s, 3>);`, tableName, test.cqlType)) + if err != nil { + t.Fatal(err) + } + + err = session.Query(fmt.Sprintf("INSERT INTO %s(id, vec) VALUES(?, ?)", tableName), 1, test.value).Exec() + if err != nil { + t.Fatal(err) + } + + v := reflect.New(reflect.TypeOf(test.value)) + err = session.Query(fmt.Sprintf("SELECT vec FROM %s WHERE id = ?", tableName), 1).Scan(v.Interface()) + if err != nil { + t.Fatal(err) + } + if test.comparator != nil { + test.comparator(test.value, v.Elem().Interface()) + } else { + assertDeepEqual(t, "vector", test.value, v.Elem().Interface()) + } + }) + } +} + +func TestVector_MarshalerUDT(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + err := createTable(session, `CREATE TYPE gocql_test.person( + first_name text, + last_name text, + age int);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE gocql_test.vector_relatives( + id int, + couple vector<person, 2>, + primary key(id) + );`) + if err != nil { + t.Fatal(err) + } + + p1 := person{"Johny", "Bravo", 25} + p2 := person{"Capitan", "Planet", 5} + insVec := []person{p1, p2} + + err = session.Query("INSERT INTO vector_relatives(id, couple) VALUES(?, ?)", 1, insVec).Exec() + if err != nil { + t.Fatal(err) + } + + var selVec []person + + err = session.Query("SELECT couple FROM vector_relatives WHERE id = ?", 1).Scan(&selVec) + if err != nil { + t.Fatal(err) + } + + assertDeepEqual(t, "udt", &insVec, &selVec) +} + +func TestVector_Empty(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + err := createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_fixed_null(id int primary key, vec vector<float, 3>);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_variable_null(id int primary key, vec vector<text, 4>);`) + if err != nil { + t.Fatal(err) + } + + err = session.Query("INSERT INTO vector_fixed_null(id) VALUES(?)", 1).Exec() + if err != nil { + t.Fatal(err) + } + var selectFixVec []float32 + err = session.Query("SELECT vec FROM vector_fixed_null WHERE id = ?", 1).Scan(&selectFixVec) + if err != nil { + t.Fatal(err) + } + assertTrue(t, "fixed size element vector is empty", selectFixVec == nil) + + err = session.Query("INSERT INTO vector_variable_null(id) VALUES(?)", 1).Exec() + if err != nil { + t.Fatal(err) + } + var selectVarVec []string + err = session.Query("SELECT vec FROM vector_variable_null WHERE id = ?", 1).Scan(&selectVarVec) + if err != nil { + t.Fatal(err) + } + assertTrue(t, "variable size element vector is empty", selectVarVec == nil) +} + +func TestVector_MissingDimension(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + err := createTable(session, `CREATE TABLE IF NOT EXISTS gocql_test.vector_fixed(id int primary key, vec vector<float, 3>);`) + if err != nil { + t.Fatal(err) + } + + err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, []float32{8, -5.0}).Exec() + require.Error(t, err, "expected vector with 3 dimensions, received 2") + + err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, []float32{8, -5.0, 1, 3}).Exec() + require.Error(t, err, "expected vector with 3 dimensions, received 4") +} + +func TestVector_SubTypeParsing(t *testing.T) { + tests := []struct { + name string + custom string + expected TypeInfo + }{ + {name: "text", custom: "org.apache.cassandra.db.marshal.UTF8Type", expected: NativeType{typ: TypeVarchar}}, + {name: "set_int", custom: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.Int32Type)", expected: CollectionType{NativeType{typ: TypeSet}, nil, NativeType{typ: TypeInt}}}, + { + name: "udt", + custom: "org.apache.cassandra.db.marshal.UserType(gocql_test,706572736f6e,66697273745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,6c6173745f6e616d65:org.apache.cassandra.db.marshal.UTF8Type,616765:org.apache.cassandra.db.marshal.Int32Type)", + expected: UDTTypeInfo{ + NativeType{typ: TypeUDT}, + "gocql_test", + "person", + []UDTField{ + UDTField{"first_name", NativeType{typ: TypeVarchar}}, + UDTField{"last_name", NativeType{typ: TypeVarchar}}, + UDTField{"age", NativeType{typ: TypeInt}}, + }, + }, + }, + { + name: "tuple", + custom: "org.apache.cassandra.db.marshal.TupleType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.UTF8Type)", + expected: TupleTypeInfo{ + NativeType{typ: TypeTuple}, + []TypeInfo{ + NativeType{typ: TypeVarchar}, + NativeType{typ: TypeInt}, + NativeType{typ: TypeVarchar}, + }, + }, + }, + { + name: "vector_vector_inet", + custom: "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.InetAddressType, 2), 3)", + expected: VectorType{ + NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, + VectorType{ + NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, + NativeType{typ: TypeInet}, + 2, + }, + 3, + }, + }, + { + name: "map_int_vector_text", + custom: "org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 10))", + expected: CollectionType{ + NativeType{typ: TypeMap}, + NativeType{typ: TypeInt}, + VectorType{ + NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, + NativeType{typ: TypeVarchar}, + 10, + }, + }, + }, + { + name: "set_map_vector_text_text", + custom: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.Int32Type, 10),org.apache.cassandra.db.marshal.UTF8Type))", + expected: CollectionType{ + NativeType{typ: TypeSet}, + nil, + CollectionType{ + NativeType{typ: TypeMap}, + VectorType{ + NativeType{typ: TypeCustom, custom: VECTOR_TYPE}, + NativeType{typ: TypeInt}, + 10, + }, + NativeType{typ: TypeVarchar}, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f := newFramer(nil, 0) + f.writeShort(0) + f.writeString(fmt.Sprintf("org.apache.cassandra.db.marshal.VectorType(%s, 2)", test.custom)) + parsedType := f.readTypeInfo() + require.IsType(t, parsedType, VectorType{}) + vectorType := parsedType.(VectorType) + assertEqual(t, "dimensions", 2, vectorType.Dimensions) + assertDeepEqual(t, "vector", test.expected, vectorType.SubType) + }) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org For additional commands, e-mail: commits-h...@cassandra.apache.org