This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 8a4d611133 ARROW-16323: [Go] Implement Dictionary Scalars (#13575)
8a4d611133 is described below

commit 8a4d611133fbc337e48053b14539a37783614f91
Author: Matt Topol <[email protected]>
AuthorDate: Wed Jul 20 10:41:39 2022 -0400

    ARROW-16323: [Go] Implement Dictionary Scalars (#13575)
    
    Lead-authored-by: Matt Topol <[email protected]>
    Co-authored-by: Matthew Topol <[email protected]>
    Signed-off-by: Matthew Topol <[email protected]>
---
 go/arrow/array.go              |   3 +
 go/arrow/array/extension.go    |   4 +
 go/arrow/datatype.go           |  10 +++
 go/arrow/scalar/nested.go      | 184 +++++++++++++++++++++++++++++++++++++++++
 go/arrow/scalar/parse.go       |  15 ++++
 go/arrow/scalar/scalar.go      |  18 +++-
 go/arrow/scalar/scalar_test.go | 165 ++++++++++++++++++++++++++++++++++++
 7 files changed, 398 insertions(+), 1 deletion(-)

diff --git a/go/arrow/array.go b/go/arrow/array.go
index 0742ae11f0..d8983c29f2 100644
--- a/go/arrow/array.go
+++ b/go/arrow/array.go
@@ -18,6 +18,7 @@ package arrow
 
 import (
        "encoding/json"
+       "fmt"
 
        "github.com/apache/arrow/go/v9/arrow/memory"
 )
@@ -86,6 +87,8 @@ type ArrayData interface {
 type Array interface {
        json.Marshaler
 
+       fmt.Stringer
+
        // DataType returns the type metadata for this instance.
        DataType() DataType
 
diff --git a/go/arrow/array/extension.go b/go/arrow/array/extension.go
index 94babc3829..3fb3e77b85 100644
--- a/go/arrow/array/extension.go
+++ b/go/arrow/array/extension.go
@@ -130,6 +130,10 @@ type ExtensionArrayBase struct {
        storage arraymarshal
 }
 
+func (e *ExtensionArrayBase) String() string {
+       return fmt.Sprintf("(%s)%s", e.data.dtype, e.storage)
+}
+
 func (e *ExtensionArrayBase) getOneForMarshal(i int) interface{} {
        return e.storage.getOneForMarshal(i)
 }
diff --git a/go/arrow/datatype.go b/go/arrow/datatype.go
index ed6e803aea..1503f655e7 100644
--- a/go/arrow/datatype.go
+++ b/go/arrow/datatype.go
@@ -255,6 +255,16 @@ func IsInteger(t Type) bool {
        return false
 }
 
+// IsUnsignedInteger is a helper that returns true if the type ID provided is
+// one of the uint integral types (uint8, uint16, uint32, uint64)
+func IsUnsignedInteger(t Type) bool {
+       switch t {
+       case UINT8, UINT16, UINT32, UINT64:
+               return true
+       }
+       return false
+}
+
 // IsPrimitive returns true if the provided type ID represents a fixed width
 // primitive type.
 func IsPrimitive(t Type) bool {
diff --git a/go/arrow/scalar/nested.go b/go/arrow/scalar/nested.go
index fb86dbe7e1..0a3bedad0b 100644
--- a/go/arrow/scalar/nested.go
+++ b/go/arrow/scalar/nested.go
@@ -18,6 +18,7 @@ package scalar
 
 import (
        "bytes"
+       "errors"
        "fmt"
 
        "github.com/apache/arrow/go/v9/arrow"
@@ -322,3 +323,186 @@ func NewStructScalarWithNames(val []Scalar, names 
[]string) (*Struct, error) {
        }
        return NewStructScalar(val, arrow.StructOf(fields...)), nil
 }
+
+type Dictionary struct {
+       scalar
+
+       Value struct {
+               Index Scalar
+               Dict  arrow.Array
+       }
+}
+
+func NewNullDictScalar(dt arrow.DataType) *Dictionary {
+       ret := &Dictionary{scalar: scalar{dt, false}}
+       ret.Value.Index = MakeNullScalar(dt.(*arrow.DictionaryType).IndexType)
+       ret.Value.Dict = nil
+       return ret
+}
+
+func NewDictScalar(index Scalar, dict arrow.Array) *Dictionary {
+       ret := &Dictionary{scalar: scalar{&arrow.DictionaryType{IndexType: 
index.DataType(), ValueType: dict.DataType()}, index.IsValid()}}
+       ret.Value.Index = index
+       ret.Value.Dict = dict
+       ret.Retain()
+       return ret
+}
+
+func (s *Dictionary) Retain() {
+       if r, ok := s.Value.Index.(Releasable); ok {
+               r.Retain()
+       }
+       if s.Value.Dict != (arrow.Array)(nil) {
+               s.Value.Dict.Retain()
+       }
+}
+
+func (s *Dictionary) Release() {
+       if r, ok := s.Value.Index.(Releasable); ok {
+               r.Release()
+       }
+       if s.Value.Dict != (arrow.Array)(nil) {
+               s.Value.Dict.Release()
+       }
+}
+
+func (s *Dictionary) Validate() (err error) {
+       dt, ok := s.Type.(*arrow.DictionaryType)
+       if !ok {
+               return errors.New("arrow/scalar: dictionary scalar should have 
type Dictionary")
+       }
+
+       if s.Value.Index == (Scalar)(nil) {
+               return fmt.Errorf("%s scalar doesn't have an index value", dt)
+       }
+
+       if err = s.Value.Index.Validate(); err != nil {
+               return fmt.Errorf("%s scalar fails validation for index value: 
%w", dt, err)
+       }
+
+       if !arrow.TypeEqual(s.Value.Index.DataType(), dt.IndexType) {
+               return fmt.Errorf("%s scalar should have an index value of type 
%s, got %s",
+                       dt, dt.IndexType, s.Value.Index.DataType())
+       }
+
+       if s.IsValid() && !s.Value.Index.IsValid() {
+               return fmt.Errorf("non-null %s scalar has null index value", dt)
+       }
+
+       if !s.IsValid() && s.Value.Index.IsValid() {
+               return fmt.Errorf("null %s scalar has non-null index value", dt)
+       }
+
+       if !s.IsValid() {
+               return
+       }
+
+       if s.Value.Dict == (arrow.Array)(nil) {
+               return fmt.Errorf("%s scalar doesn't have a dictionary value", 
dt)
+       }
+
+       if !arrow.TypeEqual(s.Value.Dict.DataType(), dt.ValueType) {
+               return fmt.Errorf("%s scalar's value type doesn't match dict 
type: got %s", dt, s.Value.Dict.DataType())
+       }
+
+       return
+}
+
+func (s *Dictionary) ValidateFull() (err error) {
+       if err = s.Validate(); err != nil {
+               return
+       }
+
+       if !s.Value.Index.IsValid() {
+               return nil
+       }
+
+       max := s.Value.Dict.Len() - 1
+       switch idx := s.Value.Index.value().(type) {
+       case int8:
+               if idx < 0 || int(idx) > max {
+                       err = fmt.Errorf("%s scalar index value out of bounds: 
%d", s.DataType(), idx)
+               }
+       case uint8:
+               if int(idx) > max {
+                       err = fmt.Errorf("%s scalar index value out of bounds: 
%d", s.DataType(), idx)
+               }
+       case int16:
+               if idx < 0 || int(idx) > max {
+                       err = fmt.Errorf("%s scalar index value out of bounds: 
%d", s.DataType(), idx)
+               }
+       case uint16:
+               if int(idx) > max {
+                       err = fmt.Errorf("%s scalar index value out of bounds: 
%d", s.DataType(), idx)
+               }
+       case int32:
+               if idx < 0 || int(idx) > max {
+                       err = fmt.Errorf("%s scalar index value out of bounds: 
%d", s.DataType(), idx)
+               }
+       case uint32:
+               if int(idx) > max {
+                       err = fmt.Errorf("%s scalar index value out of bounds: 
%d", s.DataType(), idx)
+               }
+       case int64:
+               if idx < 0 || int(idx) > max {
+                       err = fmt.Errorf("%s scalar index value out of bounds: 
%d", s.DataType(), idx)
+               }
+       case uint64:
+               if int(idx) > max {
+                       err = fmt.Errorf("%s scalar index value out of bounds: 
%d", s.DataType(), idx)
+               }
+       }
+
+       return
+}
+
+func (s *Dictionary) String() string {
+       if !s.Valid {
+               return "null"
+       }
+
+       return s.Value.Dict.String() + "[" + s.Value.Index.String() + "]"
+}
+
+func (s *Dictionary) equals(rhs Scalar) bool {
+       return s.Value.Index.equals(rhs.(*Dictionary).Value.Index) &&
+               array.Equal(s.Value.Dict, rhs.(*Dictionary).Value.Dict)
+}
+
+func (s *Dictionary) CastTo(arrow.DataType) (Scalar, error) {
+       return nil, fmt.Errorf("cast from scalar %s not implemented", 
s.DataType())
+}
+
+func (s *Dictionary) GetEncodedValue() (Scalar, error) {
+       dt := s.Type.(*arrow.DictionaryType)
+       if !s.IsValid() {
+               return MakeNullScalar(dt.ValueType), nil
+       }
+
+       var idxValue int
+       switch dt.IndexType.ID() {
+       case arrow.INT8:
+               idxValue = int(s.Value.Index.value().(int8))
+       case arrow.UINT8:
+               idxValue = int(s.Value.Index.value().(uint8))
+       case arrow.INT16:
+               idxValue = int(s.Value.Index.value().(int16))
+       case arrow.UINT16:
+               idxValue = int(s.Value.Index.value().(uint16))
+       case arrow.INT32:
+               idxValue = int(s.Value.Index.value().(int32))
+       case arrow.UINT32:
+               idxValue = int(s.Value.Index.value().(uint32))
+       case arrow.INT64:
+               idxValue = int(s.Value.Index.value().(int64))
+       case arrow.UINT64:
+               idxValue = int(s.Value.Index.value().(uint64))
+       default:
+               return nil, fmt.Errorf("unimplemented dictionary type %s", 
dt.IndexType)
+       }
+       return GetScalar(s.Value.Dict, idxValue)
+}
+
+func (s *Dictionary) value() interface{} {
+       return s.Value.Index.value()
+}
diff --git a/go/arrow/scalar/parse.go b/go/arrow/scalar/parse.go
index 805ade427d..9680589bd2 100644
--- a/go/arrow/scalar/parse.go
+++ b/go/arrow/scalar/parse.go
@@ -420,6 +420,19 @@ func MakeScalarParam(val interface{}, dt arrow.DataType) 
(Scalar, error) {
                        return NewMapScalar(v), nil
                }
        }
+
+       if arrow.IsInteger(dt.ID()) {
+               bits := dt.(arrow.FixedWidthDataType).BitWidth()
+               val := reflect.ValueOf(val)
+               if arrow.IsUnsignedInteger(dt.ID()) {
+                       return 
MakeUnsignedIntegerScalar(val.Convert(reflect.TypeOf(uint64(0))).Uint(), bits)
+               }
+               return 
MakeIntegerScalar(val.Convert(reflect.TypeOf(int64(0))).Int(), bits)
+       }
+
+       if dt.ID() == arrow.DICTIONARY {
+               return MakeScalarParam(val, 
dt.(*arrow.DictionaryType).ValueType)
+       }
        return MakeScalar(val), nil
 }
 
@@ -631,6 +644,8 @@ func ParseScalar(dt arrow.DataType, val string) (Scalar, 
error) {
                }
 
                return NewTime64Scalar(tm, dt), nil
+       case arrow.DICTIONARY:
+               return ParseScalar(dt.(*arrow.DictionaryType).ValueType, val)
        }
 
        return nil, fmt.Errorf("parsing of scalar for type %s not implemented", 
dt)
diff --git a/go/arrow/scalar/scalar.go b/go/arrow/scalar/scalar.go
index 9f75869cf7..db1ec4da44 100644
--- a/go/arrow/scalar/scalar.go
+++ b/go/arrow/scalar/scalar.go
@@ -444,7 +444,7 @@ func init() {
                arrow.STRUCT:                  func(dt arrow.DataType) Scalar { 
return &Struct{scalar: scalar{dt, false}} },
                arrow.SPARSE_UNION:            unsupportedScalarType,
                arrow.DENSE_UNION:             unsupportedScalarType,
-               arrow.DICTIONARY:              unsupportedScalarType,
+               arrow.DICTIONARY:              func(dt arrow.DataType) Scalar { 
return NewNullDictScalar(dt) },
                arrow.LARGE_STRING:            unsupportedScalarType,
                arrow.LARGE_BINARY:            unsupportedScalarType,
                arrow.LARGE_LIST:              unsupportedScalarType,
@@ -558,6 +558,18 @@ func GetScalar(arr arrow.Array, idx int) (Scalar, error) {
                return NewTime64Scalar(arr.Value(idx), arr.DataType()), nil
        case *array.Timestamp:
                return NewTimestampScalar(arr.Value(idx), arr.DataType()), nil
+       case *array.Dictionary:
+               ty := arr.DataType().(*arrow.DictionaryType)
+               index, err := MakeScalarParam(arr.GetValueIndex(idx), 
ty.IndexType)
+               if err != nil {
+                       return nil, err
+               }
+
+               scalar := &Dictionary{scalar: scalar{ty, arr.IsValid(idx)}}
+               scalar.Value.Index = index
+               scalar.Value.Dict = arr.Dictionary()
+               scalar.Value.Dict.Retain()
+               return scalar, nil
        }
 
        return nil, fmt.Errorf("cannot create scalar from array of type %s", 
arr.DataType())
@@ -819,6 +831,10 @@ func Hash(seed maphash.Seed, s Scalar) uint64 {
                                out ^= Hash(seed, c)
                        }
                }
+       case *Dictionary:
+               if s.Value.Index.IsValid() {
+                       out ^= Hash(seed, s.Value.Index)
+               }
        }
 
        return out
diff --git a/go/arrow/scalar/scalar_test.go b/go/arrow/scalar/scalar_test.go
index 22290664b3..9b4f458c44 100644
--- a/go/arrow/scalar/scalar_test.go
+++ b/go/arrow/scalar/scalar_test.go
@@ -18,8 +18,10 @@ package scalar_test
 
 import (
        "bytes"
+       "fmt"
        "hash/maphash"
        "math/bits"
+       "strings"
        "testing"
        "time"
 
@@ -978,3 +980,166 @@ func TestToScalar(t *testing.T) {
 
        assert.Equal(t, expected, sc.String())
 }
+
+var dictIndexTypes = []arrow.DataType{
+       arrow.PrimitiveTypes.Int8,
+       arrow.PrimitiveTypes.Uint8,
+       arrow.PrimitiveTypes.Int16,
+       arrow.PrimitiveTypes.Uint16,
+       arrow.PrimitiveTypes.Int32,
+       arrow.PrimitiveTypes.Uint32,
+       arrow.PrimitiveTypes.Int64,
+       arrow.PrimitiveTypes.Uint64,
+}
+
+func TestDictionaryScalarBasics(t *testing.T) {
+       for _, indexType := range dictIndexTypes {
+               t.Run(fmt.Sprint(indexType), func(t *testing.T) {
+                       mem := 
memory.NewCheckedAllocator(memory.DefaultAllocator)
+                       defer mem.AssertSize(t, 0)
+
+                       ty := &arrow.DictionaryType{IndexType: indexType, 
ValueType: arrow.BinaryTypes.String}
+                       dict, _, _ := array.FromJSON(mem, 
arrow.BinaryTypes.String, strings.NewReader(`["alpha", null, "gamma"]`))
+                       defer dict.Release()
+
+                       idxScalar, _ := scalar.MakeScalarParam(0, indexType)
+                       alpha := scalar.NewDictScalar(idxScalar, dict)
+                       defer alpha.Release()
+
+                       idxScalar, _ = scalar.MakeScalarParam(2, indexType)
+                       gamma := scalar.NewDictScalar(idxScalar, dict)
+                       defer gamma.Release()
+
+                       idxScalar, _ = scalar.MakeScalarParam(1, indexType)
+                       nullVal := scalar.NewDictScalar(idxScalar, dict)
+                       defer nullVal.Release()
+
+                       scalarNull := scalar.MakeNullScalar(ty)
+                       scalarNull.(*scalar.Dictionary).Value.Dict = dict
+                       dict.Retain()
+                       defer scalarNull.(*scalar.Dictionary).Release()
+
+                       assert.NoError(t, scalarNull.ValidateFull())
+                       assert.NoError(t, alpha.ValidateFull())
+                       assert.NoError(t, gamma.ValidateFull())
+
+                       // index is valid, corresponding value is null
+                       assert.NoError(t, nullVal.ValidateFull())
+
+                       encodedNull, err := 
scalarNull.(*scalar.Dictionary).GetEncodedValue()
+                       assert.NoError(t, err)
+                       assert.NoError(t, encodedNull.ValidateFull())
+                       assert.True(t, scalar.Equals(encodedNull, 
scalar.MakeNullScalar(arrow.BinaryTypes.String)))
+
+                       encodedNullVal, err := nullVal.GetEncodedValue()
+                       assert.NoError(t, err)
+                       assert.NoError(t, encodedNullVal.ValidateFull())
+                       assert.True(t, scalar.Equals(encodedNullVal, 
scalar.MakeNullScalar(arrow.BinaryTypes.String)))
+
+                       encodedAlpha, err := alpha.GetEncodedValue()
+                       assert.NoError(t, err)
+                       assert.NoError(t, encodedAlpha.ValidateFull())
+                       assert.True(t, scalar.Equals(encodedAlpha, 
scalar.MakeScalar("alpha")))
+
+                       encodedGamma, err := gamma.GetEncodedValue()
+                       assert.NoError(t, err)
+                       assert.NoError(t, encodedGamma.ValidateFull())
+                       assert.True(t, scalar.Equals(encodedGamma, 
scalar.MakeScalar("gamma")))
+
+                       idxArr, _, _ := array.FromJSON(mem, indexType, 
strings.NewReader(`[2, 0, 1, null]`))
+                       defer idxArr.Release()
+                       arr := array.NewDictionaryArray(ty, idxArr, dict)
+                       defer arr.Release()
+
+                       first, err := scalar.GetScalar(arr, 0)
+                       assert.NoError(t, err)
+                       second, err := scalar.GetScalar(arr, 1)
+                       assert.NoError(t, err)
+                       third, err := scalar.GetScalar(arr, 2)
+                       assert.NoError(t, err)
+                       last, err := scalar.GetScalar(arr, 3)
+                       assert.NoError(t, err)
+
+                       defer func() {
+                               first.(*scalar.Dictionary).Release()
+                               second.(*scalar.Dictionary).Release()
+                               third.(*scalar.Dictionary).Release()
+                               last.(*scalar.Dictionary).Release()
+                       }()
+
+                       assert.NoError(t, first.ValidateFull())
+                       assert.NoError(t, second.ValidateFull())
+                       assert.NoError(t, third.ValidateFull())
+                       assert.NoError(t, last.ValidateFull())
+
+                       assert.True(t, first.IsValid())
+                       assert.True(t, second.IsValid())
+                       assert.True(t, third.IsValid()) // valid because of 
valid index despite null value
+                       assert.False(t, last.IsValid())
+
+                       assert.True(t, scalar.Equals(first, gamma))
+                       assert.True(t, scalar.Equals(second, alpha))
+                       assert.True(t, scalar.Equals(third, nullVal))
+                       assert.True(t, scalar.Equals(last, scalarNull))
+
+                       assert.Same(t, first.(*scalar.Dictionary).Value.Dict, 
arr.Dictionary())
+                       assert.Same(t, second.(*scalar.Dictionary).Value.Dict, 
arr.Dictionary())
+               })
+       }
+}
+
+func TestDictionaryScalarValidateErrors(t *testing.T) {
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+       defer mem.AssertSize(t, 0)
+
+       var (
+               indexTy = arrow.PrimitiveTypes.Int16
+               valueTy = arrow.BinaryTypes.String
+               dictTy  = &arrow.DictionaryType{IndexType: indexTy, ValueType: 
valueTy}
+       )
+
+       dict, _, _ := array.FromJSON(mem, valueTy, strings.NewReader(`["alpha", 
null, "gamma"]`))
+       defer dict.Release()
+
+       alpha := scalar.NewDictScalar(scalar.MakeScalar(int16(0)), dict)
+       defer alpha.Release()
+
+       // Valid index, null underlying value
+       nullVal := scalar.NewDictScalar(scalar.MakeScalar(int16(1)), dict)
+       defer nullVal.Release()
+
+       // inconsistent index type
+       dictSc := scalar.NewDictScalar(alpha.Value.Index, dict)
+       defer dictSc.Release()
+       dictSc.Type = &arrow.DictionaryType{IndexType: 
arrow.PrimitiveTypes.Int32, ValueType: valueTy}
+       assert.Error(t, dictSc.Validate())
+
+       // inconsistent value type between dict and type
+       dictSc.Type = &arrow.DictionaryType{IndexType: indexTy, ValueType: 
arrow.BinaryTypes.Binary}
+       assert.Error(t, dictSc.Validate())
+
+       // inconsistent Valid/Value
+       dictSc.Type = dictTy
+       assert.NoError(t, dictSc.ValidateFull())
+       dictSc.Valid = false
+       assert.Error(t, dictSc.ValidateFull())
+
+       assert.NoError(t, nullVal.ValidateFull())
+       nullVal.Valid = false
+       assert.Error(t, nullVal.ValidateFull())
+
+       dictSc = scalar.NewNullDictScalar(dictTy)
+       dictSc.Valid = true
+       assert.Error(t, dictSc.ValidateFull())
+       dictSc.Valid = false
+       assert.NoError(t, dictSc.ValidateFull())
+
+       // index value out of bounds
+       for _, idx := range []int16{-1, 3} {
+               invalid := scalar.NewDictScalar(scalar.MakeScalar(idx), dict)
+               defer invalid.Release()
+
+               assert.NoError(t, invalid.Validate())
+               assert.Error(t, invalid.ValidateFull())
+       }
+}

Reply via email to