diff --git a/samples/legacy_samples/conv_sample.cpp b/samples/legacy_samples/conv_sample.cpp index db50daa8..94230b1e 100644 --- a/samples/legacy_samples/conv_sample.cpp +++ b/samples/legacy_samples/conv_sample.cpp @@ -247,20 +247,35 @@ create_operation_graph(common_conv_descriptors& descriptors, cudnnBackendDescrip return cudnn_frontend::OperationGraphBuilder().setHandle(handle_).setOperationGraph(ops.size(), ops.data()).build(); } -// Method for engine config generator based on heuristics -auto heurgen_method = [](cudnn_frontend::OperationGraph& opGraph) -> cudnn_frontend::EngineConfigList { +auto get_engine_configs_from_heuristics = [](cudnn_frontend::OperationGraph& opGraph, + int64_t max_engine_config_count) -> cudnn_frontend::EngineConfigList { auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() .setOperationGraph(opGraph) .setHeurMode(CUDNN_HEUR_MODE_INSTANT) .build(); - std::cout << "Heuristic has " << heuristics.getEngineConfigCount() << " configurations " << std::endl; + int64_t config_count = max_engine_config_count; + if (config_count <= 0) { + config_count = heuristics.getEngineConfigCount(); + std::cout << "Heuristic has " << config_count << " configurations " << std::endl; + } else { + std::cout << "Heuristic requesting " << config_count << " configuration(s) " << std::endl; + } - auto& engine_configs = heuristics.getEngineConfig(heuristics.getEngineConfigCount()); + auto& engine_configs = heuristics.getEngineConfig(config_count); cudnn_frontend::EngineConfigList filtered_configs; cudnn_frontend::filter(engine_configs, filtered_configs, ::allowAll); return filtered_configs; }; +// Method for engine config generator based on heuristics +auto heurgen_method = [](cudnn_frontend::OperationGraph& opGraph) -> cudnn_frontend::EngineConfigList { + return get_engine_configs_from_heuristics(opGraph, -1); +}; + +auto heurgen_method_first_config = [](cudnn_frontend::OperationGraph& opGraph) -> cudnn_frontend::EngineConfigList { + return get_engine_configs_from_heuristics(opGraph, 1); +}; + // Method for engine config generator based on fallback list auto fallback_method = [](cudnn_frontend::OperationGraph& opGraph) -> cudnn_frontend::EngineConfigList { auto fallback = cudnn_frontend::EngineFallbackListBuilder() @@ -965,7 +980,7 @@ run_from_cudnn_get(int64_t* x_dim, return false; }; - std::array sources = {heurgen_method}; + std::array sources = {heurgen_method_first_config}; cudnn_frontend::EngineConfigGenerator generator(static_cast(sources.size()), sources.data()); auto plans = generator.cudnnGetPlan(handle_, opGraph, sample_predicate_function);