This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 1e8c998 [FEAT] Optimize Expected<T> for minimal compiled code and
efficiency (#599)
1e8c998 is described below
commit 1e8c998269e8667b35c576bbaeb37fadb94593b1
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu May 28 13:23:27 2026 -0400
[FEAT] Optimize Expected<T> for minimal compiled code and efficiency (#599)
## Summary
Optimize `Expected<T>` so the happy paths compile down to a handful of
instructions. The previous implementation routed `is_ok` / `is_err`
through `as<Error>()` (which refcount-bumped the Error tag just to test
type), and `value()` went through `cast<T>()` which performed a
redundant runtime type check on every read.
## Changes
- `is_ok()` / `is_err()` read the type-index discriminator directly. No
indirection.
- `value()` and `error()` use `TVM_FFI_PREDICT_TRUE/FALSE` to hint the
happy path; the type-traits read uses the `*AfterCheck` variants since
the `Expected` invariant already guarantees the held type.
- `value_or()` gains an rvalue overload that moves from the held value
instead of copying when the `Expected` is consumed as a temporary.
- Add `TVM_FFI_UNSAFE_ASSUME` macro (alongside `TVM_FFI_PREDICT_*`)
expanding to `__builtin_assume` / `__assume`. Applied inside non-Object
`CopyFromAnyViewAfterCheck` to hint the compiler about the type-index
invariant.
- Add `MoveFromAnyAfterCheck` to `TypeTraits<DLDataType>`; the gap
prevented `Expected<DLDataType>::value() &&` from compiling. POD type,
so delegate to `CopyFromAnyViewAfterCheck`.
- Add `MoveFromAnyAfterCheck` to `TypeTraits<TypedFunction<FType>>`;
delegate to `TypeTraits<Function>::MoveFromAnyAfterCheck`.
## Codegen (clang++-17 -O2, lines of asm, lower is better)
| Probe | Before | After |
|---|---|---|
| is_ok / is_err (int) | 173 | 5 |
| value() lvalue (int) | 696 | 37 |
| value() rvalue (int) | 725 | 34 |
| value() rvalue (String) | 810 | 96 |
| value() rvalue (Array<int>) | 1763 | 107 |
| make_ok().value() | 738 | 50 |
| std::move(make_ok()).value() | 738 | 50 |
`make_ok().value()` compiles bit-identical to
`std::move(make_ok()).value()` — the prvalue temporary materializes an
xvalue, binding the `&&` overload directly. No `std::move` needed at the
call site.
---
include/tvm/ffi/base_details.h | 24 ++++++++++++
include/tvm/ffi/dtype.h | 6 +++
include/tvm/ffi/expected.h | 88 +++++++++++++++++++++++++-----------------
include/tvm/ffi/function.h | 4 ++
include/tvm/ffi/type_traits.h | 12 +++++-
tests/cpp/test_expected.cc | 35 +++++++++++++++++
6 files changed, 133 insertions(+), 36 deletions(-)
diff --git a/include/tvm/ffi/base_details.h b/include/tvm/ffi/base_details.h
index c2ac2b8..27431b2 100644
--- a/include/tvm/ffi/base_details.h
+++ b/include/tvm/ffi/base_details.h
@@ -114,6 +114,30 @@
#define TVM_FFI_PREDICT_TRUE(cond) (cond)
#endif
+/*!
+ * \brief Translates into __builtin_assume / __assume /
__attribute__((assume)).
+ *
+ * Use ONLY when the external invariant guarantees cond. The compiler
+ * will remove all paths inconsistent with cond. This is not an
+ * assertion or check -- using on a wrong cond will result in
+ * undefined behavior. cond must be side-effect-free.
+ */
+#if defined(__clang__)
+#define TVM_FFI_UNSAFE_ASSUME(cond) __builtin_assume(cond)
+#elif defined(__GNUC__)
+// GCC 13+ supports __attribute__((assume(...))); fall back to the void-cast
+// no-op for older GCC where __builtin_assume is absent.
+#if __GNUC__ >= 13
+#define TVM_FFI_UNSAFE_ASSUME(cond) __attribute__((assume(cond)))
+#else
+#define TVM_FFI_UNSAFE_ASSUME(cond) static_cast<void>(0)
+#endif
+#elif defined(_MSC_VER)
+#define TVM_FFI_UNSAFE_ASSUME(cond) __assume(cond)
+#else
+#define TVM_FFI_UNSAFE_ASSUME(cond) static_cast<void>(0)
+#endif
+
#define TVM_FFI_STR_CONCAT_(__x, __y) __x##__y
#define TVM_FFI_STR_CONCAT(__x, __y) TVM_FFI_STR_CONCAT_(__x, __y)
diff --git a/include/tvm/ffi/dtype.h b/include/tvm/ffi/dtype.h
index e1a44fb..6aa2aa6 100644
--- a/include/tvm/ffi/dtype.h
+++ b/include/tvm/ffi/dtype.h
@@ -164,9 +164,15 @@ struct TypeTraits<DLDataType> : public TypeTraitsBase {
}
TVM_FFI_INLINE static DLDataType CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
+ TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIDataType);
return src->v_dtype;
}
+ TVM_FFI_INLINE static DLDataType MoveFromAnyAfterCheck(TVMFFIAny* src) {
+ // POD type — move is just copy.
+ return CopyFromAnyViewAfterCheck(src);
+ }
+
TVM_FFI_INLINE static std::optional<DLDataType> TryCastFromAnyView(const
TVMFFIAny* src) {
if (src->type_index == TypeIndex::kTVMFFIDataType) {
return src->v_dtype;
diff --git a/include/tvm/ffi/expected.h b/include/tvm/ffi/expected.h
index de35c67..6e822df 100644
--- a/include/tvm/ffi/expected.h
+++ b/include/tvm/ffi/expected.h
@@ -113,62 +113,80 @@ class Expected {
// NOLINTNEXTLINE(google-explicit-constructor,runtime/explicit)
Expected(Unexpected<E> unexpected) :
data_(Any(std::move(unexpected).error())) {}
- /*!
- * \brief Check if the Expected contains a success value.
- * \return True if contains success value, false if contains error.
- * \note Checks for Error first to handle cases where T is a base class of
Error.
- */
- TVM_FFI_INLINE bool is_ok() const { return !data_.as<Error>().has_value(); }
+ /*! \brief Returns true if the Expected contains a success value. */
+ TVM_FFI_INLINE bool is_ok() const noexcept {
+ return data_.type_index() != TypeIndex::kTVMFFIError;
+ }
- /*!
- * \brief Check if the Expected contains an error.
- * \return True if contains error, false if contains success value.
- */
- TVM_FFI_INLINE bool is_err() const { return !is_ok(); }
+ /*! \brief Returns true if the Expected contains an error. */
+ TVM_FFI_INLINE bool is_err() const noexcept {
+ return data_.type_index() == TypeIndex::kTVMFFIError;
+ }
- /*!
- * \brief Alias for is_ok().
- * \return True if contains success value.
- */
- TVM_FFI_INLINE bool has_value() const { return is_ok(); }
+ /*! \brief Alias for is_ok(). */
+ TVM_FFI_INLINE bool has_value() const noexcept { return is_ok(); }
- /*! \brief Access the success value. Throws the contained error if is_err().
*/
+ /*! \brief Returns the success value, or throws the contained error. */
TVM_FFI_INLINE T value() const& {
- if (is_err()) throw data_.cast<Error>();
- return data_.cast<T>();
+ if (TVM_FFI_PREDICT_TRUE(is_ok())) {
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data_);
+ }
+ throw details::AnyUnsafe::CopyFromAnyViewAfterCheck<Error>(data_);
}
- /*! \brief Access the success value (rvalue). Throws the contained error if
is_err(). */
+
+ /*! \brief Returns the success value (moved out), or throws the contained
error. */
TVM_FFI_INLINE T value() && {
- if (is_err()) throw std::move(data_).template cast<Error>();
- return std::move(data_).template cast<T>();
+ if (TVM_FFI_PREDICT_TRUE(is_ok())) {
+ return details::AnyUnsafe::MoveFromAnyAfterCheck<T>(std::move(data_));
+ }
+ throw details::AnyUnsafe::MoveFromAnyAfterCheck<Error>(std::move(data_));
}
- /*! \brief Access the error. Throws RuntimeError if is_ok(). */
+ /*! \brief Returns the contained error, or throws RuntimeError if is_ok(). */
TVM_FFI_INLINE Error error() const& {
- if (!is_err()) TVM_FFI_THROW(RuntimeError) << "Bad expected access:
contains value, not error";
- return data_.cast<Error>();
+ // No branch hint: error() is itself a cold path — callers only invoke it
+ // after observing !is_ok(), so the branch direction here doesn't matter.
+ if (is_ok()) {
+ TVM_FFI_THROW(RuntimeError) << "Bad expected access: contains value, not
error";
+ }
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<Error>(data_);
}
- /*! \brief Access the error (rvalue). Throws RuntimeError if is_ok(). */
+
+ /*! \brief Returns the contained error (moved out), or throws RuntimeError
if is_ok(). */
TVM_FFI_INLINE Error error() && {
- if (!is_err()) TVM_FFI_THROW(RuntimeError) << "Bad expected access:
contains value, not error";
- return std::move(data_).template cast<Error>();
+ // No branch hint: error() is itself a cold path — callers only invoke it
+ // after observing !is_ok(), so the branch direction here doesn't matter.
+ if (is_ok()) {
+ TVM_FFI_THROW(RuntimeError) << "Bad expected access: contains value, not
error";
+ }
+ return details::AnyUnsafe::MoveFromAnyAfterCheck<Error>(std::move(data_));
}
/*!
- * \brief Get the success value or a default value.
- * \param default_value The value to return if Expected contains an error.
- * \return The success value if present, otherwise the default value.
+ * \brief Returns the success value, or \p default_value if the Expected
holds an error.
*/
template <typename U = std::remove_cv_t<T>>
- TVM_FFI_INLINE T value_or(U&& default_value) const {
- if (is_ok()) {
- return data_.cast<T>();
+ TVM_FFI_INLINE T value_or(U&& default_value) const& {
+ if (TVM_FFI_PREDICT_TRUE(is_ok())) {
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(data_);
+ }
+ return T(std::forward<U>(default_value));
+ }
+
+ /*!
+ * \brief Returns the success value (moved out), or \p default_value if the
Expected holds an
+ * error.
+ */
+ template <typename U = std::remove_cv_t<T>>
+ TVM_FFI_INLINE T value_or(U&& default_value) && {
+ if (TVM_FFI_PREDICT_TRUE(is_ok())) {
+ return details::AnyUnsafe::MoveFromAnyAfterCheck<T>(std::move(data_));
}
return T(std::forward<U>(default_value));
}
private:
- Any data_; // Holds either T or Error
+ Any data_; // Invariant: holds a T (type_index != kTVMFFIError) or an Error.
};
// TypeTraits specialization for Expected<T>
diff --git a/include/tvm/ffi/function.h b/include/tvm/ffi/function.h
index 1100603..6f2dec2 100644
--- a/include/tvm/ffi/function.h
+++ b/include/tvm/ffi/function.h
@@ -914,6 +914,10 @@ struct TypeTraits<TypedFunction<FType>> : public
TypeTraitsBase {
return
TypedFunction<FType>(TypeTraits<Function>::CopyFromAnyViewAfterCheck(src));
}
+ TVM_FFI_INLINE static TypedFunction<FType> MoveFromAnyAfterCheck(TVMFFIAny*
src) {
+ return
TypedFunction<FType>(TypeTraits<Function>::MoveFromAnyAfterCheck(src));
+ }
+
TVM_FFI_INLINE static std::optional<TypedFunction<FType>> TryCastFromAnyView(
const TVMFFIAny* src) {
std::optional<Function> opt =
TypeTraits<Function>::TryCastFromAnyView(src);
diff --git a/include/tvm/ffi/type_traits.h b/include/tvm/ffi/type_traits.h
index 21afde0..deb8247 100644
--- a/include/tvm/ffi/type_traits.h
+++ b/include/tvm/ffi/type_traits.h
@@ -215,6 +215,7 @@ struct TypeTraits<StrictBool> : public TypeTraitsBase {
}
TVM_FFI_INLINE static StrictBool CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
+ TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIBool);
return static_cast<bool>(src->v_int64);
}
@@ -254,6 +255,7 @@ struct TypeTraits<bool> : public TypeTraitsBase {
}
TVM_FFI_INLINE static bool CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+ TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIBool);
return static_cast<bool>(src->v_int64);
}
@@ -301,6 +303,7 @@ struct TypeTraits<Int,
std::enable_if_t<std::is_integral_v<Int>>> : public TypeT
}
TVM_FFI_INLINE static Int CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+ TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIInt);
return static_cast<Int>(src->v_int64);
}
@@ -356,6 +359,7 @@ struct TypeTraits<IntEnum,
std::enable_if_t<is_integeral_enum_v<IntEnum>>> : pub
}
TVM_FFI_INLINE static IntEnum CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
+ TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIInt);
return static_cast<IntEnum>(src->v_int64);
}
@@ -397,6 +401,7 @@ struct TypeTraits<Float,
std::enable_if_t<std::is_floating_point_v<Float>>>
}
TVM_FFI_INLINE static Float CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+ TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIFloat);
return static_cast<Float>(src->v_float64);
}
@@ -440,7 +445,10 @@ struct TypeTraits<void*> : public TypeTraitsBase {
return src->type_index == TypeIndex::kTVMFFIOpaquePtr;
}
- TVM_FFI_INLINE static void* CopyFromAnyViewAfterCheck(const TVMFFIAny* src)
{ return src->v_ptr; }
+ TVM_FFI_INLINE static void* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
+ TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIOpaquePtr);
+ return src->v_ptr;
+ }
TVM_FFI_INLINE static void* MoveFromAnyAfterCheck(TVMFFIAny* src) {
// POD type, we can just copy the value
@@ -485,6 +493,7 @@ struct TypeTraits<DLDevice> : public TypeTraitsBase {
}
TVM_FFI_INLINE static DLDevice CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
+ TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIDevice);
return src->v_device;
}
@@ -525,6 +534,7 @@ struct TypeTraits<DLTensor*> : public TypeTraitsBase {
}
TVM_FFI_INLINE static DLTensor* CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
+ TVM_FFI_UNSAFE_ASSUME(src->type_index == TypeIndex::kTVMFFIDLTensorPtr);
return static_cast<DLTensor*>(src->v_ptr);
}
diff --git a/tests/cpp/test_expected.cc b/tests/cpp/test_expected.cc
index 26f0595..a1cdfb1 100644
--- a/tests/cpp/test_expected.cc
+++ b/tests/cpp/test_expected.cc
@@ -19,6 +19,7 @@
#include <gtest/gtest.h>
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/dtype.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/expected.h>
#include <tvm/ffi/function.h>
@@ -342,4 +343,38 @@ TEST(Expected, TryCastIncompatible) {
EXPECT_FALSE(result.has_value()); // Cannot convert String to Expected<int>
}
+// Test that Expected<DLDataType>::value() && compiles and runs correctly.
+// Requires TypeTraits<DLDataType>::MoveFromAnyAfterCheck to be defined.
+TEST(ExpectedRvalueMove, DLDataTypeMoveCompiles) {
+ Expected<DLDataType> e = DLDataType{kDLFloat, 32, 1};
+ DLDataType moved = std::move(e).value();
+ EXPECT_EQ(moved.code, kDLFloat);
+ EXPECT_EQ(moved.bits, 32);
+ EXPECT_EQ(moved.lanes, 1);
+}
+
+// Test that value_or() && moves rather than copies for Object types.
+TEST(ExpectedRvalueMove, ValueOrMovesNotCopies) {
+ Expected<String> e = String("hello");
+ String moved = std::move(e).value_or(String("default"));
+ EXPECT_EQ(moved, "hello");
+}
+
+// Test value_or() && on error path returns default.
+TEST(ExpectedRvalueMove, ValueOrRvalueErrorPath) {
+ Expected<String> e = Error("ValueError", "oops", "");
+ String result = std::move(e).value_or(String("fallback"));
+ EXPECT_EQ(result, "fallback");
+}
+
+// Test POD types compile and run correctly with rvalue value().
+TEST(ExpectedRvalueMove, PodTypesCompile) {
+ EXPECT_EQ(std::move(Expected<int64_t>(42)).value(), 42);
+ EXPECT_EQ(std::move(Expected<double>(3.14)).value(), 3.14);
+ EXPECT_EQ(std::move(Expected<bool>(true)).value(), true);
+ DLDataType dtype{kDLInt, 64, 1};
+ EXPECT_EQ(std::move(Expected<DLDataType>(dtype)).value().code, kDLInt);
+ EXPECT_EQ(std::move(Expected<DLDataType>(dtype)).value().bits, 64);
+}
+
} // namespace