This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e8c998  [FEAT] Optimize Expected<T> for minimal compiled code and 
efficiency (#599)
1e8c998 is described below

commit 1e8c998269e8667b35c576bbaeb37fadb94593b1
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu May 28 13:23:27 2026 -0400

    [FEAT] Optimize Expected<T> for minimal compiled code and efficiency (#599)
    
    ## Summary
    
    Optimize `Expected<T>` so the happy paths compile down to a handful of
    instructions. The previous implementation routed `is_ok` / `is_err`
    through `as<Error>()` (which refcount-bumped the Error tag just to test
    type), and `value()` went through `cast<T>()` which performed a
    redundant runtime type check on every read.
    
    ## Changes
    
    - `is_ok()` / `is_err()` read the type-index discriminator directly. No
    indirection.
    - `value()` and `error()` use `TVM_FFI_PREDICT_TRUE/FALSE` to hint the
    happy path; the type-traits read uses the `*AfterCheck` variants since
    the `Expected` invariant already guarantees the held type.
    - `value_or()` gains an rvalue overload that moves from the held value
    instead of copying when the `Expected` is consumed as a temporary.
    - Add `TVM_FFI_UNSAFE_ASSUME` macro (alongside `TVM_FFI_PREDICT_*`)
    expanding to `__builtin_assume` / `__assume`. Applied inside non-Object
    `CopyFromAnyViewAfterCheck` to hint the compiler about the type-index
    invariant.
    - Add `MoveFromAnyAfterCheck` to `TypeTraits<DLDataType>`; the gap
    prevented `Expected<DLDataType>::value() &&` from compiling. POD type,
    so delegate to `CopyFromAnyViewAfterCheck`.
    - Add `MoveFromAnyAfterCheck` to `TypeTraits<TypedFunction<FType>>`;
    delegate to `TypeTraits<Function>::MoveFromAnyAfterCheck`.
    
    ## Codegen (clang++-17 -O2, lines of asm, lower is better)
    
    | Probe | Before | After |
    |---|---|---|
    | is_ok / is_err (int) | 173 | 5 |
    | value() lvalue (int) | 696 | 37 |
    | value() rvalue (int) | 725 | 34 |
    | value() rvalue (String) | 810 | 96 |
    | value() rvalue (Array<int>) | 1763 | 107 |
    | make_ok().value() | 738 | 50 |
    | std::move(make_ok()).value() | 738 | 50 |
    
    `make_ok().value()` compiles bit-identical to
    `std::move(make_ok()).value()` — the prvalue temporary materializes an
    xvalue, binding the `&&` overload directly. No `std::move` needed at the
    call site.
---
 include/tvm/ffi/base_details.h | 24 ++++++++++++
 include/tvm/ffi/dtype.h        |  6 +++
 include/tvm/ffi/expected.h     | 88 +++++++++++++++++++++++++-----------------
 include/tvm/ffi/function.h     |  4 ++
 include/tvm/ffi/type_traits.h  | 12 +++++-
 tests/cpp/test_expected.cc     | 35 +++++++++++++++++
 6 files changed, 133 insertions(+), 36 deletions(-)

diff --git a/include/tvm/ffi/base_details.h b/include/tvm/ffi/base_details.h
index c2ac2b8..27431b2 100644
--- a/include/tvm/ffi/base_details.h
+++ b/include/tvm/ffi/base_details.h
@@ -114,6 +114,30 @@
 #define TVM_FFI_PREDICT_TRUE(cond) (cond)
 #endif
 
+/*!
+ * \brief Translates into __builtin_assume / __assume / 
__attribute__((assume)).
+ *
+ * Use ONLY when the external invariant guarantees cond. The compiler
+ * will remove all paths inconsistent with cond. This is not an
+ * assertion or check -- using on a wrong cond will result in
+ * undefined behavior. cond must be side-effect-free.
+ */
+#if defined(__clang__)
+#define TVM_FFI_UNSAFE_ASSUME(cond) __builtin_assume(cond)
+#elif defined(__GNUC__)
+// GCC 13+ supports __attribute__((assume(...))); fall back to the void-cast
+// no-op for older GCC where __builtin_assume is absent.
+#if __GNUC__ >= 13
+#define TVM_FFI_UNSAFE_ASSUME(cond) __attribute__((assume(cond)))
+#else
+#define TVM_FFI_UNSAFE_ASSUME(cond) static_cast<void>(0)
+#endif
+#elif defined(_MSC_VER)
+#define TVM_FFI_UNSAFE_ASSUME(cond) __assume(cond)
+#else
+#define TVM_FFI_UNSAFE_ASSUME(cond) static_cast<void>(0)
+#endif
+
 #define TVM_FFI_STR_CONCAT_(__x, __y) __x##__y
 #define TVM_FFI_STR_CONCAT(__x, __y) TVM_FFI_STR_CONCAT_(__x, __y)
 
diff --git a/include/tvm/ffi/dtype.h b/include/tvm/ffi/dtype.h
index e1a44fb..6aa2aa6 100644
--- a/include/tvm/ffi/dtype.h
+++ b/include/tvm/ffi/dtype.h
@@ -164,9 +164,15 @@ struct TypeTraits<DLDataType> : public TypeTraitsBase {
   }
 
   TVM_FFI_INLINE static DLDataType CopyFromAnyViewAfterCheck(const TVMFFIAny* 
src) {
+    TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIDataType);
     return src->v_dtype;
   }
 
+  TVM_FFI_INLINE static DLDataType MoveFromAnyAfterCheck(TVMFFIAny* src) {
+    // POD type — move is just copy.
+    return CopyFromAnyViewAfterCheck(src);
+  }
+
   TVM_FFI_INLINE static std::optional<DLDataType> TryCastFromAnyView(const 
TVMFFIAny* src) {
     if (src->type_index == TypeIndex::kTVMFFIDataType) {
       return src->v_dtype;
diff --git a/include/tvm/ffi/expected.h b/include/tvm/ffi/expected.h
index de35c67..6e822df 100644
--- a/include/tvm/ffi/expected.h
+++ b/include/tvm/ffi/expected.h
@@ -113,62 +113,80 @@ class Expected {
   // NOLINTNEXTLINE(google-explicit-constructor,runtime/explicit)
   Expected(Unexpected<E> unexpected) : 
data_(Any(std::move(unexpected).error())) {}
 
-  /*!
-   * \brief Check if the Expected contains a success value.
-   * \return True if contains success value, false if contains error.
-   * \note Checks for Error first to handle cases where T is a base class of 
Error.
-   */
-  TVM_FFI_INLINE bool is_ok() const { return !data_.as<Error>().has_value(); }
+  /*! \brief Returns true if the Expected contains a success value. */
+  TVM_FFI_INLINE bool is_ok() const noexcept {
+    return data_.type_index() != TypeIndex::kTVMFFIError;
+  }
 
-  /*!
-   * \brief Check if the Expected contains an error.
-   * \return True if contains error, false if contains success value.
-   */
-  TVM_FFI_INLINE bool is_err() const { return !is_ok(); }
+  /*! \brief Returns true if the Expected contains an error. */
+  TVM_FFI_INLINE bool is_err() const noexcept {
+    return data_.type_index() == TypeIndex::kTVMFFIError;
+  }
 
-  /*!
-   * \brief Alias for is_ok().
-   * \return True if contains success value.
-   */
-  TVM_FFI_INLINE bool has_value() const { return is_ok(); }
+  /*! \brief Alias for is_ok(). */
+  TVM_FFI_INLINE bool has_value() const noexcept { return is_ok(); }
 
-  /*! \brief Access the success value. Throws the contained error if is_err(). 
*/
+  /*! \brief Returns the success value, or throws the contained error. */
   TVM_FFI_INLINE T value() const& {
-    if (is_err()) throw data_.cast<Error>();
-    return data_.cast<T>();
+    if (TVM_FFI_PREDICT_TRUE(is_ok())) {
+      return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data_);
+    }
+    throw details::AnyUnsafe::CopyFromAnyViewAfterCheck<Error>(data_);
   }
-  /*! \brief Access the success value (rvalue). Throws the contained error if 
is_err(). */
+
+  /*! \brief Returns the success value (moved out), or throws the contained 
error. */
   TVM_FFI_INLINE T value() && {
-    if (is_err()) throw std::move(data_).template cast<Error>();
-    return std::move(data_).template cast<T>();
+    if (TVM_FFI_PREDICT_TRUE(is_ok())) {
+      return details::AnyUnsafe::MoveFromAnyAfterCheck<T>(std::move(data_));
+    }
+    throw details::AnyUnsafe::MoveFromAnyAfterCheck<Error>(std::move(data_));
   }
 
-  /*! \brief Access the error. Throws RuntimeError if is_ok(). */
+  /*! \brief Returns the contained error, or throws RuntimeError if is_ok(). */
   TVM_FFI_INLINE Error error() const& {
-    if (!is_err()) TVM_FFI_THROW(RuntimeError) << "Bad expected access: 
contains value, not error";
-    return data_.cast<Error>();
+    // No branch hint: error() is itself a cold path — callers only invoke it
+    // after observing !is_ok(), so the branch direction here doesn't matter.
+    if (is_ok()) {
+      TVM_FFI_THROW(RuntimeError) << "Bad expected access: contains value, not 
error";
+    }
+    return details::AnyUnsafe::CopyFromAnyViewAfterCheck<Error>(data_);
   }
-  /*! \brief Access the error (rvalue). Throws RuntimeError if is_ok(). */
+
+  /*! \brief Returns the contained error (moved out), or throws RuntimeError 
if is_ok(). */
   TVM_FFI_INLINE Error error() && {
-    if (!is_err()) TVM_FFI_THROW(RuntimeError) << "Bad expected access: 
contains value, not error";
-    return std::move(data_).template cast<Error>();
+    // No branch hint: error() is itself a cold path — callers only invoke it
+    // after observing !is_ok(), so the branch direction here doesn't matter.
+    if (is_ok()) {
+      TVM_FFI_THROW(RuntimeError) << "Bad expected access: contains value, not 
error";
+    }
+    return details::AnyUnsafe::MoveFromAnyAfterCheck<Error>(std::move(data_));
   }
 
   /*!
-   * \brief Get the success value or a default value.
-   * \param default_value The value to return if Expected contains an error.
-   * \return The success value if present, otherwise the default value.
+   * \brief Returns the success value, or \p default_value if the Expected 
holds an error.
    */
   template <typename U = std::remove_cv_t<T>>
-  TVM_FFI_INLINE T value_or(U&& default_value) const {
-    if (is_ok()) {
-      return data_.cast<T>();
+  TVM_FFI_INLINE T value_or(U&& default_value) const& {
+    if (TVM_FFI_PREDICT_TRUE(is_ok())) {
+      return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data_);
+    }
+    return T(std::forward<U>(default_value));
+  }
+
+  /*!
+   * \brief Returns the success value (moved out), or \p default_value if the 
Expected holds an
+   * error.
+   */
+  template <typename U = std::remove_cv_t<T>>
+  TVM_FFI_INLINE T value_or(U&& default_value) && {
+    if (TVM_FFI_PREDICT_TRUE(is_ok())) {
+      return details::AnyUnsafe::MoveFromAnyAfterCheck<T>(std::move(data_));
     }
     return T(std::forward<U>(default_value));
   }
 
  private:
-  Any data_;  // Holds either T or Error
+  Any data_;  // Invariant: holds a T (type_index != kTVMFFIError) or an Error.
 };
 
 // TypeTraits specialization for Expected<T>
diff --git a/include/tvm/ffi/function.h b/include/tvm/ffi/function.h
index 1100603..6f2dec2 100644
--- a/include/tvm/ffi/function.h
+++ b/include/tvm/ffi/function.h
@@ -914,6 +914,10 @@ struct TypeTraits<TypedFunction<FType>> : public 
TypeTraitsBase {
     return 
TypedFunction<FType>(TypeTraits<Function>::CopyFromAnyViewAfterCheck(src));
   }
 
+  TVM_FFI_INLINE static TypedFunction<FType> MoveFromAnyAfterCheck(TVMFFIAny* 
src) {
+    return 
TypedFunction<FType>(TypeTraits<Function>::MoveFromAnyAfterCheck(src));
+  }
+
   TVM_FFI_INLINE static std::optional<TypedFunction<FType>> TryCastFromAnyView(
       const TVMFFIAny* src) {
     std::optional<Function> opt = 
TypeTraits<Function>::TryCastFromAnyView(src);
diff --git a/include/tvm/ffi/type_traits.h b/include/tvm/ffi/type_traits.h
index 21afde0..deb8247 100644
--- a/include/tvm/ffi/type_traits.h
+++ b/include/tvm/ffi/type_traits.h
@@ -215,6 +215,7 @@ struct TypeTraits<StrictBool> : public TypeTraitsBase {
   }
 
   TVM_FFI_INLINE static StrictBool CopyFromAnyViewAfterCheck(const TVMFFIAny* 
src) {
+    TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIBool);
     return static_cast<bool>(src->v_int64);
   }
 
@@ -254,6 +255,7 @@ struct TypeTraits<bool> : public TypeTraitsBase {
   }
 
   TVM_FFI_INLINE static bool CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+    TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIBool);
     return static_cast<bool>(src->v_int64);
   }
 
@@ -301,6 +303,7 @@ struct TypeTraits<Int, 
std::enable_if_t<std::is_integral_v<Int>>> : public TypeT
   }
 
   TVM_FFI_INLINE static Int CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+    TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIInt);
     return static_cast<Int>(src->v_int64);
   }
 
@@ -356,6 +359,7 @@ struct TypeTraits<IntEnum, 
std::enable_if_t<is_integeral_enum_v<IntEnum>>> : pub
   }
 
   TVM_FFI_INLINE static IntEnum CopyFromAnyViewAfterCheck(const TVMFFIAny* 
src) {
+    TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIInt);
     return static_cast<IntEnum>(src->v_int64);
   }
 
@@ -397,6 +401,7 @@ struct TypeTraits<Float, 
std::enable_if_t<std::is_floating_point_v<Float>>>
   }
 
   TVM_FFI_INLINE static Float CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+    TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIFloat);
     return static_cast<Float>(src->v_float64);
   }
 
@@ -440,7 +445,10 @@ struct TypeTraits<void*> : public TypeTraitsBase {
     return src->type_index == TypeIndex::kTVMFFIOpaquePtr;
   }
 
-  TVM_FFI_INLINE static void* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) 
{ return src->v_ptr; }
+  TVM_FFI_INLINE static void* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+    TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIOpaquePtr);
+    return src->v_ptr;
+  }
 
   TVM_FFI_INLINE static void* MoveFromAnyAfterCheck(TVMFFIAny* src) {
     // POD type, we can just copy the value
@@ -485,6 +493,7 @@ struct TypeTraits<DLDevice> : public TypeTraitsBase {
   }
 
   TVM_FFI_INLINE static DLDevice CopyFromAnyViewAfterCheck(const TVMFFIAny* 
src) {
+    TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIDevice);
     return src->v_device;
   }
 
@@ -525,6 +534,7 @@ struct TypeTraits<DLTensor*> : public TypeTraitsBase {
   }
 
   TVM_FFI_INLINE static DLTensor* CopyFromAnyViewAfterCheck(const TVMFFIAny* 
src) {
+    TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIDLTensorPtr);
     return static_cast<DLTensor*>(src->v_ptr);
   }
 
diff --git a/tests/cpp/test_expected.cc b/tests/cpp/test_expected.cc
index 26f0595..a1cdfb1 100644
--- a/tests/cpp/test_expected.cc
+++ b/tests/cpp/test_expected.cc
@@ -19,6 +19,7 @@
 #include <gtest/gtest.h>
 #include <tvm/ffi/any.h>
 #include <tvm/ffi/container/array.h>
+#include <tvm/ffi/dtype.h>
 #include <tvm/ffi/error.h>
 #include <tvm/ffi/expected.h>
 #include <tvm/ffi/function.h>
@@ -342,4 +343,38 @@ TEST(Expected, TryCastIncompatible) {
   EXPECT_FALSE(result.has_value());  // Cannot convert String to Expected<int>
 }
 
+// Test that Expected<DLDataType>::value() && compiles and runs correctly.
+// Requires TypeTraits<DLDataType>::MoveFromAnyAfterCheck to be defined.
+TEST(ExpectedRvalueMove, DLDataTypeMoveCompiles) {
+  Expected<DLDataType> e = DLDataType{kDLFloat, 32, 1};
+  DLDataType moved = std::move(e).value();
+  EXPECT_EQ(moved.code, kDLFloat);
+  EXPECT_EQ(moved.bits, 32);
+  EXPECT_EQ(moved.lanes, 1);
+}
+
+// Test that value_or() && moves rather than copies for Object types.
+TEST(ExpectedRvalueMove, ValueOrMovesNotCopies) {
+  Expected<String> e = String("hello");
+  String moved = std::move(e).value_or(String("default"));
+  EXPECT_EQ(moved, "hello");
+}
+
+// Test value_or() && on error path returns default.
+TEST(ExpectedRvalueMove, ValueOrRvalueErrorPath) {
+  Expected<String> e = Error("ValueError", "oops", "");
+  String result = std::move(e).value_or(String("fallback"));
+  EXPECT_EQ(result, "fallback");
+}
+
+// Test POD types compile and run correctly with rvalue value().
+TEST(ExpectedRvalueMove, PodTypesCompile) {
+  EXPECT_EQ(std::move(Expected<int64_t>(42)).value(), 42);
+  EXPECT_EQ(std::move(Expected<double>(3.14)).value(), 3.14);
+  EXPECT_EQ(std::move(Expected<bool>(true)).value(), true);
+  DLDataType dtype{kDLInt, 64, 1};
+  EXPECT_EQ(std::move(Expected<DLDataType>(dtype)).value().code, kDLInt);
+  EXPECT_EQ(std::move(Expected<DLDataType>(dtype)).value().bits, 64);
+}
+
 }  // namespace

Reply via email to