36 template <CudnnFindSamplingTechnique samplingTechnique>
42 std::set<std::reference_wrapper<ExecutionPlan>, decltype(plan_cmp)> timed_execution_plans(plan_cmp);
44 const int maxIterCount =
48 const float threshhold = 0.95f;
50 cudaEvent_t start, stop;
51 cudaEventCreate(&start);
52 cudaEventCreate(&stop);
53 cudaDeviceSynchronize();
55 for (
auto &plan : plans) {
57 float final_time_ms = 0.0f;
58 float min_time_ms = std::numeric_limits<float>::max();
61 auto warmup_status = ::cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc());
62 if (warmup_status != CUDNN_STATUS_SUCCESS) {
63 getLogger() <<
"[cudnn_frontend] Plan " << plan.getTag() <<
" failed with " <<
to_string(warmup_status) << std::endl;
66 cudaDeviceSynchronize();
68 for (
int i = 0; i < maxIterCount; i++) {
69 cudaEventRecord(start);
71 ::cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc());
73 cudaEventRecord(stop);
74 cudaEventSynchronize(stop);
75 cudaEventElapsedTime(&time_ms, start, stop);
78 final_time_ms = std::min(min_time_ms, time_ms);
79 if (time_ms / min_time_ms < threshhold) {
80 min_time_ms = final_time_ms;
85 final_time_ms = i == (maxIterCount / 2) ? time_ms : final_time_ms;
88 getLogger() <<
"[cudnn_frontend] Plan " << plan.getTag() <<
" took " << std::setw(10) << final_time_ms << std::endl;
89 plan.setExecutionTime(final_time_ms);
90 timed_execution_plans.insert(plan);
94 time_sorted_plans.emplace_back(std::move(plan));
97 cudaEventDestroy(start);
98 cudaEventDestroy(stop);
100 getLogger() <<
"[cudnn_frontend] Auto-tuning returns " << time_sorted_plans.size() <<
" plans." << std::endl;
102 return time_sorted_plans;
105 template <CudnnFindSamplingTechnique samplingTechnique>
112 return time_sorted_plan<samplingTechnique>(handle, std::move(plans), variantPack);
115 template <CudnnFindSamplingTechnique samplingTechnique>
123 return time_sorted_plan<samplingTechnique>(handle, std::move(plans), variantPack);
126 template <CudnnFindSamplingTechnique samplingTechnique>
134 auto sorted_plans = cudnnFindPlan<samplingTechnique>(handle, opGraph, variantPack, pred);
136 if (cache.is_fastest_plan_stable(opGraph, sorted_plans.front().getTag())) {
137 cache.add_plan_to_cache(opGraph, sorted_plans.front());
139 return sorted_plans.front();
Sample 3 times and take median.
ConditionalStreamer & getLogger()
Sample multiple times till stable.
std::function< bool(cudnn_frontend::ExecutionPlan const &plan)> Predicate
std::vector< cudnn_frontend::ExecutionPlan > executionPlans_t
Variety of renames.
static std::string to_string(cudnnDataType_t type)
Sample once quick but may have unstable values.
auto cudnnFindPlanAndCache(cudnnHandle_t handle, cudnn_frontend::OperationGraph &opGraph, cudnn_frontend::VariantPack const &variantPack, cudnn_frontend::ExecutionPlanCache &cache, Predicate pred=[](const cudnn_frontend::ExecutionPlan &) {return false;}) -> cudnn_frontend::ExecutionPlan
float getExecutionTime() const
auto cudnnFindPlan(cudnnHandle_t handle, cudnn_frontend::OperationGraph &opGraph, cudnn_frontend::VariantPack const &variantPack, Predicate pred) -> executionPlans_t
auto time_sorted_plan(cudnnHandle_t handle, executionPlans_t plans, VariantPack const &variantPack) -> executionPlans_t