Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions samples/legacy_samples/conv_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,20 +247,35 @@ create_operation_graph(common_conv_descriptors& descriptors, cudnnBackendDescrip
return cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build();
}

// Method for engine config generator based on heuristics
auto heurgen_method = [](cudnn_frontend::OperationGraph& opGraph) -> cudnn_frontend::EngineConfigList {
auto get_engine_configs_from_heuristics = [](cudnn_frontend::OperationGraph& opGraph,
int64_t max_engine_config_count) -> cudnn_frontend::EngineConfigList {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(opGraph)
.setHeurMode(CUDNN_HEUR_MODE_INSTANT)
.build();
std::cout << "Heuristic has " << heuristics.getEngineConfigCount() << " configurations " << std::endl;
int64_t config_count = max_engine_config_count;
if (config_count <= 0) {
config_count = heuristics.getEngineConfigCount();
std::cout << "Heuristic has " << config_count << " configurations " << std::endl;
} else {
std::cout << "Heuristic requesting " << config_count << " configuration(s) " << std::endl;
}

auto& engine_configs = heuristics.getEngineConfig(heuristics.getEngineConfigCount());
auto& engine_configs = heuristics.getEngineConfig(config_count);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(engine_configs, filtered_configs, ::allowAll);
return filtered_configs;
};

// Method for engine config generator based on heuristics
auto heurgen_method = [](cudnn_frontend::OperationGraph& opGraph) -> cudnn_frontend::EngineConfigList {
return get_engine_configs_from_heuristics(opGraph, -1);
};

auto heurgen_method_first_config = [](cudnn_frontend::OperationGraph& opGraph) -> cudnn_frontend::EngineConfigList {
return get_engine_configs_from_heuristics(opGraph, 1);
};

// Method for engine config generator based on fallback list
auto fallback_method = [](cudnn_frontend::OperationGraph& opGraph) -> cudnn_frontend::EngineConfigList {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
Expand Down Expand Up @@ -965,7 +980,7 @@ run_from_cudnn_get(int64_t* x_dim,
return false;
};

std::array<cudnn_frontend::GeneratorSource const, 1> sources = {heurgen_method};
std::array<cudnn_frontend::GeneratorSource const, 1> sources = {heurgen_method_first_config};
cudnn_frontend::EngineConfigGenerator generator(static_cast<int>(sources.size()), sources.data());

auto plans = generator.cudnnGetPlan(handle_, opGraph, sample_predicate_function);
Expand Down