This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bcc48303c4 [REFACTOR][S-TIR] Move tvm/support/random_engine.h → 
tvm/s_tir/random_engine.h (#19475)
bcc48303c4 is described below

commit bcc48303c40b730bc9b8785653aad28a076a6df5
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Apr 29 17:29:54 2026 -0400

    [REFACTOR][S-TIR] Move tvm/support/random_engine.h → 
tvm/s_tir/random_engine.h (#19475)
    
    ## Summary
    
    `LinearCongruentialEngine` is used only by s_tir's meta-schedule and
    schedule primitives — no other consumer exists in `src/`, `include/`, or
    `apps/`. `tvm/support/` should stay reserved for genuinely cross-cutting
    utilities; an s_tir-only RNG belongs under s_tir.
    
    - Move `include/tvm/support/random_engine.h` →
    `include/tvm/s_tir/random_engine.h`
    - Rename namespace `tvm::support` → `tvm::s_tir` and update header guard
    - Update 5 s_tir consumer `#include` + namespace references, plus the
    cpptest
    
    Class shape (members, behavior, `std::uniform_random_bit_generator`
    interface) is unchanged. No ABI impact.
---
 include/tvm/s_tir/meta_schedule/mutator.h          | 12 +++---
 include/tvm/s_tir/meta_schedule/task_scheduler.h   |  4 +-
 include/tvm/s_tir/meta_schedule/tune_context.h     |  6 +--
 include/tvm/{support => s_tir}/random_engine.h     | 10 ++---
 include/tvm/s_tir/schedule/schedule.h              | 10 ++---
 src/s_tir/meta_schedule/mutator/mutator.cc         |  7 ++--
 .../meta_schedule/task_scheduler/gradient_based.cc |  6 +--
 src/s_tir/meta_schedule/tune_context.cc            |  2 +-
 src/s_tir/meta_schedule/utils.h                    | 16 ++++----
 src/s_tir/schedule/concrete_schedule.cc            | 13 +++----
 src/s_tir/schedule/concrete_schedule.h             |  6 +--
 src/s_tir/schedule/primitive.h                     | 44 +++++++++++-----------
 src/s_tir/schedule/primitive/sampling.cc           | 44 +++++++++++-----------
 src/s_tir/schedule/schedule.cc                     |  4 +-
 src/s_tir/schedule/traced_schedule.cc              |  5 +--
 tests/cpp/random_engine_test.cc                    |  8 ++--
 16 files changed, 97 insertions(+), 100 deletions(-)

diff --git a/include/tvm/s_tir/meta_schedule/mutator.h 
b/include/tvm/s_tir/meta_schedule/mutator.h
index 42708dec57..9378035e07 100644
--- a/include/tvm/s_tir/meta_schedule/mutator.h
+++ b/include/tvm/s_tir/meta_schedule/mutator.h
@@ -24,9 +24,9 @@
 #include <tvm/ffi/optional.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/object.h>
+#include <tvm/s_tir/random_engine.h>
 #include <tvm/s_tir/schedule/schedule.h>
 #include <tvm/s_tir/schedule/trace.h>
-#include <tvm/support/random_engine.h>
 
 namespace tvm {
 namespace s_tir {
@@ -59,8 +59,8 @@ class MutatorNode : public runtime::Object {
    * \param rand_state The random state for mutation.
    * \return None if mutator failed, otherwise return the mutated trace.
    */
-  virtual ffi::Optional<s_tir::Trace> Apply(
-      const s_tir::Trace& trace, 
support::LinearCongruentialEngine::TRandState* rand_state) = 0;
+  virtual ffi::Optional<s_tir::Trace> Apply(const s_tir::Trace& trace,
+                                            
LinearCongruentialEngine::TRandState* rand_state) = 0;
 
   /*!
    * \brief Clone the mutator.
@@ -89,7 +89,7 @@ class Mutator : public runtime::ObjectRef {
    * \return None if mutator failed, otherwise return the mutated trace.
    */
   using FApply = ffi::TypedFunction<ffi::Optional<s_tir::Trace>(
-      const s_tir::Trace&, support::LinearCongruentialEngine::TRandState 
rand_state)>;
+      const s_tir::Trace&, LinearCongruentialEngine::TRandState rand_state)>;
   /*!
    * \brief Clone the mutator.
    * \return The cloned mutator.
@@ -171,8 +171,8 @@ class PyMutatorNode : public MutatorNode {
   }
 
   void InitializeWithTuneContext(const TuneContext& context) final;
-  ffi::Optional<s_tir::Trace> Apply(
-      const s_tir::Trace& trace, 
support::LinearCongruentialEngine::TRandState* rand_state) final;
+  ffi::Optional<s_tir::Trace> Apply(const s_tir::Trace& trace,
+                                    LinearCongruentialEngine::TRandState* 
rand_state) final;
   Mutator Clone() const final;
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.PyMutator", 
PyMutatorNode, MutatorNode);
 };
diff --git a/include/tvm/s_tir/meta_schedule/task_scheduler.h 
b/include/tvm/s_tir/meta_schedule/task_scheduler.h
index c349a450d1..c8461ad9c8 100644
--- a/include/tvm/s_tir/meta_schedule/task_scheduler.h
+++ b/include/tvm/s_tir/meta_schedule/task_scheduler.h
@@ -29,7 +29,7 @@
 #include <tvm/s_tir/meta_schedule/measure_callback.h>
 #include <tvm/s_tir/meta_schedule/runner.h>
 #include <tvm/s_tir/meta_schedule/tune_context.h>
-#include <tvm/support/random_engine.h>
+#include <tvm/s_tir/random_engine.h>
 
 #include <string>
 #include <vector>
@@ -279,7 +279,7 @@ class TaskScheduler : public runtime::ObjectRef {
    * \return The task scheduler created.
    */
   TVM_DLL static TaskScheduler GradientBased(ffi::Function logger, double 
alpha, int window_size,
-                                             
support::LinearCongruentialEngine::TRandState seed);
+                                             
LinearCongruentialEngine::TRandState seed);
   /*!
    * \brief Create a task scheduler with customized methods on the python-side.
    * \param logger The tuning task's logging function.
diff --git a/include/tvm/s_tir/meta_schedule/tune_context.h 
b/include/tvm/s_tir/meta_schedule/tune_context.h
index d35a4d2331..ace97480a9 100644
--- a/include/tvm/s_tir/meta_schedule/tune_context.h
+++ b/include/tvm/s_tir/meta_schedule/tune_context.h
@@ -32,7 +32,7 @@
 #include <tvm/s_tir/meta_schedule/runner.h>
 #include <tvm/s_tir/meta_schedule/search_strategy.h>
 #include <tvm/s_tir/meta_schedule/space_generator.h>
-#include <tvm/support/random_engine.h>
+#include <tvm/s_tir/random_engine.h>
 #include <tvm/target/target.h>
 
 namespace tvm {
@@ -46,7 +46,7 @@ class TuneContext;
 /*! \brief The auto tuning context. */
 class TuneContextNode : public runtime::Object {
  public:
-  using TRandState = support::LinearCongruentialEngine::TRandState;
+  using TRandState = LinearCongruentialEngine::TRandState;
 
   /*! \brief The workload to be tuned. */
   ffi::Optional<IRModule> mod;
@@ -98,7 +98,7 @@ class TuneContextNode : public runtime::Object {
  */
 class TuneContext : public runtime::ObjectRef {
  public:
-  using TRandState = support::LinearCongruentialEngine::TRandState;
+  using TRandState = LinearCongruentialEngine::TRandState;
   /*!
    * \brief Constructor from ObjectPtr<TuneContextNode>.
    * \param data The object pointer.
diff --git a/include/tvm/support/random_engine.h 
b/include/tvm/s_tir/random_engine.h
similarity index 96%
rename from include/tvm/support/random_engine.h
rename to include/tvm/s_tir/random_engine.h
index 9cd8cf7055..d594e1ba0c 100644
--- a/include/tvm/support/random_engine.h
+++ b/include/tvm/s_tir/random_engine.h
@@ -21,15 +21,15 @@
  * \brief Random number generator. It provides a generic interface consistent 
with
  * `std::uniform_random_bit_generator`
  */
-#ifndef TVM_SUPPORT_RANDOM_ENGINE_H_
-#define TVM_SUPPORT_RANDOM_ENGINE_H_
+#ifndef TVM_S_TIR_RANDOM_ENGINE_H_
+#define TVM_S_TIR_RANDOM_ENGINE_H_
 #include <tvm/runtime/logging.h>
 
 #include <cstdint>
 #include <random>
 
 namespace tvm {
-namespace support {
+namespace s_tir {
 
 /*!
  * \brief This linear congruential engine is a drop-in replacement for 
std::minstd_rand. It strictly
@@ -130,7 +130,7 @@ class LinearCongruentialEngine {
   TRandState* rand_state_ptr_;
 };
 
-}  // namespace support
+}  // namespace s_tir
 }  // namespace tvm
 
-#endif  // TVM_SUPPORT_RANDOM_ENGINE_H_
+#endif  // TVM_S_TIR_RANDOM_ENGINE_H_
diff --git a/include/tvm/s_tir/schedule/schedule.h 
b/include/tvm/s_tir/schedule/schedule.h
index be903e10cd..881f0ce13f 100644
--- a/include/tvm/s_tir/schedule/schedule.h
+++ b/include/tvm/s_tir/schedule/schedule.h
@@ -19,9 +19,9 @@
 #ifndef TVM_S_TIR_SCHEDULE_SCHEDULE_H_
 #define TVM_S_TIR_SCHEDULE_SCHEDULE_H_
 
+#include <tvm/s_tir/random_engine.h>
 #include <tvm/s_tir/schedule/state.h>
 #include <tvm/s_tir/schedule/trace.h>
-#include <tvm/support/random_engine.h>
 #include <tvm/tirx/index_map.h>
 
 namespace tvm {
@@ -150,9 +150,9 @@ class ScheduleNode : public runtime::Object {
    * \brief Seed the randomness
    * \param seed The new random seed, -1 if use device random, otherwise 
non-negative
    */
-  virtual void Seed(support::LinearCongruentialEngine::TRandState seed) = 0;
+  virtual void Seed(LinearCongruentialEngine::TRandState seed) = 0;
   /*! \brief Fork the random state */
-  virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0;
+  virtual LinearCongruentialEngine::TRandState ForkSeed() = 0;
 
  public:
   /******** Lookup/Remove random variables ********/
@@ -909,7 +909,7 @@ class Schedule : public runtime::ObjectRef {
    * \sa ScheduleDebugMask
    * \note The checks performed includes: 1) VerifySRefTree 2) 
VerifyCachedFlags
    */
-  TVM_DLL static Schedule Concrete(IRModule mod, 
support::LinearCongruentialEngine::TRandState seed,
+  TVM_DLL static Schedule Concrete(IRModule mod, 
LinearCongruentialEngine::TRandState seed,
                                    int debug_mask, ScheduleErrorRenderLevel 
error_render_level,
                                    bool enable_check = true);
   /*!
@@ -926,7 +926,7 @@ class Schedule : public runtime::ObjectRef {
    * 1) VerifySRefTree
    * 2) VerifyCachedFlags
    */
-  TVM_DLL static Schedule Traced(IRModule mod, 
support::LinearCongruentialEngine::TRandState seed,
+  TVM_DLL static Schedule Traced(IRModule mod, 
LinearCongruentialEngine::TRandState seed,
                                  int debug_mask, ScheduleErrorRenderLevel 
error_render_level,
                                  bool enable_check = true);
   TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Schedule, runtime::ObjectRef, 
ScheduleNode);
diff --git a/src/s_tir/meta_schedule/mutator/mutator.cc 
b/src/s_tir/meta_schedule/mutator/mutator.cc
index b4b46f8e2a..bbe50bad88 100644
--- a/src/s_tir/meta_schedule/mutator/mutator.cc
+++ b/src/s_tir/meta_schedule/mutator/mutator.cc
@@ -30,8 +30,8 @@ void PyMutatorNode::InitializeWithTuneContext(const 
TuneContext& context) {
   f_initialize_with_tune_context(context);
 }
 
-ffi::Optional<s_tir::Trace> PyMutatorNode::Apply(
-    const s_tir::Trace& trace, support::LinearCongruentialEngine::TRandState* 
rand_state) {
+ffi::Optional<s_tir::Trace> PyMutatorNode::Apply(const s_tir::Trace& trace,
+                                                 
LinearCongruentialEngine::TRandState* rand_state) {
   TVM_FFI_ICHECK(f_apply != nullptr) << "PyMutator's Apply method not 
implemented!";
   return f_apply(trace, *rand_state);
 }
@@ -93,8 +93,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
                   &MutatorNode::InitializeWithTuneContext)
       .def("s_tir.meta_schedule.MutatorApply",
            [](Mutator self, s_tir::Trace trace, TRandState seed) -> 
ffi::Optional<s_tir::Trace> {
-             TRandState seed_ =
-                 (seed != -1) ? seed : 
support::LinearCongruentialEngine::DeviceRandom();
+             TRandState seed_ = (seed != -1) ? seed : 
LinearCongruentialEngine::DeviceRandom();
              return self->Apply(trace, &seed_);
            })
       .def_method("s_tir.meta_schedule.MutatorClone", &MutatorNode::Clone)
diff --git a/src/s_tir/meta_schedule/task_scheduler/gradient_based.cc 
b/src/s_tir/meta_schedule/task_scheduler/gradient_based.cc
index bdc80f3455..fe21d9b6cc 100644
--- a/src/s_tir/meta_schedule/task_scheduler/gradient_based.cc
+++ b/src/s_tir/meta_schedule/task_scheduler/gradient_based.cc
@@ -29,7 +29,7 @@ class GradientBasedNode final : public TaskSchedulerNode {
  public:
   double alpha;
   int window_size;
-  support::LinearCongruentialEngine::TRandState rand_state;
+  LinearCongruentialEngine::TRandState rand_state;
 
   int round_robin_rounds_;
   std::vector<std::vector<double>> best_latency_history_;
@@ -135,12 +135,12 @@ class GradientBasedNode final : public TaskSchedulerNode {
 };
 
 TaskScheduler TaskScheduler::GradientBased(ffi::Function logger, double alpha, 
int window_size,
-                                           
support::LinearCongruentialEngine::TRandState seed) {
+                                           
LinearCongruentialEngine::TRandState seed) {
   ObjectPtr<GradientBasedNode> n = ffi::make_object<GradientBasedNode>();
   n->logger = logger;
   n->alpha = alpha;
   n->window_size = window_size;
-  n->rand_state = support::LinearCongruentialEngine::NormalizeSeed(seed);
+  n->rand_state = LinearCongruentialEngine::NormalizeSeed(seed);
   return TaskScheduler(n);
 }
 
diff --git a/src/s_tir/meta_schedule/tune_context.cc 
b/src/s_tir/meta_schedule/tune_context.cc
index 9a935e931c..c8d7ab4bb0 100644
--- a/src/s_tir/meta_schedule/tune_context.cc
+++ b/src/s_tir/meta_schedule/tune_context.cc
@@ -40,7 +40,7 @@ TuneContext::TuneContext(ffi::Optional<IRModule> mod, 
ffi::Optional<Target> targ
   n->search_strategy = search_strategy;
   n->task_name = task_name;
   n->num_threads = num_threads;
-  n->rand_state = support::LinearCongruentialEngine::NormalizeSeed(rand_state);
+  n->rand_state = LinearCongruentialEngine::NormalizeSeed(rand_state);
   n->logger = logger;
   data_ = std::move(n);
 }
diff --git a/src/s_tir/meta_schedule/utils.h b/src/s_tir/meta_schedule/utils.h
index 9bfbddd7f9..bf5576362c 100644
--- a/src/s_tir/meta_schedule/utils.h
+++ b/src/s_tir/meta_schedule/utils.h
@@ -170,7 +170,7 @@ inline void clear_logging(const char* file, int lineno, 
ffi::Function logging_fu
 }
 
 /*! \brief The type of the random state */
-using TRandState = support::LinearCongruentialEngine::TRandState;
+using TRandState = LinearCongruentialEngine::TRandState;
 
 /*!
  * \brief Get the base64 encoded result of a string.
@@ -242,9 +242,9 @@ inline ffi::String SHash2Hex(const ObjectRef& obj) {
  * \param rand_state The random state to be forked
  * \return The forked random state
  */
-inline support::LinearCongruentialEngine::TRandState ForkSeed(
-    support::LinearCongruentialEngine::TRandState* rand_state) {
-  return support::LinearCongruentialEngine(rand_state).ForkSeed();
+inline LinearCongruentialEngine::TRandState ForkSeed(
+    LinearCongruentialEngine::TRandState* rand_state) {
+  return LinearCongruentialEngine(rand_state).ForkSeed();
 }
 
 /*!
@@ -254,12 +254,12 @@ inline support::LinearCongruentialEngine::TRandState 
ForkSeed(
  * \param n The number of forks
  * \return The forked random states
  */
-inline std::vector<support::LinearCongruentialEngine::TRandState> ForkSeed(
-    support::LinearCongruentialEngine::TRandState* rand_state, int n) {
-  std::vector<support::LinearCongruentialEngine::TRandState> results;
+inline std::vector<LinearCongruentialEngine::TRandState> ForkSeed(
+    LinearCongruentialEngine::TRandState* rand_state, int n) {
+  std::vector<LinearCongruentialEngine::TRandState> results;
   results.reserve(n);
   for (int i = 0; i < n; ++i) {
-    
results.push_back(support::LinearCongruentialEngine(rand_state).ForkSeed());
+    results.push_back(LinearCongruentialEngine(rand_state).ForkSeed());
   }
   return results;
 }
diff --git a/src/s_tir/schedule/concrete_schedule.cc 
b/src/s_tir/schedule/concrete_schedule.cc
index 51189b7254..0e11a3afe3 100644
--- a/src/s_tir/schedule/concrete_schedule.cc
+++ b/src/s_tir/schedule/concrete_schedule.cc
@@ -24,9 +24,8 @@ namespace tvm {
 namespace s_tir {
 using namespace tvm::tirx;
 
-Schedule Schedule::Concrete(IRModule mod, 
support::LinearCongruentialEngine::TRandState seed,
-                            int debug_mask, ScheduleErrorRenderLevel 
error_render_level,
-                            bool enable_check) {
+Schedule Schedule::Concrete(IRModule mod, LinearCongruentialEngine::TRandState 
seed, int debug_mask,
+                            ScheduleErrorRenderLevel error_render_level, bool 
enable_check) {
   ObjectPtr<ConcreteScheduleNode> n = ffi::make_object<ConcreteScheduleNode>();
   n->state_ = ScheduleState(mod, debug_mask, enable_check);
   n->error_render_level_ = error_render_level;
@@ -226,12 +225,12 @@ Schedule ConcreteScheduleNode::Copy() {
 
 /******** Schedule: Schedule: Sampling ********/
 
-void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState 
seed) {
-  this->rand_state_ = support::LinearCongruentialEngine::NormalizeSeed(seed);
+void ConcreteScheduleNode::Seed(LinearCongruentialEngine::TRandState seed) {
+  this->rand_state_ = LinearCongruentialEngine::NormalizeSeed(seed);
 }
 
-support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() 
{
-  return support::LinearCongruentialEngine(&rand_state_).ForkSeed();
+LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() {
+  return LinearCongruentialEngine(&rand_state_).ForkSeed();
 }
 
 ExprRV ConcreteScheduleNode::SampleCategorical(const ffi::Array<Integer>& 
candidates,
diff --git a/src/s_tir/schedule/concrete_schedule.h 
b/src/s_tir/schedule/concrete_schedule.h
index 84b934ff72..29445da915 100644
--- a/src/s_tir/schedule/concrete_schedule.h
+++ b/src/s_tir/schedule/concrete_schedule.h
@@ -48,7 +48,7 @@ class ConcreteScheduleNode : public ScheduleNode {
   /*! \brief A persistent stateless arithmetic analyzer. */
   std::unique_ptr<arith::Analyzer> analyzer_;
   /*! \brief The value of random state for sampling. */
-  support::LinearCongruentialEngine::TRandState rand_state_;
+  LinearCongruentialEngine::TRandState rand_state_;
 
  public:
   static void RegisterReflection() {
@@ -64,8 +64,8 @@ class ConcreteScheduleNode : public ScheduleNode {
   ffi::Optional<GlobalVar> func_working_on() const final { return 
func_working_on_; }
   void WorkOn(const ffi::String& func_name) final;
   Schedule Copy() override;
-  void Seed(support::LinearCongruentialEngine::TRandState seed) final;
-  support::LinearCongruentialEngine::TRandState ForkSeed() final;
+  void Seed(LinearCongruentialEngine::TRandState seed) final;
+  LinearCongruentialEngine::TRandState ForkSeed() final;
 
  public:
   /******** Lookup random variables ********/
diff --git a/src/s_tir/schedule/primitive.h b/src/s_tir/schedule/primitive.h
index 63213bb4ab..85ef871463 100644
--- a/src/s_tir/schedule/primitive.h
+++ b/src/s_tir/schedule/primitive.h
@@ -19,8 +19,8 @@
 #ifndef TVM_S_TIR_SCHEDULE_PRIMITIVE_H_
 #define TVM_S_TIR_SCHEDULE_PRIMITIVE_H_
 
+#include <tvm/s_tir/random_engine.h>
 #include <tvm/s_tir/schedule/state.h>
-#include <tvm/support/random_engine.h>
 
 #include <vector>
 
@@ -36,8 +36,8 @@ using namespace tvm::tirx;
  * \param max_exclusive The maximum value of the range, exclusive.
  * \return The random integer sampled in the given range.
  */
-TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* 
rand_state,
-                          int32_t min_inclusive, int32_t max_exclusive);
+TVM_DLL int32_t SampleInt(LinearCongruentialEngine::TRandState* rand_state, 
int32_t min_inclusive,
+                          int32_t max_exclusive);
 /*!
  * \brief Sample k random integers from given range without replacement, i.e, 
no duplication.
  * \param rand_state The pointer to schedule's random state
@@ -45,8 +45,8 @@ TVM_DLL int32_t 
SampleInt(support::LinearCongruentialEngine::TRandState* rand_st
  * \param k The total number of samples.
  * \return The randomly selected samples from the n candidates.
  */
-std::vector<int32_t> SampleWithoutReplacement(
-    support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, 
int32_t k);
+std::vector<int32_t> 
SampleWithoutReplacement(LinearCongruentialEngine::TRandState* rand_state,
+                                              int32_t n, int32_t k);
 /*!
  * \brief Sample once category from candidates according to the probability 
weights.
  * \param rand_state The pointer to schedule's random state
@@ -55,7 +55,7 @@ std::vector<int32_t> SampleWithoutReplacement(
  * \param decision The sampling decision, if any
  * \return The random variable sampled from candidates
  */
-TVM_DLL int64_t 
SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state,
+TVM_DLL int64_t SampleCategorical(LinearCongruentialEngine::TRandState* 
rand_state,
                                   const ffi::Array<Integer>& candidates,
                                   const ffi::Array<FloatImm>& probs,
                                   ffi::Optional<Integer>* decision);
@@ -66,7 +66,7 @@ TVM_DLL int64_t 
SampleCategorical(support::LinearCongruentialEngine::TRandState*
  * \return The multinomial sampling function.
  */
 TVM_DLL std::function<int32_t()> MakeMultinomialSampler(
-    support::LinearCongruentialEngine::TRandState* rand_state, const 
std::vector<double>& weights);
+    LinearCongruentialEngine::TRandState* rand_state, const 
std::vector<double>& weights);
 /*!
  * \brief Sample the factors to perfect tile a specific loop
  * \param rand_state The random state
@@ -74,9 +74,8 @@ TVM_DLL std::function<int32_t()> MakeMultinomialSampler(
  * \param n_split The number of tiles to be sampled
  * \return A list of length `n`, the random perfect tile sizes sampled
  */
-TVM_DLL std::vector<int64_t> SamplePerfectTile(
-    support::LinearCongruentialEngine::TRandState* rand_state,  //
-    int32_t extent, int32_t n_splits);
+TVM_DLL std::vector<int64_t> 
SamplePerfectTile(LinearCongruentialEngine::TRandState* rand_state,  //
+                                               int32_t extent, int32_t 
n_splits);
 /*!
  * \brief Sample the factors to perfect tile a specific loop
  * \param rand_state The random state
@@ -85,9 +84,9 @@ TVM_DLL std::vector<int64_t> SamplePerfectTile(
  * \param max_innermost_factor The maximum tile size allowed to be sampled in 
the innermost loop
  * \return A list of length `n`, the random perfect tile sizes sampled
  */
-TVM_DLL std::vector<int64_t> SamplePerfectTile(
-    support::LinearCongruentialEngine::TRandState* rand_state,  //
-    int32_t extent, int32_t n_split, int32_t max_innermost_factor);
+TVM_DLL std::vector<int64_t> 
SamplePerfectTile(LinearCongruentialEngine::TRandState* rand_state,  //
+                                               int32_t extent, int32_t n_split,
+                                               int32_t max_innermost_factor);
 /*!
  * \brief Sample the factors to perfect tile a specific loop
  * \param rand_state The random state
@@ -97,10 +96,10 @@ TVM_DLL std::vector<int64_t> SamplePerfectTile(
  * \param decision The sampling decision
  * \return A list of length `n`, the random perfect tile sizes sampled
  */
-TVM_DLL std::vector<int64_t> SamplePerfectTile(
-    support::LinearCongruentialEngine::TRandState* rand_state,  //
-    const tirx::StmtSRef& loop_sref, int32_t n_split, int32_t 
max_innermost_factor,
-    ffi::Optional<ffi::Array<Integer>>* decision);
+TVM_DLL std::vector<int64_t> 
SamplePerfectTile(LinearCongruentialEngine::TRandState* rand_state,  //
+                                               const tirx::StmtSRef& 
loop_sref, int32_t n_split,
+                                               int32_t max_innermost_factor,
+                                               
ffi::Optional<ffi::Array<Integer>>* decision);
 /*!
  * \brief Sample the factors to a partitioned tile for a specific loop
  *
@@ -117,7 +116,7 @@ TVM_DLL std::vector<int64_t> SamplePerfectTile(
  * \return A list of length `n`, the random partitioned tile sizes sampled
  */
 TVM_DLL std::vector<int64_t> SamplePartitionedTile(
-    support::LinearCongruentialEngine::TRandState* rand_state,  //
+    LinearCongruentialEngine::TRandState* rand_state,  //
     int32_t extent, int32_t n_split, int32_t partition_pos, int32_t 
innerpart_factor);
 /*!
  * \brief Sample the factors to a partitioned tile for a specific loop
@@ -136,7 +135,7 @@ TVM_DLL std::vector<int64_t> SamplePartitionedTile(
  * \return A list of length `n`, the random partitioned tile sizes sampled
  */
 TVM_DLL std::vector<int64_t> SamplePartitionedTile(
-    support::LinearCongruentialEngine::TRandState* rand_state,  //
+    LinearCongruentialEngine::TRandState* rand_state,  //
     const tirx::StmtSRef& loop_sref, int32_t n_split, int32_t partition_pos,
     int32_t innerpart_factor, ffi::Optional<ffi::Array<Integer>>* decision);
 /*!
@@ -147,9 +146,10 @@ TVM_DLL std::vector<int64_t> SamplePartitionedTile(
  * \param decision The sampling decision
  * \return The sampled loop where the input block is to be computed at
  */
-TVM_DLL tirx::StmtSRef SampleComputeLocation(
-    s_tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* 
rand_state,
-    const tirx::StmtSRef& block_sref, ffi::Optional<Integer>* decision);
+TVM_DLL tirx::StmtSRef SampleComputeLocation(s_tir::ScheduleState self,
+                                             
LinearCongruentialEngine::TRandState* rand_state,
+                                             const tirx::StmtSRef& block_sref,
+                                             ffi::Optional<Integer>* decision);
 
 /******** Schedule: Get blocks & loops ********/
 /*!
diff --git a/src/s_tir/schedule/primitive/sampling.cc 
b/src/s_tir/schedule/primitive/sampling.cc
index 273f0d8441..d7851d9130 100644
--- a/src/s_tir/schedule/primitive/sampling.cc
+++ b/src/s_tir/schedule/primitive/sampling.cc
@@ -125,20 +125,20 @@ struct PrimeTable {
   }
 };
 
-int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, 
int32_t min_inclusive,
+int32_t SampleInt(LinearCongruentialEngine::TRandState* rand_state, int32_t 
min_inclusive,
                   int32_t max_exclusive) {
   TVM_FFI_CHECK(min_inclusive < max_exclusive, ValueError)
       << "max_exclusive must be greater than min_inclusive.";
   if (min_inclusive + 1 == max_exclusive) {
     return min_inclusive;
   }
-  support::LinearCongruentialEngine rand_(rand_state);
+  LinearCongruentialEngine rand_(rand_state);
   std::uniform_int_distribution<int32_t> dist(min_inclusive, max_exclusive - 
1);
   return dist(rand_);
 }
 
-std::vector<int32_t> SampleWithoutReplacement(
-    support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, 
int32_t k) {
+std::vector<int32_t> 
SampleWithoutReplacement(LinearCongruentialEngine::TRandState* rand_state,
+                                              int32_t n, int32_t k) {
   if (k == 1) {
     return {SampleInt(rand_state, 0, n)};
   }
@@ -163,7 +163,7 @@ std::vector<int32_t> SampleWithoutReplacement(
   return {order.begin(), order.begin() + k};
 }
 
-int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* 
rand_state,
+int64_t SampleCategorical(LinearCongruentialEngine::TRandState* rand_state,
                           const ffi::Array<Integer>& candidates, const 
ffi::Array<FloatImm>& probs,
                           ffi::Optional<Integer>* decision) {
   TVM_FFI_CHECK(candidates.size() == probs.size(), ValueError)
@@ -177,7 +177,7 @@ int64_t 
SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st
   } else {
     std::vector<double> weights = support::AsVector<FloatImm, double>(probs);
     std::discrete_distribution<int32_t> dist(weights.begin(), weights.end());
-    support::LinearCongruentialEngine rand_(rand_state);
+    LinearCongruentialEngine rand_(rand_state);
     i = dist(rand_);
     TVM_FFI_CHECK(0 <= i && i < n, ValueError)
         << "Unexpected decision generated, where n = " << n << ", but decision 
is: " << i;
@@ -187,8 +187,8 @@ int64_t 
SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st
   return candidates[i]->value;
 }
 
-std::function<int32_t()> MakeMultinomialSampler(
-    support::LinearCongruentialEngine::TRandState* rand_state, const 
std::vector<double>& weights) {
+std::function<int32_t()> 
MakeMultinomialSampler(LinearCongruentialEngine::TRandState* rand_state,
+                                                const std::vector<double>& 
weights) {
   TVM_FFI_ICHECK(!weights.empty());
   std::vector<double> sums;
   sums.reserve(weights.size());
@@ -196,10 +196,10 @@ std::function<int32_t()> MakeMultinomialSampler(
   for (double w : weights) {
     sums.push_back(sum += w);
   }
-  return [rng = support::LinearCongruentialEngine(rand_state).ForkSeed(),
+  return [rng = LinearCongruentialEngine(rand_state).ForkSeed(),
           dist = std::uniform_real_distribution<double>(0.0, sum),
           sums = std::move(sums)]() mutable -> int32_t {
-    support::LinearCongruentialEngine rand_(&rng);
+    LinearCongruentialEngine rand_(&rng);
     double p = dist(rand_);
     int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin();
     int32_t n = sums.size();
@@ -209,7 +209,7 @@ std::function<int32_t()> MakeMultinomialSampler(
   };
 }
 
-std::vector<int64_t> 
SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state,
+std::vector<int64_t> SamplePerfectTile(LinearCongruentialEngine::TRandState* 
rand_state,
                                        int32_t extent, int32_t n_splits) {
   TVM_FFI_CHECK_GE(extent, 1, ValueError) << "Cannot tile a loop with 0 or 
negative extent";
   TVM_FFI_CHECK_GE(n_splits, 1, ValueError) << "Cannot tile a loop to 0 or 
negative splits";
@@ -292,7 +292,7 @@ std::vector<int64_t> 
SamplePerfectTile(support::LinearCongruentialEngine::TRandS
   return result;
 }
 
-std::vector<int64_t> 
SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state,
+std::vector<int64_t> SamplePerfectTile(LinearCongruentialEngine::TRandState* 
rand_state,
                                        int32_t extent, int32_t n_splits,
                                        int32_t max_innermost_factor) {
   if (max_innermost_factor == -1) {
@@ -307,10 +307,10 @@ std::vector<int64_t> 
SamplePerfectTile(support::LinearCongruentialEngine::TRandS
   }
 }
 
-std::vector<int64_t> SamplePerfectTile(
-    support::LinearCongruentialEngine::TRandState* rand_state,  //
-    const tirx::StmtSRef& loop_sref, int32_t n_splits, int32_t 
max_innermost_factor,
-    ffi::Optional<ffi::Array<Integer>>* decision) {
+std::vector<int64_t> SamplePerfectTile(LinearCongruentialEngine::TRandState* 
rand_state,  //
+                                       const tirx::StmtSRef& loop_sref, 
int32_t n_splits,
+                                       int32_t max_innermost_factor,
+                                       ffi::Optional<ffi::Array<Integer>>* 
decision) {
   const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
   const int64_t* extent = GetLoopIntExtent(loop);
   std::vector<int64_t> result;
@@ -347,7 +347,7 @@ std::vector<int64_t> SamplePerfectTile(
 }
 
 TVM_DLL std::vector<int64_t> SamplePartitionedTile(
-    support::LinearCongruentialEngine::TRandState* rand_state,  //
+    LinearCongruentialEngine::TRandState* rand_state,  //
     int32_t extent, int32_t n_splits, int32_t partition_pos, int32_t 
innerpart_factor) {
   if (partition_pos == 0 && innerpart_factor == 1) {
     return SamplePerfectTile(rand_state, extent, n_splits);
@@ -368,10 +368,10 @@ TVM_DLL std::vector<int64_t> SamplePartitionedTile(
   }
 }
 
-std::vector<int64_t> SamplePartitionedTile(
-    support::LinearCongruentialEngine::TRandState* rand_state,  //
-    const tirx::StmtSRef& loop_sref, int32_t n_splits, int32_t partition_pos,
-    int32_t innerpart_factor, ffi::Optional<ffi::Array<Integer>>* decision) {
+std::vector<int64_t> 
SamplePartitionedTile(LinearCongruentialEngine::TRandState* rand_state,  //
+                                           const tirx::StmtSRef& loop_sref, 
int32_t n_splits,
+                                           int32_t partition_pos, int32_t 
innerpart_factor,
+                                           ffi::Optional<ffi::Array<Integer>>* 
decision) {
   const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
   const int64_t* extent = GetLoopIntExtent(loop);
   std::vector<int64_t> result;
@@ -419,7 +419,7 @@ std::vector<int64_t> SamplePartitionedTile(
 }
 
 tirx::StmtSRef SampleComputeLocation(s_tir::ScheduleState self,
-                                     
support::LinearCongruentialEngine::TRandState* rand_state,
+                                     LinearCongruentialEngine::TRandState* 
rand_state,
                                      const StmtSRef& block_sref, 
ffi::Optional<Integer>* decision) {
   // Step 1. Collect all possible compute-at locations.
   auto [location_srefs, location_indices] = CollectComputeLocation(self, 
block_sref);
diff --git a/src/s_tir/schedule/schedule.cc b/src/s_tir/schedule/schedule.cc
index a14ff2aa8a..ba8ef72b57 100644
--- a/src/s_tir/schedule/schedule.cc
+++ b/src/s_tir/schedule/schedule.cc
@@ -70,14 +70,14 @@ TVM_FFI_STATIC_INIT_BLOCK() {
       .def("s_tir.schedule.SBlockRV", []() { return SBlockRV(); })
       .def("s_tir.schedule.LoopRV", []() { return LoopRV(); })
       .def("s_tir.schedule.ConcreteSchedule",
-           [](IRModule mod, support::LinearCongruentialEngine::TRandState 
seed, int debug_mask,
+           [](IRModule mod, LinearCongruentialEngine::TRandState seed, int 
debug_mask,
               int error_render_level, bool enable_check) -> Schedule {
              return Schedule::Concrete(mod, debug_mask, seed,
                                        
static_cast<ScheduleErrorRenderLevel>(error_render_level),
                                        enable_check);
            })
       .def("s_tir.schedule.TracedSchedule",
-           [](IRModule mod, support::LinearCongruentialEngine::TRandState 
seed, int debug_mask,
+           [](IRModule mod, LinearCongruentialEngine::TRandState seed, int 
debug_mask,
               int error_render_level, bool enable_check) -> Schedule {
              return Schedule::Traced(mod, seed, debug_mask,
                                      
static_cast<ScheduleErrorRenderLevel>(error_render_level),
diff --git a/src/s_tir/schedule/traced_schedule.cc 
b/src/s_tir/schedule/traced_schedule.cc
index e43df0835c..b1d78312e1 100644
--- a/src/s_tir/schedule/traced_schedule.cc
+++ b/src/s_tir/schedule/traced_schedule.cc
@@ -22,9 +22,8 @@ namespace tvm {
 namespace s_tir {
 using namespace tvm::tirx;
 
-Schedule Schedule::Traced(IRModule mod, 
support::LinearCongruentialEngine::TRandState seed,
-                          int debug_mask, ScheduleErrorRenderLevel 
error_render_level,
-                          bool enable_check) {
+Schedule Schedule::Traced(IRModule mod, LinearCongruentialEngine::TRandState 
seed, int debug_mask,
+                          ScheduleErrorRenderLevel error_render_level, bool 
enable_check) {
   ObjectPtr<TracedScheduleNode> n = ffi::make_object<TracedScheduleNode>();
   n->state_ = ScheduleState(mod, debug_mask, enable_check);
   n->error_render_level_ = error_render_level;
diff --git a/tests/cpp/random_engine_test.cc b/tests/cpp/random_engine_test.cc
index 42d65aa402..afd02c09b2 100644
--- a/tests/cpp/random_engine_test.cc
+++ b/tests/cpp/random_engine_test.cc
@@ -19,12 +19,12 @@
 
 #include <gtest/gtest.h>
 #include <tvm/runtime/logging.h>
-#include <tvm/support/random_engine.h>
+#include <tvm/s_tir/random_engine.h>
 
 TEST(RandomEngine, Randomness) {
   int64_t rand_state = 0;
 
-  tvm::support::LinearCongruentialEngine rng(&rand_state);
+  tvm::s_tir::LinearCongruentialEngine rng(&rand_state);
   rng.Seed(0x114514);
 
   bool covered[100];
@@ -39,7 +39,7 @@ TEST(RandomEngine, Randomness) {
 
 TEST(RandomEngine, Reproducibility) {
   int64_t rand_state_a = 0, rand_state_b = 0;
-  tvm::support::LinearCongruentialEngine rng_a(&rand_state_a), 
rng_b(&rand_state_b);
+  tvm::s_tir::LinearCongruentialEngine rng_a(&rand_state_a), 
rng_b(&rand_state_b);
 
   rng_a.Seed(0x23456789);
   rng_b.Seed(0x23456789);
@@ -51,7 +51,7 @@ TEST(RandomEngine, Reproducibility) {
 
 TEST(RandomEngine, Serialization) {
   int64_t rand_state_a = 0, rand_state_b = 0;
-  tvm::support::LinearCongruentialEngine rng_a(&rand_state_a), 
rng_b(&rand_state_b);
+  tvm::s_tir::LinearCongruentialEngine rng_a(&rand_state_a), 
rng_b(&rand_state_b);
 
   rng_a.Seed(0x56728);
 


Reply via email to