diff --git a/Samples/WindowsML/Shared/cpp/ArgumentParser.cpp b/Samples/WindowsML/Shared/cpp/ArgumentParser.cpp index bd33e9fe0..797f77e72 100644 --- a/Samples/WindowsML/Shared/cpp/ArgumentParser.cpp +++ b/Samples/WindowsML/Shared/cpp/ArgumentParser.cpp @@ -20,6 +20,8 @@ namespace Shared return false; } + bool disable_ep = false; // tracks --ep_policy DISABLE locally for validation + for (size_t i = 1; i < arguments.size(); ++i) { if (arguments[i] == L"--compile") @@ -67,11 +69,14 @@ namespace Shared } else if (policy_str == L"DISABLE") { - options.ep_policy = std::nullopt; + disable_ep = true; + options.ep_policy.reset(); } else { - std::wcout << L"Unknown EP policy: " << policy_str << L", using default (DISABLE)\n"; + std::wcout << L"Unknown EP policy: " << policy_str << L"\n"; + PrintUsage(); + return false; } } else if (arguments[i] == L"--perf_mode" && i + 1 < arguments.size()) @@ -117,8 +122,8 @@ namespace Shared return false; } - // Require one selection method - if (!options.ep_policy.has_value() && options.ep_name.empty()) + // Require one selection method (unless disabled) + if (!disable_ep && !options.ep_policy.has_value() && options.ep_name.empty()) { std::wcout << L"ERROR: You must specify one of --ep_policy or --ep_name.\n"; PrintUsage(); diff --git a/Samples/WindowsML/Shared/cpp/InferenceEngine.cpp b/Samples/WindowsML/Shared/cpp/InferenceEngine.cpp index 03a673779..19536671e 100644 --- a/Samples/WindowsML/Shared/cpp/InferenceEngine.cpp +++ b/Samples/WindowsML/Shared/cpp/InferenceEngine.cpp @@ -20,9 +20,9 @@ namespace Shared std::cout << "Using EP Selection Policy: " << ArgumentParser::ToString(options.ep_policy.value()) << std::endl; sessionOptions.SetEpSelectionPolicy(options.ep_policy.value()); } - else + else if (!options.ep_name.empty()) { - // Use explicit configuration + // Use explicit EP configuration std::cout << "Using explicit EP configuration" << std::endl; ExecutionProviderManager::ConfigureSelectedExecutionProvider( sessionOptions, @@ -31,6 +31,11 @@ namespace Shared options.device_type, options.perf_mode); } + else + { + // DISABLE: no EP policy or explicit EP — use ONNX Runtime defaults (CPU + DML) + std::cout << "EP selection disabled, using ONNX Runtime defaults" << std::endl; + } return sessionOptions; } diff --git a/Samples/WindowsML/Shared/cs/ArgumentParser.cs b/Samples/WindowsML/Shared/cs/ArgumentParser.cs index 3f7e79716..02b020d7b 100644 --- a/Samples/WindowsML/Shared/cs/ArgumentParser.cs +++ b/Samples/WindowsML/Shared/cs/ArgumentParser.cs @@ -52,6 +52,7 @@ public static class ArgumentParser public static Options ParseOptions(string[] args) { Options options = new(); + bool disableEp = false; // tracks --ep_policy DISABLE locally for validation for (int i = 0; i < args.Length; i++) { @@ -76,12 +77,13 @@ public static Options ParseOptions(string[] args) options.EpPolicy = ExecutionProviderDevicePolicy.DEFAULT; break; case "DISABLE": + disableEp = true; options.EpPolicy = null; break; default: - Console.WriteLine($"Unknown EP policy: {policyStr}, using default (DISABLE)"); - options.EpPolicy = null; - break; + Console.WriteLine($"Unknown EP policy: {policyStr}"); + PrintHelp(); + throw new ArgumentException($"Unknown EP policy: {policyStr}", "--ep_policy"); } } break; @@ -173,11 +175,11 @@ public static Options ParseOptions(string[] args) throw new Exception("Mutually exclusive EP options"); } - if (!options.EpPolicy.HasValue && string.IsNullOrEmpty(options.EpName)) + if (!disableEp && !options.EpPolicy.HasValue && string.IsNullOrEmpty(options.EpName)) { Console.WriteLine("ERROR: You must specify one of --ep_policy or --ep_name."); PrintHelp(); - throw new Exception("Missing EP selection"); + throw new ArgumentException("Missing EP selection"); } if (!string.IsNullOrEmpty(options.DeviceType)) diff --git a/Samples/WindowsML/Shared/cs/ModelManager.cs b/Samples/WindowsML/Shared/cs/ModelManager.cs index e63602be0..a4a702d02 100644 --- a/Samples/WindowsML/Shared/cs/ModelManager.cs +++ b/Samples/WindowsML/Shared/cs/ModelManager.cs @@ -271,7 +271,8 @@ public static InferenceSession CreateSession(string modelPath, Options options, } else { - throw new Exception("Could not find an EP selection policy or an explicit execution provider."); + // DISABLE: no EP policy or explicit EP — use ONNX Runtime defaults (CPU + DML) + Console.WriteLine("EP selection disabled, using ONNX Runtime defaults"); } return new InferenceSession(modelPath, sessionOptions); @@ -344,10 +345,7 @@ public static string ResolveActualModelPath(Options options, string modelPath, s throw new Exception("Failed to configure selected execution provider"); } } - else - { - throw new Exception("Could not find an EP selection policy or an explicit execution provider."); - } + // else: DISABLE — compile with ONNX Runtime defaults CompileModel(tempSessionOptions, modelPath, compiledModelPath);