This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 8a4d611133 ARROW-16323: [Go] Implement Dictionary Scalars (#13575)
8a4d611133 is described below
commit 8a4d611133fbc337e48053b14539a37783614f91
Author: Matt Topol <[email protected]>
AuthorDate: Wed Jul 20 10:41:39 2022 -0400
ARROW-16323: [Go] Implement Dictionary Scalars (#13575)
Lead-authored-by: Matt Topol <[email protected]>
Co-authored-by: Matthew Topol <[email protected]>
Signed-off-by: Matthew Topol <[email protected]>
---
go/arrow/array.go | 3 +
go/arrow/array/extension.go | 4 +
go/arrow/datatype.go | 10 +++
go/arrow/scalar/nested.go | 184 +++++++++++++++++++++++++++++++++++++++++
go/arrow/scalar/parse.go | 15 ++++
go/arrow/scalar/scalar.go | 18 +++-
go/arrow/scalar/scalar_test.go | 165 ++++++++++++++++++++++++++++++++++++
7 files changed, 398 insertions(+), 1 deletion(-)
diff --git a/go/arrow/array.go b/go/arrow/array.go
index 0742ae11f0..d8983c29f2 100644
--- a/go/arrow/array.go
+++ b/go/arrow/array.go
@@ -18,6 +18,7 @@ package arrow
import (
"encoding/json"
+ "fmt"
"github.com/apache/arrow/go/v9/arrow/memory"
)
@@ -86,6 +87,8 @@ type ArrayData interface {
type Array interface {
json.Marshaler
+ fmt.Stringer
+
// DataType returns the type metadata for this instance.
DataType() DataType
diff --git a/go/arrow/array/extension.go b/go/arrow/array/extension.go
index 94babc3829..3fb3e77b85 100644
--- a/go/arrow/array/extension.go
+++ b/go/arrow/array/extension.go
@@ -130,6 +130,10 @@ type ExtensionArrayBase struct {
storage arraymarshal
}
+func (e *ExtensionArrayBase) String() string {
+ return fmt.Sprintf("(%s)%s", e.data.dtype, e.storage)
+}
+
func (e *ExtensionArrayBase) getOneForMarshal(i int) interface{} {
return e.storage.getOneForMarshal(i)
}
diff --git a/go/arrow/datatype.go b/go/arrow/datatype.go
index ed6e803aea..1503f655e7 100644
--- a/go/arrow/datatype.go
+++ b/go/arrow/datatype.go
@@ -255,6 +255,16 @@ func IsInteger(t Type) bool {
return false
}
+// IsUnsignedInteger is a helper that returns true if the type ID provided is
+// one of the uint integral types (uint8, uint16, uint32, uint64)
+func IsUnsignedInteger(t Type) bool {
+ switch t {
+ case UINT8, UINT16, UINT32, UINT64:
+ return true
+ }
+ return false
+}
+
// IsPrimitive returns true if the provided type ID represents a fixed width
// primitive type.
func IsPrimitive(t Type) bool {
diff --git a/go/arrow/scalar/nested.go b/go/arrow/scalar/nested.go
index fb86dbe7e1..0a3bedad0b 100644
--- a/go/arrow/scalar/nested.go
+++ b/go/arrow/scalar/nested.go
@@ -18,6 +18,7 @@ package scalar
import (
"bytes"
+ "errors"
"fmt"
"github.com/apache/arrow/go/v9/arrow"
@@ -322,3 +323,186 @@ func NewStructScalarWithNames(val []Scalar, names
[]string) (*Struct, error) {
}
return NewStructScalar(val, arrow.StructOf(fields...)), nil
}
+
+type Dictionary struct {
+ scalar
+
+ Value struct {
+ Index Scalar
+ Dict arrow.Array
+ }
+}
+
+func NewNullDictScalar(dt arrow.DataType) *Dictionary {
+ ret := &Dictionary{scalar: scalar{dt, false}}
+ ret.Value.Index = MakeNullScalar(dt.(*arrow.DictionaryType).IndexType)
+ ret.Value.Dict = nil
+ return ret
+}
+
+func NewDictScalar(index Scalar, dict arrow.Array) *Dictionary {
+ ret := &Dictionary{scalar: scalar{&arrow.DictionaryType{IndexType:
index.DataType(), ValueType: dict.DataType()}, index.IsValid()}}
+ ret.Value.Index = index
+ ret.Value.Dict = dict
+ ret.Retain()
+ return ret
+}
+
+func (s *Dictionary) Retain() {
+ if r, ok := s.Value.Index.(Releasable); ok {
+ r.Retain()
+ }
+ if s.Value.Dict != (arrow.Array)(nil) {
+ s.Value.Dict.Retain()
+ }
+}
+
+func (s *Dictionary) Release() {
+ if r, ok := s.Value.Index.(Releasable); ok {
+ r.Release()
+ }
+ if s.Value.Dict != (arrow.Array)(nil) {
+ s.Value.Dict.Release()
+ }
+}
+
+func (s *Dictionary) Validate() (err error) {
+ dt, ok := s.Type.(*arrow.DictionaryType)
+ if !ok {
+ return errors.New("arrow/scalar: dictionary scalar should have
type Dictionary")
+ }
+
+ if s.Value.Index == (Scalar)(nil) {
+ return fmt.Errorf("%s scalar doesn't have an index value", dt)
+ }
+
+ if err = s.Value.Index.Validate(); err != nil {
+ return fmt.Errorf("%s scalar fails validation for index value:
%w", dt, err)
+ }
+
+ if !arrow.TypeEqual(s.Value.Index.DataType(), dt.IndexType) {
+ return fmt.Errorf("%s scalar should have an index value of type
%s, got %s",
+ dt, dt.IndexType, s.Value.Index.DataType())
+ }
+
+ if s.IsValid() && !s.Value.Index.IsValid() {
+ return fmt.Errorf("non-null %s scalar has null index value", dt)
+ }
+
+ if !s.IsValid() && s.Value.Index.IsValid() {
+ return fmt.Errorf("null %s scalar has non-null index value", dt)
+ }
+
+ if !s.IsValid() {
+ return
+ }
+
+ if s.Value.Dict == (arrow.Array)(nil) {
+ return fmt.Errorf("%s scalar doesn't have a dictionary value",
dt)
+ }
+
+ if !arrow.TypeEqual(s.Value.Dict.DataType(), dt.ValueType) {
+ return fmt.Errorf("%s scalar's value type doesn't match dict
type: got %s", dt, s.Value.Dict.DataType())
+ }
+
+ return
+}
+
+func (s *Dictionary) ValidateFull() (err error) {
+ if err = s.Validate(); err != nil {
+ return
+ }
+
+ if !s.Value.Index.IsValid() {
+ return nil
+ }
+
+ max := s.Value.Dict.Len() - 1
+ switch idx := s.Value.Index.value().(type) {
+ case int8:
+ if idx < 0 || int(idx) > max {
+ err = fmt.Errorf("%s scalar index value out of bounds:
%d", s.DataType(), idx)
+ }
+ case uint8:
+ if int(idx) > max {
+ err = fmt.Errorf("%s scalar index value out of bounds:
%d", s.DataType(), idx)
+ }
+ case int16:
+ if idx < 0 || int(idx) > max {
+ err = fmt.Errorf("%s scalar index value out of bounds:
%d", s.DataType(), idx)
+ }
+ case uint16:
+ if int(idx) > max {
+ err = fmt.Errorf("%s scalar index value out of bounds:
%d", s.DataType(), idx)
+ }
+ case int32:
+ if idx < 0 || int(idx) > max {
+ err = fmt.Errorf("%s scalar index value out of bounds:
%d", s.DataType(), idx)
+ }
+ case uint32:
+ if int(idx) > max {
+ err = fmt.Errorf("%s scalar index value out of bounds:
%d", s.DataType(), idx)
+ }
+ case int64:
+ if idx < 0 || int(idx) > max {
+ err = fmt.Errorf("%s scalar index value out of bounds:
%d", s.DataType(), idx)
+ }
+ case uint64:
+ if int(idx) > max {
+ err = fmt.Errorf("%s scalar index value out of bounds:
%d", s.DataType(), idx)
+ }
+ }
+
+ return
+}
+
+func (s *Dictionary) String() string {
+ if !s.Valid {
+ return "null"
+ }
+
+ return s.Value.Dict.String() + "[" + s.Value.Index.String() + "]"
+}
+
+func (s *Dictionary) equals(rhs Scalar) bool {
+ return s.Value.Index.equals(rhs.(*Dictionary).Value.Index) &&
+ array.Equal(s.Value.Dict, rhs.(*Dictionary).Value.Dict)
+}
+
+func (s *Dictionary) CastTo(arrow.DataType) (Scalar, error) {
+ return nil, fmt.Errorf("cast from scalar %s not implemented",
s.DataType())
+}
+
+func (s *Dictionary) GetEncodedValue() (Scalar, error) {
+ dt := s.Type.(*arrow.DictionaryType)
+ if !s.IsValid() {
+ return MakeNullScalar(dt.ValueType), nil
+ }
+
+ var idxValue int
+ switch dt.IndexType.ID() {
+ case arrow.INT8:
+ idxValue = int(s.Value.Index.value().(int8))
+ case arrow.UINT8:
+ idxValue = int(s.Value.Index.value().(uint8))
+ case arrow.INT16:
+ idxValue = int(s.Value.Index.value().(int16))
+ case arrow.UINT16:
+ idxValue = int(s.Value.Index.value().(uint16))
+ case arrow.INT32:
+ idxValue = int(s.Value.Index.value().(int32))
+ case arrow.UINT32:
+ idxValue = int(s.Value.Index.value().(uint32))
+ case arrow.INT64:
+ idxValue = int(s.Value.Index.value().(int64))
+ case arrow.UINT64:
+ idxValue = int(s.Value.Index.value().(uint64))
+ default:
+ return nil, fmt.Errorf("unimplemented dictionary type %s",
dt.IndexType)
+ }
+ return GetScalar(s.Value.Dict, idxValue)
+}
+
+func (s *Dictionary) value() interface{} {
+ return s.Value.Index.value()
+}
diff --git a/go/arrow/scalar/parse.go b/go/arrow/scalar/parse.go
index 805ade427d..9680589bd2 100644
--- a/go/arrow/scalar/parse.go
+++ b/go/arrow/scalar/parse.go
@@ -420,6 +420,19 @@ func MakeScalarParam(val interface{}, dt arrow.DataType)
(Scalar, error) {
return NewMapScalar(v), nil
}
}
+
+ if arrow.IsInteger(dt.ID()) {
+ bits := dt.(arrow.FixedWidthDataType).BitWidth()
+ val := reflect.ValueOf(val)
+ if arrow.IsUnsignedInteger(dt.ID()) {
+ return
MakeUnsignedIntegerScalar(val.Convert(reflect.TypeOf(uint64(0))).Uint(), bits)
+ }
+ return
MakeIntegerScalar(val.Convert(reflect.TypeOf(int64(0))).Int(), bits)
+ }
+
+ if dt.ID() == arrow.DICTIONARY {
+ return MakeScalarParam(val,
dt.(*arrow.DictionaryType).ValueType)
+ }
return MakeScalar(val), nil
}
@@ -631,6 +644,8 @@ func ParseScalar(dt arrow.DataType, val string) (Scalar,
error) {
}
return NewTime64Scalar(tm, dt), nil
+ case arrow.DICTIONARY:
+ return ParseScalar(dt.(*arrow.DictionaryType).ValueType, val)
}
return nil, fmt.Errorf("parsing of scalar for type %s not implemented",
dt)
diff --git a/go/arrow/scalar/scalar.go b/go/arrow/scalar/scalar.go
index 9f75869cf7..db1ec4da44 100644
--- a/go/arrow/scalar/scalar.go
+++ b/go/arrow/scalar/scalar.go
@@ -444,7 +444,7 @@ func init() {
arrow.STRUCT: func(dt arrow.DataType) Scalar {
return &Struct{scalar: scalar{dt, false}} },
arrow.SPARSE_UNION: unsupportedScalarType,
arrow.DENSE_UNION: unsupportedScalarType,
- arrow.DICTIONARY: unsupportedScalarType,
+ arrow.DICTIONARY: func(dt arrow.DataType) Scalar {
return NewNullDictScalar(dt) },
arrow.LARGE_STRING: unsupportedScalarType,
arrow.LARGE_BINARY: unsupportedScalarType,
arrow.LARGE_LIST: unsupportedScalarType,
@@ -558,6 +558,18 @@ func GetScalar(arr arrow.Array, idx int) (Scalar, error) {
return NewTime64Scalar(arr.Value(idx), arr.DataType()), nil
case *array.Timestamp:
return NewTimestampScalar(arr.Value(idx), arr.DataType()), nil
+ case *array.Dictionary:
+ ty := arr.DataType().(*arrow.DictionaryType)
+ index, err := MakeScalarParam(arr.GetValueIndex(idx),
ty.IndexType)
+ if err != nil {
+ return nil, err
+ }
+
+ scalar := &Dictionary{scalar: scalar{ty, arr.IsValid(idx)}}
+ scalar.Value.Index = index
+ scalar.Value.Dict = arr.Dictionary()
+ scalar.Value.Dict.Retain()
+ return scalar, nil
}
return nil, fmt.Errorf("cannot create scalar from array of type %s",
arr.DataType())
@@ -819,6 +831,10 @@ func Hash(seed maphash.Seed, s Scalar) uint64 {
out ^= Hash(seed, c)
}
}
+ case *Dictionary:
+ if s.Value.Index.IsValid() {
+ out ^= Hash(seed, s.Value.Index)
+ }
}
return out
diff --git a/go/arrow/scalar/scalar_test.go b/go/arrow/scalar/scalar_test.go
index 22290664b3..9b4f458c44 100644
--- a/go/arrow/scalar/scalar_test.go
+++ b/go/arrow/scalar/scalar_test.go
@@ -18,8 +18,10 @@ package scalar_test
import (
"bytes"
+ "fmt"
"hash/maphash"
"math/bits"
+ "strings"
"testing"
"time"
@@ -978,3 +980,166 @@ func TestToScalar(t *testing.T) {
assert.Equal(t, expected, sc.String())
}
+
+var dictIndexTypes = []arrow.DataType{
+ arrow.PrimitiveTypes.Int8,
+ arrow.PrimitiveTypes.Uint8,
+ arrow.PrimitiveTypes.Int16,
+ arrow.PrimitiveTypes.Uint16,
+ arrow.PrimitiveTypes.Int32,
+ arrow.PrimitiveTypes.Uint32,
+ arrow.PrimitiveTypes.Int64,
+ arrow.PrimitiveTypes.Uint64,
+}
+
+func TestDictionaryScalarBasics(t *testing.T) {
+ for _, indexType := range dictIndexTypes {
+ t.Run(fmt.Sprint(indexType), func(t *testing.T) {
+ mem :=
memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ ty := &arrow.DictionaryType{IndexType: indexType,
ValueType: arrow.BinaryTypes.String}
+ dict, _, _ := array.FromJSON(mem,
arrow.BinaryTypes.String, strings.NewReader(`["alpha", null, "gamma"]`))
+ defer dict.Release()
+
+ idxScalar, _ := scalar.MakeScalarParam(0, indexType)
+ alpha := scalar.NewDictScalar(idxScalar, dict)
+ defer alpha.Release()
+
+ idxScalar, _ = scalar.MakeScalarParam(2, indexType)
+ gamma := scalar.NewDictScalar(idxScalar, dict)
+ defer gamma.Release()
+
+ idxScalar, _ = scalar.MakeScalarParam(1, indexType)
+ nullVal := scalar.NewDictScalar(idxScalar, dict)
+ defer nullVal.Release()
+
+ scalarNull := scalar.MakeNullScalar(ty)
+ scalarNull.(*scalar.Dictionary).Value.Dict = dict
+ dict.Retain()
+ defer scalarNull.(*scalar.Dictionary).Release()
+
+ assert.NoError(t, scalarNull.ValidateFull())
+ assert.NoError(t, alpha.ValidateFull())
+ assert.NoError(t, gamma.ValidateFull())
+
+ // index is valid, corresponding value is null
+ assert.NoError(t, nullVal.ValidateFull())
+
+ encodedNull, err :=
scalarNull.(*scalar.Dictionary).GetEncodedValue()
+ assert.NoError(t, err)
+ assert.NoError(t, encodedNull.ValidateFull())
+ assert.True(t, scalar.Equals(encodedNull,
scalar.MakeNullScalar(arrow.BinaryTypes.String)))
+
+ encodedNullVal, err := nullVal.GetEncodedValue()
+ assert.NoError(t, err)
+ assert.NoError(t, encodedNullVal.ValidateFull())
+ assert.True(t, scalar.Equals(encodedNullVal,
scalar.MakeNullScalar(arrow.BinaryTypes.String)))
+
+ encodedAlpha, err := alpha.GetEncodedValue()
+ assert.NoError(t, err)
+ assert.NoError(t, encodedAlpha.ValidateFull())
+ assert.True(t, scalar.Equals(encodedAlpha,
scalar.MakeScalar("alpha")))
+
+ encodedGamma, err := gamma.GetEncodedValue()
+ assert.NoError(t, err)
+ assert.NoError(t, encodedGamma.ValidateFull())
+ assert.True(t, scalar.Equals(encodedGamma,
scalar.MakeScalar("gamma")))
+
+ idxArr, _, _ := array.FromJSON(mem, indexType,
strings.NewReader(`[2, 0, 1, null]`))
+ defer idxArr.Release()
+ arr := array.NewDictionaryArray(ty, idxArr, dict)
+ defer arr.Release()
+
+ first, err := scalar.GetScalar(arr, 0)
+ assert.NoError(t, err)
+ second, err := scalar.GetScalar(arr, 1)
+ assert.NoError(t, err)
+ third, err := scalar.GetScalar(arr, 2)
+ assert.NoError(t, err)
+ last, err := scalar.GetScalar(arr, 3)
+ assert.NoError(t, err)
+
+ defer func() {
+ first.(*scalar.Dictionary).Release()
+ second.(*scalar.Dictionary).Release()
+ third.(*scalar.Dictionary).Release()
+ last.(*scalar.Dictionary).Release()
+ }()
+
+ assert.NoError(t, first.ValidateFull())
+ assert.NoError(t, second.ValidateFull())
+ assert.NoError(t, third.ValidateFull())
+ assert.NoError(t, last.ValidateFull())
+
+ assert.True(t, first.IsValid())
+ assert.True(t, second.IsValid())
+ assert.True(t, third.IsValid()) // valid because of
valid index despite null value
+ assert.False(t, last.IsValid())
+
+ assert.True(t, scalar.Equals(first, gamma))
+ assert.True(t, scalar.Equals(second, alpha))
+ assert.True(t, scalar.Equals(third, nullVal))
+ assert.True(t, scalar.Equals(last, scalarNull))
+
+ assert.Same(t, first.(*scalar.Dictionary).Value.Dict,
arr.Dictionary())
+ assert.Same(t, second.(*scalar.Dictionary).Value.Dict,
arr.Dictionary())
+ })
+ }
+}
+
+func TestDictionaryScalarValidateErrors(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ var (
+ indexTy = arrow.PrimitiveTypes.Int16
+ valueTy = arrow.BinaryTypes.String
+ dictTy = &arrow.DictionaryType{IndexType: indexTy, ValueType:
valueTy}
+ )
+
+ dict, _, _ := array.FromJSON(mem, valueTy, strings.NewReader(`["alpha",
null, "gamma"]`))
+ defer dict.Release()
+
+ alpha := scalar.NewDictScalar(scalar.MakeScalar(int16(0)), dict)
+ defer alpha.Release()
+
+ // Valid index, null underlying value
+ nullVal := scalar.NewDictScalar(scalar.MakeScalar(int16(1)), dict)
+ defer nullVal.Release()
+
+ // inconsistent index type
+ dictSc := scalar.NewDictScalar(alpha.Value.Index, dict)
+ defer dictSc.Release()
+ dictSc.Type = &arrow.DictionaryType{IndexType:
arrow.PrimitiveTypes.Int32, ValueType: valueTy}
+ assert.Error(t, dictSc.Validate())
+
+ // inconsistent value type between dict and type
+ dictSc.Type = &arrow.DictionaryType{IndexType: indexTy, ValueType:
arrow.BinaryTypes.Binary}
+ assert.Error(t, dictSc.Validate())
+
+ // inconsistent Valid/Value
+ dictSc.Type = dictTy
+ assert.NoError(t, dictSc.ValidateFull())
+ dictSc.Valid = false
+ assert.Error(t, dictSc.ValidateFull())
+
+ assert.NoError(t, nullVal.ValidateFull())
+ nullVal.Valid = false
+ assert.Error(t, nullVal.ValidateFull())
+
+ dictSc = scalar.NewNullDictScalar(dictTy)
+ dictSc.Valid = true
+ assert.Error(t, dictSc.ValidateFull())
+ dictSc.Valid = false
+ assert.NoError(t, dictSc.ValidateFull())
+
+ // index value out of bounds
+ for _, idx := range []int16{-1, 3} {
+ invalid := scalar.NewDictScalar(scalar.MakeScalar(idx), dict)
+ defer invalid.Release()
+
+ assert.NoError(t, invalid.Validate())
+ assert.Error(t, invalid.ValidateFull())
+ }
+}