This is an automated email from the ASF dual-hosted git repository.

kszucs pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 1fcbc6d  ARROW-9478: [C++] Improve error message for unsupported casts
1fcbc6d is described below

commit 1fcbc6dc0cb3e39f880aa43e95b35c8bcf6e8d62
Author: Antoine Pitrou <[email protected]>
AuthorDate: Wed Jul 15 23:30:26 2020 +0200

    ARROW-9478: [C++] Improve error message for unsupported casts
    
    Mention both input type and target type, as far as possible.
    
    Closes #7773 from pitrou/ARROW-9478-better-cast-error-message
    
    Lead-authored-by: Antoine Pitrou <[email protected]>
    Co-authored-by: Krisztián Szűcs <[email protected]>
    Signed-off-by: Krisztián Szűcs <[email protected]>
---
 cpp/src/arrow/compute/cast.cc                     |  45 ++++++--
 cpp/src/arrow/compute/kernels/scalar_cast_test.cc |  39 ++++++-
 cpp/src/arrow/testing/gtest_util.cc               |  39 +++++++
 cpp/src/arrow/testing/gtest_util.h                |   3 +
 cpp/src/arrow/type.cc                             | 132 ++++++++++------------
 cpp/src/arrow/type.h                              | 132 +---------------------
 cpp/src/arrow/type_fwd.h                          | 127 +++++++++++++++++++++
 cpp/src/arrow/type_test.cc                        |  39 +++++++
 cpp/src/arrow/type_traits.h                       |  52 +++++++++
 cpp/src/arrow/visitor_inline.h                    |  19 ++++
 r/src/arrow_exports.h                             |   6 -
 11 files changed, 410 insertions(+), 223 deletions(-)

diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc
index 211e5a2..a9700f3 100644
--- a/cpp/src/arrow/compute/cast.cc
+++ b/cpp/src/arrow/compute/cast.cc
@@ -32,6 +32,9 @@
 #include "arrow/util/logging.h"
 
 namespace arrow {
+
+using internal::ToTypeName;
+
 namespace compute {
 namespace internal {
 
@@ -54,6 +57,29 @@ void InitCastTable() {
 
 void EnsureInitCastTable() { std::call_once(cast_table_initialized, 
InitCastTable); }
 
+namespace {
+
+// Private version of GetCastFunction with better error reporting
+// if the input type is known.
+Result<std::shared_ptr<CastFunction>> GetCastFunctionInternal(
+    const std::shared_ptr<DataType>& to_type, const DataType* from_type = 
nullptr) {
+  internal::EnsureInitCastTable();
+  auto it = internal::g_cast_table.find(static_cast<int>(to_type->id()));
+  if (it == internal::g_cast_table.end()) {
+    if (from_type != nullptr) {
+      return Status::NotImplemented("Unsupported cast from ", *from_type, " to 
",
+                                    *to_type,
+                                    " (no available cast function for target 
type)");
+    } else {
+      return Status::NotImplemented("Unsupported cast to ", *to_type,
+                                    " (no available cast function for target 
type)");
+    }
+  }
+  return it->second;
+}
+
+}  // namespace
+
 // Metafunction for dispatching to appropraite CastFunction. This corresponds
 // to the standard SQL CAST(expr AS target_type)
 class CastMetaFunction : public MetaFunction {
@@ -79,8 +105,9 @@ class CastMetaFunction : public MetaFunction {
     if (args[0].type()->Equals(*cast_options->to_type)) {
       return args[0];
     }
-    ARROW_ASSIGN_OR_RAISE(std::shared_ptr<CastFunction> cast_func,
-                          GetCastFunction(cast_options->to_type));
+    ARROW_ASSIGN_OR_RAISE(
+        std::shared_ptr<CastFunction> cast_func,
+        GetCastFunctionInternal(cast_options->to_type, args[0].type().get()));
     return cast_func->Execute(args, options, ctx);
   }
 };
@@ -147,9 +174,9 @@ Result<const ScalarKernel*> CastFunction::DispatchExact(
   }
 
   if (candidate_kernels.size() == 0) {
-    return Status::NotImplemented("Function ", this->name(),
-                                  " has no kernel matching input type ",
-                                  values[0].ToString());
+    return Status::NotImplemented("Unsupported cast from ", 
values[0].type->ToString(),
+                                  " to ", ToTypeName(impl_->out_type), " using 
function ",
+                                  this->name());
   } else if (candidate_kernels.size() == 1) {
     // One match, return it
     return candidate_kernels[0];
@@ -188,13 +215,7 @@ Result<std::shared_ptr<Array>> Cast(const Array& value, 
std::shared_ptr<DataType
 
 Result<std::shared_ptr<CastFunction>> GetCastFunction(
     const std::shared_ptr<DataType>& to_type) {
-  internal::EnsureInitCastTable();
-  auto it = internal::g_cast_table.find(static_cast<int>(to_type->id()));
-  if (it == internal::g_cast_table.end()) {
-    return Status::NotImplemented("No cast function available to cast to ",
-                                  to_type->ToString());
-  }
-  return it->second;
+  return internal::GetCastFunctionInternal(to_type);
 }
 
 bool CanCast(const DataType& from_type, const DataType& to_type) {
diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc 
b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
index 083e12e..bea9a0e 100644
--- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
@@ -22,6 +22,7 @@
 #include <string>
 #include <vector>
 
+#include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
 #include "arrow/array.h"
@@ -1457,14 +1458,40 @@ TEST_F(TestCast, ChunkedArray) {
   ASSERT_TRUE(out.chunked_array()->Equals(*ex_carr));
 }
 
-TEST_F(TestCast, UnsupportedTarget) {
-  std::vector<bool> is_valid = {true, false, true, true, true};
-  std::vector<int32_t> v1 = {0, 1, 2, 3, 4};
+TEST_F(TestCast, UnsupportedInputType) {
+  // Casting to a supported target type, but with an unsupported input type
+  // for the target type.
+  const auto arr = ArrayFromJSON(int32(), "[1, 2, 3]");
 
-  std::shared_ptr<Array> arr;
-  ArrayFromVector<Int32Type>(int32(), is_valid, v1, &arr);
+  const auto to_type = list(utf8());
+  const char* expected_message = "Unsupported cast from int32 to list";
+
+  // Try through concrete API
+  EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, 
::testing::HasSubstr(expected_message),
+                                  Cast(*arr, to_type));
+
+  // Try through general kernel API
+  CastOptions options;
+  options.to_type = to_type;
+  EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, 
::testing::HasSubstr(expected_message),
+                                  CallFunction("cast", {arr}, &options));
+}
+
+TEST_F(TestCast, UnsupportedTargetType) {
+  // Casting to an unsupported target type
+  const auto arr = ArrayFromJSON(int32(), "[1, 2, 3]");
+  const auto to_type = dense_union({field("a", int32())});
 
-  ASSERT_RAISES(NotImplemented, Cast(*arr, list(utf8())));
+  // Try through concrete API
+  const char* expected_message = "Unsupported cast from int32 to dense_union";
+  EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, 
::testing::HasSubstr(expected_message),
+                                  Cast(*arr, to_type));
+
+  // Try through general kernel API
+  CastOptions options;
+  options.to_type = to_type;
+  EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, 
::testing::HasSubstr(expected_message),
+                                  CallFunction("cast", {arr}, &options));
 }
 
 TEST_F(TestCast, DateTimeZeroCopy) {
diff --git a/cpp/src/arrow/testing/gtest_util.cc 
b/cpp/src/arrow/testing/gtest_util.cc
index 99c8e31..de5b87a 100644
--- a/cpp/src/arrow/testing/gtest_util.cc
+++ b/cpp/src/arrow/testing/gtest_util.cc
@@ -54,6 +54,45 @@ namespace arrow {
 using internal::checked_cast;
 using internal::checked_pointer_cast;
 
+std::vector<Type::type> AllTypeIds() {
+  return {Type::NA,
+          Type::BOOL,
+          Type::INT8,
+          Type::INT16,
+          Type::INT32,
+          Type::INT64,
+          Type::UINT8,
+          Type::UINT16,
+          Type::UINT32,
+          Type::UINT64,
+          Type::HALF_FLOAT,
+          Type::FLOAT,
+          Type::DOUBLE,
+          Type::DECIMAL,
+          Type::DATE32,
+          Type::DATE64,
+          Type::TIME32,
+          Type::TIME64,
+          Type::TIMESTAMP,
+          Type::INTERVAL_DAY_TIME,
+          Type::INTERVAL_MONTHS,
+          Type::DURATION,
+          Type::STRING,
+          Type::BINARY,
+          Type::LARGE_STRING,
+          Type::LARGE_BINARY,
+          Type::FIXED_SIZE_BINARY,
+          Type::STRUCT,
+          Type::LIST,
+          Type::LARGE_LIST,
+          Type::FIXED_SIZE_LIST,
+          Type::MAP,
+          Type::DENSE_UNION,
+          Type::SPARSE_UNION,
+          Type::DICTIONARY,
+          Type::EXTENSION};
+}
+
 template <typename T, typename CompareFunctor>
 void AssertTsSame(const T& expected, const T& actual, CompareFunctor&& 
compare) {
   if (!compare(actual, expected)) {
diff --git a/cpp/src/arrow/testing/gtest_util.h 
b/cpp/src/arrow/testing/gtest_util.h
index 9291d9e..1411e70 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -153,6 +153,9 @@ class RecordBatch;
 class Table;
 struct Datum;
 
+ARROW_TESTING_EXPORT
+std::vector<Type::type> AllTypeIds();
+
 #define ASSERT_ARRAYS_EQUAL(lhs, rhs) AssertArraysEqual((lhs), (rhs))
 #define ASSERT_BATCHES_EQUAL(lhs, rhs) AssertBatchesEqual((lhs), (rhs))
 #define ASSERT_BATCHES_APPROX_EQUAL(lhs, rhs) AssertBatchesApproxEqual((lhs), 
(rhs))
diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index 52f50c3..d534d58 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -93,79 +93,71 @@ constexpr Type::type DictionaryType::type_id;
 
 namespace internal {
 
+struct TypeIdToTypeNameVisitor {
+  std::string out;
+
+  template <typename ArrowType>
+  Status Visit(const ArrowType*) {
+    out = ArrowType::type_name();
+    return Status::OK();
+  }
+};
+
+std::string ToTypeName(Type::type id) {
+  TypeIdToTypeNameVisitor visitor;
+
+  ARROW_CHECK_OK(VisitTypeIdInline(id, &visitor));
+  return std::move(visitor.out);
+}
+
 std::string ToString(Type::type id) {
   switch (id) {
-    case Type::NA:
-      return "NA";
-    case Type::BOOL:
-      return "BOOL";
-    case Type::UINT8:
-      return "UINT8";
-    case Type::INT8:
-      return "INT8";
-    case Type::UINT16:
-      return "UINT16";
-    case Type::INT16:
-      return "INT16";
-    case Type::UINT32:
-      return "UINT32";
-    case Type::INT32:
-      return "INT32";
-    case Type::UINT64:
-      return "UINT64";
-    case Type::INT64:
-      return "INT64";
-    case Type::HALF_FLOAT:
-      return "HALF_FLOAT";
-    case Type::FLOAT:
-      return "FLOAT";
-    case Type::DOUBLE:
-      return "DOUBLE";
-    case Type::STRING:
-      return "UTF8";
-    case Type::BINARY:
-      return "BINARY";
-    case Type::FIXED_SIZE_BINARY:
-      return "FIXED_SIZE_BINARY";
-    case Type::DATE64:
-      return "DATE64";
-    case Type::TIMESTAMP:
-      return "TIMESTAMP";
-    case Type::TIME32:
-      return "TIME32";
-    case Type::TIME64:
-      return "TIME64";
-    case Type::INTERVAL_MONTHS:
-      return "INTERVAL_MONTHS";
-    case Type::INTERVAL_DAY_TIME:
-      return "INTERVAL_DAY_TIME";
-    case Type::DECIMAL:
-      return "DECIMAL";
-    case Type::LIST:
-      return "LIST";
-    case Type::STRUCT:
-      return "STRUCT";
-    case Type::SPARSE_UNION:
-      return "SPARSE_UNION";
-    case Type::DENSE_UNION:
-      return "DENSE_UNION";
-    case Type::DICTIONARY:
-      return "DICTIONARY";
-    case Type::MAP:
-      return "MAP";
-    case Type::EXTENSION:
-      return "EXTENSION";
-    case Type::FIXED_SIZE_LIST:
-      return "FIXED_SIZE_LIST";
-    case Type::DURATION:
-      return "DURATION";
-    case Type::LARGE_BINARY:
-      return "LARGE_BINARY";
-    case Type::LARGE_LIST:
-      return "LARGE_LIST";
+#define TO_STRING_CASE(_id) \
+  case Type::_id:           \
+    return ARROW_STRINGIFY(_id);
+
+    TO_STRING_CASE(NA)
+    TO_STRING_CASE(BOOL)
+    TO_STRING_CASE(INT8)
+    TO_STRING_CASE(INT16)
+    TO_STRING_CASE(INT32)
+    TO_STRING_CASE(INT64)
+    TO_STRING_CASE(UINT8)
+    TO_STRING_CASE(UINT16)
+    TO_STRING_CASE(UINT32)
+    TO_STRING_CASE(UINT64)
+    TO_STRING_CASE(HALF_FLOAT)
+    TO_STRING_CASE(FLOAT)
+    TO_STRING_CASE(DOUBLE)
+    TO_STRING_CASE(DECIMAL)
+    TO_STRING_CASE(DATE32)
+    TO_STRING_CASE(DATE64)
+    TO_STRING_CASE(TIME32)
+    TO_STRING_CASE(TIME64)
+    TO_STRING_CASE(TIMESTAMP)
+    TO_STRING_CASE(INTERVAL_DAY_TIME)
+    TO_STRING_CASE(INTERVAL_MONTHS)
+    TO_STRING_CASE(DURATION)
+    TO_STRING_CASE(STRING)
+    TO_STRING_CASE(BINARY)
+    TO_STRING_CASE(LARGE_STRING)
+    TO_STRING_CASE(LARGE_BINARY)
+    TO_STRING_CASE(FIXED_SIZE_BINARY)
+    TO_STRING_CASE(STRUCT)
+    TO_STRING_CASE(LIST)
+    TO_STRING_CASE(LARGE_LIST)
+    TO_STRING_CASE(FIXED_SIZE_LIST)
+    TO_STRING_CASE(MAP)
+    TO_STRING_CASE(DENSE_UNION)
+    TO_STRING_CASE(SPARSE_UNION)
+    TO_STRING_CASE(DICTIONARY)
+    TO_STRING_CASE(EXTENSION)
+
+#undef TO_STRING_CASE
+
     default:
-      DCHECK(false) << "Should not be able to reach here";
-      return "unknown";
+      ARROW_LOG(FATAL) << "Unhandled type id: " << id;
+      return "";
   }
 }
 
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index 1eb06cd..86d8a79 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -36,135 +36,6 @@
 #include "arrow/visitor.h"  // IWYU pragma: keep
 
 namespace arrow {
-
-class Array;
-class Field;
-class MemoryPool;
-
-struct Type {
-  /// \brief Main data type enumeration
-  ///
-  /// This enumeration provides a quick way to interrogate the category
-  /// of a DataType instance.
-  enum type {
-    /// A NULL type having no physical storage
-    NA,
-
-    /// Boolean as 1 bit, LSB bit-packed ordering
-    BOOL,
-
-    /// Unsigned 8-bit little-endian integer
-    UINT8,
-
-    /// Signed 8-bit little-endian integer
-    INT8,
-
-    /// Unsigned 16-bit little-endian integer
-    UINT16,
-
-    /// Signed 16-bit little-endian integer
-    INT16,
-
-    /// Unsigned 32-bit little-endian integer
-    UINT32,
-
-    /// Signed 32-bit little-endian integer
-    INT32,
-
-    /// Unsigned 64-bit little-endian integer
-    UINT64,
-
-    /// Signed 64-bit little-endian integer
-    INT64,
-
-    /// 2-byte floating point value
-    HALF_FLOAT,
-
-    /// 4-byte floating point value
-    FLOAT,
-
-    /// 8-byte floating point value
-    DOUBLE,
-
-    /// UTF8 variable-length string as List<Char>
-    STRING,
-
-    /// Variable-length bytes (no guarantee of UTF8-ness)
-    BINARY,
-
-    /// Fixed-size binary. Each value occupies the same number of bytes
-    FIXED_SIZE_BINARY,
-
-    /// int32_t days since the UNIX epoch
-    DATE32,
-
-    /// int64_t milliseconds since the UNIX epoch
-    DATE64,
-
-    /// Exact timestamp encoded with int64 since UNIX epoch
-    /// Default unit millisecond
-    TIMESTAMP,
-
-    /// Time as signed 32-bit integer, representing either seconds or
-    /// milliseconds since midnight
-    TIME32,
-
-    /// Time as signed 64-bit integer, representing either microseconds or
-    /// nanoseconds since midnight
-    TIME64,
-
-    /// YEAR_MONTH interval in SQL style
-    INTERVAL_MONTHS,
-
-    /// DAY_TIME interval in SQL style
-    INTERVAL_DAY_TIME,
-
-    /// Precision- and scale-based decimal type. Storage type depends on the
-    /// parameters.
-    DECIMAL,
-
-    /// A list of some logical data type
-    LIST,
-
-    /// Struct of logical types
-    STRUCT,
-
-    /// Sparse unions of logical types
-    SPARSE_UNION,
-
-    /// Dense unions of logical types
-    DENSE_UNION,
-
-    /// Dictionary-encoded type, also called "categorical" or "factor"
-    /// in other programming languages. Holds the dictionary value
-    /// type but not the dictionary itself, which is part of the
-    /// ArrayData struct
-    DICTIONARY,
-
-    /// Map, a repeated struct logical type
-    MAP,
-
-    /// Custom data type, implemented by user
-    EXTENSION,
-
-    /// Fixed size list of some logical type
-    FIXED_SIZE_LIST,
-
-    /// Measure of elapsed time in either seconds, milliseconds, microseconds
-    /// or nanoseconds.
-    DURATION,
-
-    /// Like STRING, but with 64-bit offsets
-    LARGE_STRING,
-
-    /// Like BINARY, but with 64-bit offsets
-    LARGE_BINARY,
-
-    /// Like LIST, but with 64-bit offsets
-    LARGE_LIST
-  };
-};
-
 namespace detail {
 
 class ARROW_EXPORT Fingerprintable {
@@ -1937,6 +1808,9 @@ ARROW_EXPORT
 std::string ToString(Type::type id);
 
 ARROW_EXPORT
+std::string ToTypeName(Type::type id);
+
+ARROW_EXPORT
 std::string ToString(TimeUnit::type unit);
 
 ARROW_EXPORT
diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h
index 8c2c1c7..fc25b27 100644
--- a/cpp/src/arrow/type_fwd.h
+++ b/cpp/src/arrow/type_fwd.h
@@ -248,6 +248,133 @@ struct ExtensionScalar;
 
 // ----------------------------------------------------------------------
 
+struct Type {
+  /// \brief Main data type enumeration
+  ///
+  /// This enumeration provides a quick way to interrogate the category
+  /// of a DataType instance.
+  enum type {
+    /// A NULL type having no physical storage
+    NA = 0,
+
+    /// Boolean as 1 bit, LSB bit-packed ordering
+    BOOL,
+
+    /// Unsigned 8-bit little-endian integer
+    UINT8,
+
+    /// Signed 8-bit little-endian integer
+    INT8,
+
+    /// Unsigned 16-bit little-endian integer
+    UINT16,
+
+    /// Signed 16-bit little-endian integer
+    INT16,
+
+    /// Unsigned 32-bit little-endian integer
+    UINT32,
+
+    /// Signed 32-bit little-endian integer
+    INT32,
+
+    /// Unsigned 64-bit little-endian integer
+    UINT64,
+
+    /// Signed 64-bit little-endian integer
+    INT64,
+
+    /// 2-byte floating point value
+    HALF_FLOAT,
+
+    /// 4-byte floating point value
+    FLOAT,
+
+    /// 8-byte floating point value
+    DOUBLE,
+
+    /// UTF8 variable-length string as List<Char>
+    STRING,
+
+    /// Variable-length bytes (no guarantee of UTF8-ness)
+    BINARY,
+
+    /// Fixed-size binary. Each value occupies the same number of bytes
+    FIXED_SIZE_BINARY,
+
+    /// int32_t days since the UNIX epoch
+    DATE32,
+
+    /// int64_t milliseconds since the UNIX epoch
+    DATE64,
+
+    /// Exact timestamp encoded with int64 since UNIX epoch
+    /// Default unit millisecond
+    TIMESTAMP,
+
+    /// Time as signed 32-bit integer, representing either seconds or
+    /// milliseconds since midnight
+    TIME32,
+
+    /// Time as signed 64-bit integer, representing either microseconds or
+    /// nanoseconds since midnight
+    TIME64,
+
+    /// YEAR_MONTH interval in SQL style
+    INTERVAL_MONTHS,
+
+    /// DAY_TIME interval in SQL style
+    INTERVAL_DAY_TIME,
+
+    /// Precision- and scale-based decimal type. Storage type depends on the
+    /// parameters.
+    DECIMAL,
+
+    /// A list of some logical data type
+    LIST,
+
+    /// Struct of logical types
+    STRUCT,
+
+    /// Sparse unions of logical types
+    SPARSE_UNION,
+
+    /// Dense unions of logical types
+    DENSE_UNION,
+
+    /// Dictionary-encoded type, also called "categorical" or "factor"
+    /// in other programming languages. Holds the dictionary value
+    /// type but not the dictionary itself, which is part of the
+    /// ArrayData struct
+    DICTIONARY,
+
+    /// Map, a repeated struct logical type
+    MAP,
+
+    /// Custom data type, implemented by user
+    EXTENSION,
+
+    /// Fixed size list of some logical type
+    FIXED_SIZE_LIST,
+
+    /// Measure of elapsed time in either seconds, milliseconds, microseconds
+    /// or nanoseconds.
+    DURATION,
+
+    /// Like STRING, but with 64-bit offsets
+    LARGE_STRING,
+
+    /// Like BINARY, but with 64-bit offsets
+    LARGE_BINARY,
+
+    /// Like LIST, but with 64-bit offsets
+    LARGE_LIST,
+
+    // Leave this at the end
+    MAX_ID
+  };
+};
+
 /// \defgroup type-factories Factory functions for creating data types
 ///
 /// Factory functions for creating data types
diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc
index ebe42d3..e53d259 100644
--- a/cpp/src/arrow/type_test.cc
+++ b/cpp/src/arrow/type_test.cc
@@ -17,9 +17,12 @@
 
 // Unit tests for DataType (and subclasses), Field, and Schema
 
+#include <algorithm>
+#include <cctype>
 #include <cstdint>
 #include <memory>
 #include <string>
+#include <unordered_set>
 #include <vector>
 
 #include <gmock/gmock.h>
@@ -39,6 +42,42 @@ using testing::ElementsAre;
 using internal::checked_cast;
 using internal::checked_pointer_cast;
 
+TEST(TestTypeId, AllTypeIds) {
+  const auto all_ids = AllTypeIds();
+  ASSERT_EQ(static_cast<int>(all_ids.size()), Type::MAX_ID);
+}
+
+template <typename ReprFunc>
+void CheckTypeIdReprs(ReprFunc&& repr_func, bool expect_uppercase) {
+  std::unordered_set<std::string> unique_reprs;
+  const auto all_ids = AllTypeIds();
+  for (const auto id : all_ids) {
+    std::string repr = repr_func(id);
+    ASSERT_TRUE(std::all_of(repr.begin(), repr.end(),
+                            [=](const char c) {
+                              return c == '_' || std::isdigit(c) ||
+                                     (expect_uppercase ? std::isupper(c)
+                                                       : std::islower(c));
+                            }))
+        << "Invalid type id repr: '" << repr << "'";
+    unique_reprs.insert(std::move(repr));
+  }
+  // No duplicates
+  ASSERT_EQ(unique_reprs.size(), all_ids.size());
+}
+
+TEST(TestTypeId, ToString) {
+  // Should be all uppercase strings (corresponding to the enum member names)
+  CheckTypeIdReprs([](Type::type id) { return internal::ToString(id); },
+                   /* expect_uppercase=*/true);
+}
+
+TEST(TestTypeId, ToTypeName) {
+  // Should be all lowercase strings (corresponding to TypeClass::type_name())
+  CheckTypeIdReprs([](Type::type id) { return internal::ToTypeName(id); },
+                   /* expect_uppercase=*/false);
+}
+
 TEST(TestField, Basics) {
   Field f0("f0", int32());
   Field f0_nn("f0", int32(), false);
diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h
index f6fd4dd..7cf9503 100644
--- a/cpp/src/arrow/type_traits.h
+++ b/cpp/src/arrow/type_traits.h
@@ -28,6 +28,58 @@
 namespace arrow {
 
 //
+// Per-type id type lookup
+//
+
+template <Type::type id>
+struct TypeIdTraits {};
+
+#define TYPE_ID_TRAIT(_id, _typeclass) \
+  template <>                          \
+  struct TypeIdTraits<Type::_id> {     \
+    using Type = _typeclass;           \
+  };
+
+TYPE_ID_TRAIT(NA, NullType)
+TYPE_ID_TRAIT(BOOL, BooleanType)
+TYPE_ID_TRAIT(INT8, Int8Type)
+TYPE_ID_TRAIT(INT16, Int16Type)
+TYPE_ID_TRAIT(INT32, Int32Type)
+TYPE_ID_TRAIT(INT64, Int64Type)
+TYPE_ID_TRAIT(UINT8, UInt8Type)
+TYPE_ID_TRAIT(UINT16, UInt16Type)
+TYPE_ID_TRAIT(UINT32, UInt32Type)
+TYPE_ID_TRAIT(UINT64, UInt64Type)
+TYPE_ID_TRAIT(HALF_FLOAT, HalfFloatType)
+TYPE_ID_TRAIT(FLOAT, FloatType)
+TYPE_ID_TRAIT(DOUBLE, DoubleType)
+TYPE_ID_TRAIT(STRING, StringType)
+TYPE_ID_TRAIT(BINARY, BinaryType)
+TYPE_ID_TRAIT(LARGE_STRING, LargeStringType)
+TYPE_ID_TRAIT(LARGE_BINARY, LargeBinaryType)
+TYPE_ID_TRAIT(FIXED_SIZE_BINARY, FixedSizeBinaryType)
+TYPE_ID_TRAIT(DATE32, Date32Type)
+TYPE_ID_TRAIT(DATE64, Date64Type)
+TYPE_ID_TRAIT(TIME32, Time32Type)
+TYPE_ID_TRAIT(TIME64, Time64Type)
+TYPE_ID_TRAIT(TIMESTAMP, TimestampType)
+TYPE_ID_TRAIT(INTERVAL_DAY_TIME, DayTimeIntervalType)
+TYPE_ID_TRAIT(INTERVAL_MONTHS, MonthIntervalType)
+TYPE_ID_TRAIT(DURATION, DurationType)
+TYPE_ID_TRAIT(DECIMAL, Decimal128Type)  // XXX or DecimalType?
+TYPE_ID_TRAIT(STRUCT, StructType)
+TYPE_ID_TRAIT(LIST, ListType)
+TYPE_ID_TRAIT(LARGE_LIST, LargeListType)
+TYPE_ID_TRAIT(FIXED_SIZE_LIST, FixedSizeListType)
+TYPE_ID_TRAIT(MAP, MapType)
+TYPE_ID_TRAIT(DENSE_UNION, DenseUnionType)
+TYPE_ID_TRAIT(SPARSE_UNION, SparseUnionType)
+TYPE_ID_TRAIT(DICTIONARY, DictionaryType)
+TYPE_ID_TRAIT(EXTENSION, ExtensionType)
+
+#undef TYPE_ID_TRAIT
+
+//
 // Per-type type traits
 //
 
diff --git a/cpp/src/arrow/visitor_inline.h b/cpp/src/arrow/visitor_inline.h
index 719b801..233b105 100644
--- a/cpp/src/arrow/visitor_inline.h
+++ b/cpp/src/arrow/visitor_inline.h
@@ -94,6 +94,25 @@ inline Status VisitTypeInline(const DataType& type, VISITOR* 
visitor) {
 
 #undef TYPE_VISIT_INLINE
 
+#define TYPE_ID_VISIT_INLINE(TYPE_CLASS)            \
+  case TYPE_CLASS##Type::type_id: {                 \
+    const TYPE_CLASS##Type* concrete_ptr = nullptr; \
+    return visitor->Visit(concrete_ptr);            \
+  }
+
+// Calls `visitor` with a nullptr of the corresponding concrete type class
+template <typename VISITOR>
+inline Status VisitTypeIdInline(Type::type id, VISITOR* visitor) {
+  switch (id) {
+    ARROW_GENERATE_FOR_ALL_TYPES(TYPE_ID_VISIT_INLINE);
+    default:
+      break;
+  }
+  return Status::NotImplemented("Type not implemented");
+}
+
+#undef TYPE_ID_VISIT_INLINE
+
 #define ARRAY_VISIT_INLINE(TYPE_CLASS)                                         
          \
   case TYPE_CLASS##Type::type_id:                                              
          \
     return visitor->Visit(                                                     
          \
diff --git a/r/src/arrow_exports.h b/r/src/arrow_exports.h
index 5811840..6b85995 100644
--- a/r/src/arrow_exports.h
+++ b/r/src/arrow_exports.h
@@ -32,12 +32,6 @@
 
 namespace arrow {
 
-struct Type {
-  enum type {
-    // forward declaration
-  };
-};
-
 namespace compute {
 class CastOptions;
 

Reply via email to