This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bcc48303c4 [REFACTOR][S-TIR] Move tvm/support/random_engine.h →
tvm/s_tir/random_engine.h (#19475)
bcc48303c4 is described below
commit bcc48303c40b730bc9b8785653aad28a076a6df5
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Apr 29 17:29:54 2026 -0400
[REFACTOR][S-TIR] Move tvm/support/random_engine.h →
tvm/s_tir/random_engine.h (#19475)
## Summary
`LinearCongruentialEngine` is used only by s_tir's meta-schedule and
schedule primitives — no other consumer exists in `src/`, `include/`, or
`apps/`. `tvm/support/` should stay reserved for genuinely cross-cutting
utilities; an s_tir-only RNG belongs under s_tir.
- Move `include/tvm/support/random_engine.h` →
`include/tvm/s_tir/random_engine.h`
- Rename namespace `tvm::support` → `tvm::s_tir` and update header guard
- Update 5 s_tir consumer `#include` + namespace references, plus the
cpptest
Class shape (members, behavior, `std::uniform_random_bit_generator`
interface) is unchanged. No ABI impact.
---
include/tvm/s_tir/meta_schedule/mutator.h | 12 +++---
include/tvm/s_tir/meta_schedule/task_scheduler.h | 4 +-
include/tvm/s_tir/meta_schedule/tune_context.h | 6 +--
include/tvm/{support => s_tir}/random_engine.h | 10 ++---
include/tvm/s_tir/schedule/schedule.h | 10 ++---
src/s_tir/meta_schedule/mutator/mutator.cc | 7 ++--
.../meta_schedule/task_scheduler/gradient_based.cc | 6 +--
src/s_tir/meta_schedule/tune_context.cc | 2 +-
src/s_tir/meta_schedule/utils.h | 16 ++++----
src/s_tir/schedule/concrete_schedule.cc | 13 +++----
src/s_tir/schedule/concrete_schedule.h | 6 +--
src/s_tir/schedule/primitive.h | 44 +++++++++++-----------
src/s_tir/schedule/primitive/sampling.cc | 44 +++++++++++-----------
src/s_tir/schedule/schedule.cc | 4 +-
src/s_tir/schedule/traced_schedule.cc | 5 +--
tests/cpp/random_engine_test.cc | 8 ++--
16 files changed, 97 insertions(+), 100 deletions(-)
diff --git a/include/tvm/s_tir/meta_schedule/mutator.h
b/include/tvm/s_tir/meta_schedule/mutator.h
index 42708dec57..9378035e07 100644
--- a/include/tvm/s_tir/meta_schedule/mutator.h
+++ b/include/tvm/s_tir/meta_schedule/mutator.h
@@ -24,9 +24,9 @@
#include <tvm/ffi/optional.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/object.h>
+#include <tvm/s_tir/random_engine.h>
#include <tvm/s_tir/schedule/schedule.h>
#include <tvm/s_tir/schedule/trace.h>
-#include <tvm/support/random_engine.h>
namespace tvm {
namespace s_tir {
@@ -59,8 +59,8 @@ class MutatorNode : public runtime::Object {
* \param rand_state The random state for mutation.
* \return None if mutator failed, otherwise return the mutated trace.
*/
- virtual ffi::Optional<s_tir::Trace> Apply(
- const s_tir::Trace& trace,
support::LinearCongruentialEngine::TRandState* rand_state) = 0;
+ virtual ffi::Optional<s_tir::Trace> Apply(const s_tir::Trace& trace,
+
LinearCongruentialEngine::TRandState* rand_state) = 0;
/*!
* \brief Clone the mutator.
@@ -89,7 +89,7 @@ class Mutator : public runtime::ObjectRef {
* \return None if mutator failed, otherwise return the mutated trace.
*/
using FApply = ffi::TypedFunction<ffi::Optional<s_tir::Trace>(
- const s_tir::Trace&, support::LinearCongruentialEngine::TRandState
rand_state)>;
+ const s_tir::Trace&, LinearCongruentialEngine::TRandState rand_state)>;
/*!
* \brief Clone the mutator.
* \return The cloned mutator.
@@ -171,8 +171,8 @@ class PyMutatorNode : public MutatorNode {
}
void InitializeWithTuneContext(const TuneContext& context) final;
- ffi::Optional<s_tir::Trace> Apply(
- const s_tir::Trace& trace,
support::LinearCongruentialEngine::TRandState* rand_state) final;
+ ffi::Optional<s_tir::Trace> Apply(const s_tir::Trace& trace,
+ LinearCongruentialEngine::TRandState*
rand_state) final;
Mutator Clone() const final;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.PyMutator",
PyMutatorNode, MutatorNode);
};
diff --git a/include/tvm/s_tir/meta_schedule/task_scheduler.h
b/include/tvm/s_tir/meta_schedule/task_scheduler.h
index c349a450d1..c8461ad9c8 100644
--- a/include/tvm/s_tir/meta_schedule/task_scheduler.h
+++ b/include/tvm/s_tir/meta_schedule/task_scheduler.h
@@ -29,7 +29,7 @@
#include <tvm/s_tir/meta_schedule/measure_callback.h>
#include <tvm/s_tir/meta_schedule/runner.h>
#include <tvm/s_tir/meta_schedule/tune_context.h>
-#include <tvm/support/random_engine.h>
+#include <tvm/s_tir/random_engine.h>
#include <string>
#include <vector>
@@ -279,7 +279,7 @@ class TaskScheduler : public runtime::ObjectRef {
* \return The task scheduler created.
*/
TVM_DLL static TaskScheduler GradientBased(ffi::Function logger, double
alpha, int window_size,
-
support::LinearCongruentialEngine::TRandState seed);
+
LinearCongruentialEngine::TRandState seed);
/*!
* \brief Create a task scheduler with customized methods on the python-side.
* \param logger The tuning task's logging function.
diff --git a/include/tvm/s_tir/meta_schedule/tune_context.h
b/include/tvm/s_tir/meta_schedule/tune_context.h
index d35a4d2331..ace97480a9 100644
--- a/include/tvm/s_tir/meta_schedule/tune_context.h
+++ b/include/tvm/s_tir/meta_schedule/tune_context.h
@@ -32,7 +32,7 @@
#include <tvm/s_tir/meta_schedule/runner.h>
#include <tvm/s_tir/meta_schedule/search_strategy.h>
#include <tvm/s_tir/meta_schedule/space_generator.h>
-#include <tvm/support/random_engine.h>
+#include <tvm/s_tir/random_engine.h>
#include <tvm/target/target.h>
namespace tvm {
@@ -46,7 +46,7 @@ class TuneContext;
/*! \brief The auto tuning context. */
class TuneContextNode : public runtime::Object {
public:
- using TRandState = support::LinearCongruentialEngine::TRandState;
+ using TRandState = LinearCongruentialEngine::TRandState;
/*! \brief The workload to be tuned. */
ffi::Optional<IRModule> mod;
@@ -98,7 +98,7 @@ class TuneContextNode : public runtime::Object {
*/
class TuneContext : public runtime::ObjectRef {
public:
- using TRandState = support::LinearCongruentialEngine::TRandState;
+ using TRandState = LinearCongruentialEngine::TRandState;
/*!
* \brief Constructor from ObjectPtr<TuneContextNode>.
* \param data The object pointer.
diff --git a/include/tvm/support/random_engine.h
b/include/tvm/s_tir/random_engine.h
similarity index 96%
rename from include/tvm/support/random_engine.h
rename to include/tvm/s_tir/random_engine.h
index 9cd8cf7055..d594e1ba0c 100644
--- a/include/tvm/support/random_engine.h
+++ b/include/tvm/s_tir/random_engine.h
@@ -21,15 +21,15 @@
* \brief Random number generator. It provides a generic interface consistent
with
* `std::uniform_random_bit_generator`
*/
-#ifndef TVM_SUPPORT_RANDOM_ENGINE_H_
-#define TVM_SUPPORT_RANDOM_ENGINE_H_
+#ifndef TVM_S_TIR_RANDOM_ENGINE_H_
+#define TVM_S_TIR_RANDOM_ENGINE_H_
#include <tvm/runtime/logging.h>
#include <cstdint>
#include <random>
namespace tvm {
-namespace support {
+namespace s_tir {
/*!
* \brief This linear congruential engine is a drop-in replacement for
std::minstd_rand. It strictly
@@ -130,7 +130,7 @@ class LinearCongruentialEngine {
TRandState* rand_state_ptr_;
};
-} // namespace support
+} // namespace s_tir
} // namespace tvm
-#endif // TVM_SUPPORT_RANDOM_ENGINE_H_
+#endif // TVM_S_TIR_RANDOM_ENGINE_H_
diff --git a/include/tvm/s_tir/schedule/schedule.h
b/include/tvm/s_tir/schedule/schedule.h
index be903e10cd..881f0ce13f 100644
--- a/include/tvm/s_tir/schedule/schedule.h
+++ b/include/tvm/s_tir/schedule/schedule.h
@@ -19,9 +19,9 @@
#ifndef TVM_S_TIR_SCHEDULE_SCHEDULE_H_
#define TVM_S_TIR_SCHEDULE_SCHEDULE_H_
+#include <tvm/s_tir/random_engine.h>
#include <tvm/s_tir/schedule/state.h>
#include <tvm/s_tir/schedule/trace.h>
-#include <tvm/support/random_engine.h>
#include <tvm/tirx/index_map.h>
namespace tvm {
@@ -150,9 +150,9 @@ class ScheduleNode : public runtime::Object {
* \brief Seed the randomness
* \param seed The new random seed, -1 if use device random, otherwise
non-negative
*/
- virtual void Seed(support::LinearCongruentialEngine::TRandState seed) = 0;
+ virtual void Seed(LinearCongruentialEngine::TRandState seed) = 0;
/*! \brief Fork the random state */
- virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0;
+ virtual LinearCongruentialEngine::TRandState ForkSeed() = 0;
public:
/******** Lookup/Remove random variables ********/
@@ -909,7 +909,7 @@ class Schedule : public runtime::ObjectRef {
* \sa ScheduleDebugMask
* \note The checks performed includes: 1) VerifySRefTree 2)
VerifyCachedFlags
*/
- TVM_DLL static Schedule Concrete(IRModule mod,
support::LinearCongruentialEngine::TRandState seed,
+ TVM_DLL static Schedule Concrete(IRModule mod,
LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel
error_render_level,
bool enable_check = true);
/*!
@@ -926,7 +926,7 @@ class Schedule : public runtime::ObjectRef {
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
- TVM_DLL static Schedule Traced(IRModule mod,
support::LinearCongruentialEngine::TRandState seed,
+ TVM_DLL static Schedule Traced(IRModule mod,
LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel
error_render_level,
bool enable_check = true);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Schedule, runtime::ObjectRef,
ScheduleNode);
diff --git a/src/s_tir/meta_schedule/mutator/mutator.cc
b/src/s_tir/meta_schedule/mutator/mutator.cc
index b4b46f8e2a..bbe50bad88 100644
--- a/src/s_tir/meta_schedule/mutator/mutator.cc
+++ b/src/s_tir/meta_schedule/mutator/mutator.cc
@@ -30,8 +30,8 @@ void PyMutatorNode::InitializeWithTuneContext(const
TuneContext& context) {
f_initialize_with_tune_context(context);
}
-ffi::Optional<s_tir::Trace> PyMutatorNode::Apply(
- const s_tir::Trace& trace, support::LinearCongruentialEngine::TRandState*
rand_state) {
+ffi::Optional<s_tir::Trace> PyMutatorNode::Apply(const s_tir::Trace& trace,
+
LinearCongruentialEngine::TRandState* rand_state) {
TVM_FFI_ICHECK(f_apply != nullptr) << "PyMutator's Apply method not
implemented!";
return f_apply(trace, *rand_state);
}
@@ -93,8 +93,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
&MutatorNode::InitializeWithTuneContext)
.def("s_tir.meta_schedule.MutatorApply",
[](Mutator self, s_tir::Trace trace, TRandState seed) ->
ffi::Optional<s_tir::Trace> {
- TRandState seed_ =
- (seed != -1) ? seed :
support::LinearCongruentialEngine::DeviceRandom();
+ TRandState seed_ = (seed != -1) ? seed :
LinearCongruentialEngine::DeviceRandom();
return self->Apply(trace, &seed_);
})
.def_method("s_tir.meta_schedule.MutatorClone", &MutatorNode::Clone)
diff --git a/src/s_tir/meta_schedule/task_scheduler/gradient_based.cc
b/src/s_tir/meta_schedule/task_scheduler/gradient_based.cc
index bdc80f3455..fe21d9b6cc 100644
--- a/src/s_tir/meta_schedule/task_scheduler/gradient_based.cc
+++ b/src/s_tir/meta_schedule/task_scheduler/gradient_based.cc
@@ -29,7 +29,7 @@ class GradientBasedNode final : public TaskSchedulerNode {
public:
double alpha;
int window_size;
- support::LinearCongruentialEngine::TRandState rand_state;
+ LinearCongruentialEngine::TRandState rand_state;
int round_robin_rounds_;
std::vector<std::vector<double>> best_latency_history_;
@@ -135,12 +135,12 @@ class GradientBasedNode final : public TaskSchedulerNode {
};
TaskScheduler TaskScheduler::GradientBased(ffi::Function logger, double alpha,
int window_size,
-
support::LinearCongruentialEngine::TRandState seed) {
+
LinearCongruentialEngine::TRandState seed) {
ObjectPtr<GradientBasedNode> n = ffi::make_object<GradientBasedNode>();
n->logger = logger;
n->alpha = alpha;
n->window_size = window_size;
- n->rand_state = support::LinearCongruentialEngine::NormalizeSeed(seed);
+ n->rand_state = LinearCongruentialEngine::NormalizeSeed(seed);
return TaskScheduler(n);
}
diff --git a/src/s_tir/meta_schedule/tune_context.cc
b/src/s_tir/meta_schedule/tune_context.cc
index 9a935e931c..c8d7ab4bb0 100644
--- a/src/s_tir/meta_schedule/tune_context.cc
+++ b/src/s_tir/meta_schedule/tune_context.cc
@@ -40,7 +40,7 @@ TuneContext::TuneContext(ffi::Optional<IRModule> mod,
ffi::Optional<Target> targ
n->search_strategy = search_strategy;
n->task_name = task_name;
n->num_threads = num_threads;
- n->rand_state = support::LinearCongruentialEngine::NormalizeSeed(rand_state);
+ n->rand_state = LinearCongruentialEngine::NormalizeSeed(rand_state);
n->logger = logger;
data_ = std::move(n);
}
diff --git a/src/s_tir/meta_schedule/utils.h b/src/s_tir/meta_schedule/utils.h
index 9bfbddd7f9..bf5576362c 100644
--- a/src/s_tir/meta_schedule/utils.h
+++ b/src/s_tir/meta_schedule/utils.h
@@ -170,7 +170,7 @@ inline void clear_logging(const char* file, int lineno,
ffi::Function logging_fu
}
/*! \brief The type of the random state */
-using TRandState = support::LinearCongruentialEngine::TRandState;
+using TRandState = LinearCongruentialEngine::TRandState;
/*!
* \brief Get the base64 encoded result of a string.
@@ -242,9 +242,9 @@ inline ffi::String SHash2Hex(const ObjectRef& obj) {
* \param rand_state The random state to be forked
* \return The forked random state
*/
-inline support::LinearCongruentialEngine::TRandState ForkSeed(
- support::LinearCongruentialEngine::TRandState* rand_state) {
- return support::LinearCongruentialEngine(rand_state).ForkSeed();
+inline LinearCongruentialEngine::TRandState ForkSeed(
+ LinearCongruentialEngine::TRandState* rand_state) {
+ return LinearCongruentialEngine(rand_state).ForkSeed();
}
/*!
@@ -254,12 +254,12 @@ inline support::LinearCongruentialEngine::TRandState
ForkSeed(
* \param n The number of forks
* \return The forked random states
*/
-inline std::vector<support::LinearCongruentialEngine::TRandState> ForkSeed(
- support::LinearCongruentialEngine::TRandState* rand_state, int n) {
- std::vector<support::LinearCongruentialEngine::TRandState> results;
+inline std::vector<LinearCongruentialEngine::TRandState> ForkSeed(
+ LinearCongruentialEngine::TRandState* rand_state, int n) {
+ std::vector<LinearCongruentialEngine::TRandState> results;
results.reserve(n);
for (int i = 0; i < n; ++i) {
-
results.push_back(support::LinearCongruentialEngine(rand_state).ForkSeed());
+ results.push_back(LinearCongruentialEngine(rand_state).ForkSeed());
}
return results;
}
diff --git a/src/s_tir/schedule/concrete_schedule.cc
b/src/s_tir/schedule/concrete_schedule.cc
index 51189b7254..0e11a3afe3 100644
--- a/src/s_tir/schedule/concrete_schedule.cc
+++ b/src/s_tir/schedule/concrete_schedule.cc
@@ -24,9 +24,8 @@ namespace tvm {
namespace s_tir {
using namespace tvm::tirx;
-Schedule Schedule::Concrete(IRModule mod,
support::LinearCongruentialEngine::TRandState seed,
- int debug_mask, ScheduleErrorRenderLevel
error_render_level,
- bool enable_check) {
+Schedule Schedule::Concrete(IRModule mod, LinearCongruentialEngine::TRandState
seed, int debug_mask,
+ ScheduleErrorRenderLevel error_render_level, bool
enable_check) {
ObjectPtr<ConcreteScheduleNode> n = ffi::make_object<ConcreteScheduleNode>();
n->state_ = ScheduleState(mod, debug_mask, enable_check);
n->error_render_level_ = error_render_level;
@@ -226,12 +225,12 @@ Schedule ConcreteScheduleNode::Copy() {
/******** Schedule: Schedule: Sampling ********/
-void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState
seed) {
- this->rand_state_ = support::LinearCongruentialEngine::NormalizeSeed(seed);
+void ConcreteScheduleNode::Seed(LinearCongruentialEngine::TRandState seed) {
+ this->rand_state_ = LinearCongruentialEngine::NormalizeSeed(seed);
}
-support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed()
{
- return support::LinearCongruentialEngine(&rand_state_).ForkSeed();
+LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() {
+ return LinearCongruentialEngine(&rand_state_).ForkSeed();
}
ExprRV ConcreteScheduleNode::SampleCategorical(const ffi::Array<Integer>&
candidates,
diff --git a/src/s_tir/schedule/concrete_schedule.h
b/src/s_tir/schedule/concrete_schedule.h
index 84b934ff72..29445da915 100644
--- a/src/s_tir/schedule/concrete_schedule.h
+++ b/src/s_tir/schedule/concrete_schedule.h
@@ -48,7 +48,7 @@ class ConcreteScheduleNode : public ScheduleNode {
/*! \brief A persistent stateless arithmetic analyzer. */
std::unique_ptr<arith::Analyzer> analyzer_;
/*! \brief The value of random state for sampling. */
- support::LinearCongruentialEngine::TRandState rand_state_;
+ LinearCongruentialEngine::TRandState rand_state_;
public:
static void RegisterReflection() {
@@ -64,8 +64,8 @@ class ConcreteScheduleNode : public ScheduleNode {
ffi::Optional<GlobalVar> func_working_on() const final { return
func_working_on_; }
void WorkOn(const ffi::String& func_name) final;
Schedule Copy() override;
- void Seed(support::LinearCongruentialEngine::TRandState seed) final;
- support::LinearCongruentialEngine::TRandState ForkSeed() final;
+ void Seed(LinearCongruentialEngine::TRandState seed) final;
+ LinearCongruentialEngine::TRandState ForkSeed() final;
public:
/******** Lookup random variables ********/
diff --git a/src/s_tir/schedule/primitive.h b/src/s_tir/schedule/primitive.h
index 63213bb4ab..85ef871463 100644
--- a/src/s_tir/schedule/primitive.h
+++ b/src/s_tir/schedule/primitive.h
@@ -19,8 +19,8 @@
#ifndef TVM_S_TIR_SCHEDULE_PRIMITIVE_H_
#define TVM_S_TIR_SCHEDULE_PRIMITIVE_H_
+#include <tvm/s_tir/random_engine.h>
#include <tvm/s_tir/schedule/state.h>
-#include <tvm/support/random_engine.h>
#include <vector>
@@ -36,8 +36,8 @@ using namespace tvm::tirx;
* \param max_exclusive The maximum value of the range, exclusive.
* \return The random integer sampled in the given range.
*/
-TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState*
rand_state,
- int32_t min_inclusive, int32_t max_exclusive);
+TVM_DLL int32_t SampleInt(LinearCongruentialEngine::TRandState* rand_state,
int32_t min_inclusive,
+ int32_t max_exclusive);
/*!
* \brief Sample k random integers from given range without replacement, i.e,
no duplication.
* \param rand_state The pointer to schedule's random state
@@ -45,8 +45,8 @@ TVM_DLL int32_t
SampleInt(support::LinearCongruentialEngine::TRandState* rand_st
* \param k The total number of samples.
* \return The randomly selected samples from the n candidates.
*/
-std::vector<int32_t> SampleWithoutReplacement(
- support::LinearCongruentialEngine::TRandState* rand_state, int32_t n,
int32_t k);
+std::vector<int32_t>
SampleWithoutReplacement(LinearCongruentialEngine::TRandState* rand_state,
+ int32_t n, int32_t k);
/*!
* \brief Sample once category from candidates according to the probability
weights.
* \param rand_state The pointer to schedule's random state
@@ -55,7 +55,7 @@ std::vector<int32_t> SampleWithoutReplacement(
* \param decision The sampling decision, if any
* \return The random variable sampled from candidates
*/
-TVM_DLL int64_t
SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state,
+TVM_DLL int64_t SampleCategorical(LinearCongruentialEngine::TRandState*
rand_state,
const ffi::Array<Integer>& candidates,
const ffi::Array<FloatImm>& probs,
ffi::Optional<Integer>* decision);
@@ -66,7 +66,7 @@ TVM_DLL int64_t
SampleCategorical(support::LinearCongruentialEngine::TRandState*
* \return The multinomial sampling function.
*/
TVM_DLL std::function<int32_t()> MakeMultinomialSampler(
- support::LinearCongruentialEngine::TRandState* rand_state, const
std::vector<double>& weights);
+ LinearCongruentialEngine::TRandState* rand_state, const
std::vector<double>& weights);
/*!
* \brief Sample the factors to perfect tile a specific loop
* \param rand_state The random state
@@ -74,9 +74,8 @@ TVM_DLL std::function<int32_t()> MakeMultinomialSampler(
* \param n_split The number of tiles to be sampled
* \return A list of length `n`, the random perfect tile sizes sampled
*/
-TVM_DLL std::vector<int64_t> SamplePerfectTile(
- support::LinearCongruentialEngine::TRandState* rand_state, //
- int32_t extent, int32_t n_splits);
+TVM_DLL std::vector<int64_t>
SamplePerfectTile(LinearCongruentialEngine::TRandState* rand_state, //
+ int32_t extent, int32_t
n_splits);
/*!
* \brief Sample the factors to perfect tile a specific loop
* \param rand_state The random state
@@ -85,9 +84,9 @@ TVM_DLL std::vector<int64_t> SamplePerfectTile(
* \param max_innermost_factor The maximum tile size allowed to be sampled in
the innermost loop
* \return A list of length `n`, the random perfect tile sizes sampled
*/
-TVM_DLL std::vector<int64_t> SamplePerfectTile(
- support::LinearCongruentialEngine::TRandState* rand_state, //
- int32_t extent, int32_t n_split, int32_t max_innermost_factor);
+TVM_DLL std::vector<int64_t>
SamplePerfectTile(LinearCongruentialEngine::TRandState* rand_state, //
+ int32_t extent, int32_t n_split,
+ int32_t max_innermost_factor);
/*!
* \brief Sample the factors to perfect tile a specific loop
* \param rand_state The random state
@@ -97,10 +96,10 @@ TVM_DLL std::vector<int64_t> SamplePerfectTile(
* \param decision The sampling decision
* \return A list of length `n`, the random perfect tile sizes sampled
*/
-TVM_DLL std::vector<int64_t> SamplePerfectTile(
- support::LinearCongruentialEngine::TRandState* rand_state, //
- const tirx::StmtSRef& loop_sref, int32_t n_split, int32_t
max_innermost_factor,
- ffi::Optional<ffi::Array<Integer>>* decision);
+TVM_DLL std::vector<int64_t>
SamplePerfectTile(LinearCongruentialEngine::TRandState* rand_state, //
+ const tirx::StmtSRef&
loop_sref, int32_t n_split,
+ int32_t max_innermost_factor,
+
ffi::Optional<ffi::Array<Integer>>* decision);
/*!
* \brief Sample the factors to a partitioned tile for a specific loop
*
@@ -117,7 +116,7 @@ TVM_DLL std::vector<int64_t> SamplePerfectTile(
* \return A list of length `n`, the random partitioned tile sizes sampled
*/
TVM_DLL std::vector<int64_t> SamplePartitionedTile(
- support::LinearCongruentialEngine::TRandState* rand_state, //
+ LinearCongruentialEngine::TRandState* rand_state, //
int32_t extent, int32_t n_split, int32_t partition_pos, int32_t
innerpart_factor);
/*!
* \brief Sample the factors to a partitioned tile for a specific loop
@@ -136,7 +135,7 @@ TVM_DLL std::vector<int64_t> SamplePartitionedTile(
* \return A list of length `n`, the random partitioned tile sizes sampled
*/
TVM_DLL std::vector<int64_t> SamplePartitionedTile(
- support::LinearCongruentialEngine::TRandState* rand_state, //
+ LinearCongruentialEngine::TRandState* rand_state, //
const tirx::StmtSRef& loop_sref, int32_t n_split, int32_t partition_pos,
int32_t innerpart_factor, ffi::Optional<ffi::Array<Integer>>* decision);
/*!
@@ -147,9 +146,10 @@ TVM_DLL std::vector<int64_t> SamplePartitionedTile(
* \param decision The sampling decision
* \return The sampled loop where the input block is to be computed at
*/
-TVM_DLL tirx::StmtSRef SampleComputeLocation(
- s_tir::ScheduleState self, support::LinearCongruentialEngine::TRandState*
rand_state,
- const tirx::StmtSRef& block_sref, ffi::Optional<Integer>* decision);
+TVM_DLL tirx::StmtSRef SampleComputeLocation(s_tir::ScheduleState self,
+
LinearCongruentialEngine::TRandState* rand_state,
+ const tirx::StmtSRef& block_sref,
+ ffi::Optional<Integer>* decision);
/******** Schedule: Get blocks & loops ********/
/*!
diff --git a/src/s_tir/schedule/primitive/sampling.cc
b/src/s_tir/schedule/primitive/sampling.cc
index 273f0d8441..d7851d9130 100644
--- a/src/s_tir/schedule/primitive/sampling.cc
+++ b/src/s_tir/schedule/primitive/sampling.cc
@@ -125,20 +125,20 @@ struct PrimeTable {
}
};
-int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state,
int32_t min_inclusive,
+int32_t SampleInt(LinearCongruentialEngine::TRandState* rand_state, int32_t
min_inclusive,
int32_t max_exclusive) {
TVM_FFI_CHECK(min_inclusive < max_exclusive, ValueError)
<< "max_exclusive must be greater than min_inclusive.";
if (min_inclusive + 1 == max_exclusive) {
return min_inclusive;
}
- support::LinearCongruentialEngine rand_(rand_state);
+ LinearCongruentialEngine rand_(rand_state);
std::uniform_int_distribution<int32_t> dist(min_inclusive, max_exclusive -
1);
return dist(rand_);
}
-std::vector<int32_t> SampleWithoutReplacement(
- support::LinearCongruentialEngine::TRandState* rand_state, int32_t n,
int32_t k) {
+std::vector<int32_t>
SampleWithoutReplacement(LinearCongruentialEngine::TRandState* rand_state,
+ int32_t n, int32_t k) {
if (k == 1) {
return {SampleInt(rand_state, 0, n)};
}
@@ -163,7 +163,7 @@ std::vector<int32_t> SampleWithoutReplacement(
return {order.begin(), order.begin() + k};
}
-int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState*
rand_state,
+int64_t SampleCategorical(LinearCongruentialEngine::TRandState* rand_state,
const ffi::Array<Integer>& candidates, const
ffi::Array<FloatImm>& probs,
ffi::Optional<Integer>* decision) {
TVM_FFI_CHECK(candidates.size() == probs.size(), ValueError)
@@ -177,7 +177,7 @@ int64_t
SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st
} else {
std::vector<double> weights = support::AsVector<FloatImm, double>(probs);
std::discrete_distribution<int32_t> dist(weights.begin(), weights.end());
- support::LinearCongruentialEngine rand_(rand_state);
+ LinearCongruentialEngine rand_(rand_state);
i = dist(rand_);
TVM_FFI_CHECK(0 <= i && i < n, ValueError)
<< "Unexpected decision generated, where n = " << n << ", but decision
is: " << i;
@@ -187,8 +187,8 @@ int64_t
SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st
return candidates[i]->value;
}
-std::function<int32_t()> MakeMultinomialSampler(
- support::LinearCongruentialEngine::TRandState* rand_state, const
std::vector<double>& weights) {
+std::function<int32_t()>
MakeMultinomialSampler(LinearCongruentialEngine::TRandState* rand_state,
+ const std::vector<double>&
weights) {
TVM_FFI_ICHECK(!weights.empty());
std::vector<double> sums;
sums.reserve(weights.size());
@@ -196,10 +196,10 @@ std::function<int32_t()> MakeMultinomialSampler(
for (double w : weights) {
sums.push_back(sum += w);
}
- return [rng = support::LinearCongruentialEngine(rand_state).ForkSeed(),
+ return [rng = LinearCongruentialEngine(rand_state).ForkSeed(),
dist = std::uniform_real_distribution<double>(0.0, sum),
sums = std::move(sums)]() mutable -> int32_t {
- support::LinearCongruentialEngine rand_(&rng);
+ LinearCongruentialEngine rand_(&rng);
double p = dist(rand_);
int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin();
int32_t n = sums.size();
@@ -209,7 +209,7 @@ std::function<int32_t()> MakeMultinomialSampler(
};
}
-std::vector<int64_t>
SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state,
+std::vector<int64_t> SamplePerfectTile(LinearCongruentialEngine::TRandState*
rand_state,
int32_t extent, int32_t n_splits) {
TVM_FFI_CHECK_GE(extent, 1, ValueError) << "Cannot tile a loop with 0 or
negative extent";
TVM_FFI_CHECK_GE(n_splits, 1, ValueError) << "Cannot tile a loop to 0 or
negative splits";
@@ -292,7 +292,7 @@ std::vector<int64_t>
SamplePerfectTile(support::LinearCongruentialEngine::TRandS
return result;
}
-std::vector<int64_t>
SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state,
+std::vector<int64_t> SamplePerfectTile(LinearCongruentialEngine::TRandState*
rand_state,
int32_t extent, int32_t n_splits,
int32_t max_innermost_factor) {
if (max_innermost_factor == -1) {
@@ -307,10 +307,10 @@ std::vector<int64_t>
SamplePerfectTile(support::LinearCongruentialEngine::TRandS
}
}
-std::vector<int64_t> SamplePerfectTile(
- support::LinearCongruentialEngine::TRandState* rand_state, //
- const tirx::StmtSRef& loop_sref, int32_t n_splits, int32_t
max_innermost_factor,
- ffi::Optional<ffi::Array<Integer>>* decision) {
+std::vector<int64_t> SamplePerfectTile(LinearCongruentialEngine::TRandState*
rand_state, //
+ const tirx::StmtSRef& loop_sref,
int32_t n_splits,
+ int32_t max_innermost_factor,
+ ffi::Optional<ffi::Array<Integer>>*
decision) {
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
const int64_t* extent = GetLoopIntExtent(loop);
std::vector<int64_t> result;
@@ -347,7 +347,7 @@ std::vector<int64_t> SamplePerfectTile(
}
TVM_DLL std::vector<int64_t> SamplePartitionedTile(
- support::LinearCongruentialEngine::TRandState* rand_state, //
+ LinearCongruentialEngine::TRandState* rand_state, //
int32_t extent, int32_t n_splits, int32_t partition_pos, int32_t
innerpart_factor) {
if (partition_pos == 0 && innerpart_factor == 1) {
return SamplePerfectTile(rand_state, extent, n_splits);
@@ -368,10 +368,10 @@ TVM_DLL std::vector<int64_t> SamplePartitionedTile(
}
}
-std::vector<int64_t> SamplePartitionedTile(
- support::LinearCongruentialEngine::TRandState* rand_state, //
- const tirx::StmtSRef& loop_sref, int32_t n_splits, int32_t partition_pos,
- int32_t innerpart_factor, ffi::Optional<ffi::Array<Integer>>* decision) {
+std::vector<int64_t>
SamplePartitionedTile(LinearCongruentialEngine::TRandState* rand_state, //
+ const tirx::StmtSRef& loop_sref,
int32_t n_splits,
+ int32_t partition_pos, int32_t
innerpart_factor,
+ ffi::Optional<ffi::Array<Integer>>*
decision) {
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
const int64_t* extent = GetLoopIntExtent(loop);
std::vector<int64_t> result;
@@ -419,7 +419,7 @@ std::vector<int64_t> SamplePartitionedTile(
}
tirx::StmtSRef SampleComputeLocation(s_tir::ScheduleState self,
-
support::LinearCongruentialEngine::TRandState* rand_state,
+ LinearCongruentialEngine::TRandState*
rand_state,
const StmtSRef& block_sref,
ffi::Optional<Integer>* decision) {
// Step 1. Collect all possible compute-at locations.
auto [location_srefs, location_indices] = CollectComputeLocation(self,
block_sref);
diff --git a/src/s_tir/schedule/schedule.cc b/src/s_tir/schedule/schedule.cc
index a14ff2aa8a..ba8ef72b57 100644
--- a/src/s_tir/schedule/schedule.cc
+++ b/src/s_tir/schedule/schedule.cc
@@ -70,14 +70,14 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def("s_tir.schedule.SBlockRV", []() { return SBlockRV(); })
.def("s_tir.schedule.LoopRV", []() { return LoopRV(); })
.def("s_tir.schedule.ConcreteSchedule",
- [](IRModule mod, support::LinearCongruentialEngine::TRandState
seed, int debug_mask,
+ [](IRModule mod, LinearCongruentialEngine::TRandState seed, int
debug_mask,
int error_render_level, bool enable_check) -> Schedule {
return Schedule::Concrete(mod, debug_mask, seed,
static_cast<ScheduleErrorRenderLevel>(error_render_level),
enable_check);
})
.def("s_tir.schedule.TracedSchedule",
- [](IRModule mod, support::LinearCongruentialEngine::TRandState
seed, int debug_mask,
+ [](IRModule mod, LinearCongruentialEngine::TRandState seed, int
debug_mask,
int error_render_level, bool enable_check) -> Schedule {
return Schedule::Traced(mod, seed, debug_mask,
static_cast<ScheduleErrorRenderLevel>(error_render_level),
diff --git a/src/s_tir/schedule/traced_schedule.cc
b/src/s_tir/schedule/traced_schedule.cc
index e43df0835c..b1d78312e1 100644
--- a/src/s_tir/schedule/traced_schedule.cc
+++ b/src/s_tir/schedule/traced_schedule.cc
@@ -22,9 +22,8 @@ namespace tvm {
namespace s_tir {
using namespace tvm::tirx;
-Schedule Schedule::Traced(IRModule mod,
support::LinearCongruentialEngine::TRandState seed,
- int debug_mask, ScheduleErrorRenderLevel
error_render_level,
- bool enable_check) {
+Schedule Schedule::Traced(IRModule mod, LinearCongruentialEngine::TRandState
seed, int debug_mask,
+ ScheduleErrorRenderLevel error_render_level, bool
enable_check) {
ObjectPtr<TracedScheduleNode> n = ffi::make_object<TracedScheduleNode>();
n->state_ = ScheduleState(mod, debug_mask, enable_check);
n->error_render_level_ = error_render_level;
diff --git a/tests/cpp/random_engine_test.cc b/tests/cpp/random_engine_test.cc
index 42d65aa402..afd02c09b2 100644
--- a/tests/cpp/random_engine_test.cc
+++ b/tests/cpp/random_engine_test.cc
@@ -19,12 +19,12 @@
#include <gtest/gtest.h>
#include <tvm/runtime/logging.h>
-#include <tvm/support/random_engine.h>
+#include <tvm/s_tir/random_engine.h>
TEST(RandomEngine, Randomness) {
int64_t rand_state = 0;
- tvm::support::LinearCongruentialEngine rng(&rand_state);
+ tvm::s_tir::LinearCongruentialEngine rng(&rand_state);
rng.Seed(0x114514);
bool covered[100];
@@ -39,7 +39,7 @@ TEST(RandomEngine, Randomness) {
TEST(RandomEngine, Reproducibility) {
int64_t rand_state_a = 0, rand_state_b = 0;
- tvm::support::LinearCongruentialEngine rng_a(&rand_state_a),
rng_b(&rand_state_b);
+ tvm::s_tir::LinearCongruentialEngine rng_a(&rand_state_a),
rng_b(&rand_state_b);
rng_a.Seed(0x23456789);
rng_b.Seed(0x23456789);
@@ -51,7 +51,7 @@ TEST(RandomEngine, Reproducibility) {
TEST(RandomEngine, Serialization) {
int64_t rand_state_a = 0, rand_state_b = 0;
- tvm::support::LinearCongruentialEngine rng_a(&rand_state_a),
rng_b(&rand_state_b);
+ tvm::s_tir::LinearCongruentialEngine rng_a(&rand_state_a),
rng_b(&rand_state_b);
rng_a.Seed(0x56728);