diff --git a/core/utils_sync.py b/core/utils_sync.py index 59514de..a8ecac0 100644 --- a/core/utils_sync.py +++ b/core/utils_sync.py @@ -86,6 +86,31 @@ def _build_payload( return payload +def write_assessment_payload( + raw_data_path: str, + *, + report_path: str, + name: str, + started_at: int, + exit_strategy: int, + cloud_service_provider: int, + assessment_type: int, +) -> str: + payload = _build_payload( + report_path=report_path, + name=name, + started_at=started_at, + exit_strategy=exit_strategy, + cloud_service_provider=cloud_service_provider, + assessment_type=assessment_type, + ) + payload_path = os.path.join(raw_data_path, "payload.json") + with open(payload_path, "w", encoding="utf-8") as payload_file: + json.dump(payload, payload_file, indent=2) + payload_file.write("\n") + return payload_path + + def post_assessment( *, name: str, diff --git a/main.py b/main.py index e658229..76a3a2c 100644 --- a/main.py +++ b/main.py @@ -20,6 +20,7 @@ sync_assessment, generate_report, ) +from core.utils_sync import write_assessment_payload from utils.azure import ( select_subscription, select_resource_group, @@ -241,7 +242,7 @@ def handle_aws(args): return # Run the AWS assessment pipeline - run_assessment(config, "aws") + run_assessment(config, "aws", dry_run=args.dry_run) def handle_azure(args): @@ -464,10 +465,10 @@ def handle_azure(args): # Run the Azure assessment pipeline # logger.info("Starting Azure assessment pipeline.") - run_assessment(config, "azure") + run_assessment(config, "azure", dry_run=args.dry_run) -def run_assessment(config, provider_name): +def run_assessment(config, provider_name, *, dry_run=False): # Record the assessment start time to propagate across stages started_at = int(time.time()) @@ -484,7 +485,9 @@ def run_assessment(config, provider_name): # Detect ExitCloud Integration mode, jwt = resolve_mode() - if mode == "online": + if dry_run: + print_step("Dry run mode – no remote sync.", status="ok") + elif mode == "online": print_step("ExitCloud integration configured.", status="ok") else: print_step("ExitCloud integration not configured.", status="warning") @@ -629,8 +632,37 @@ def run_assessment(config, provider_name): or f"Exit Assessment {datetime.now().strftime('%Y%m%d_%H%M%S')}" ) + payload_path = None + if dry_run: + payload_path = write_assessment_payload( + raw_data_path, + report_path=report_path, + name=name, + started_at=started_at, + exit_strategy=config["exitStrategy"], + cloud_service_provider=config["cloudServiceProvider"], + assessment_type=config["assessmentType"], + ) + # Stage 5 – Online / Offline Risk Assessment - if mode == "online": + if dry_run or mode == "offline": + console.print("Stage #5 – Offline Risk Assessment", style="bold") + + with console.status("Performing risk assessment...", spinner="dots"): + risk_result = perform_risk_assessment( + exit_strategy=config["exitStrategy"], + report_path=report_path, + mode="offline", + ) + + status = "ok" if risk_result["success"] else "error" + print_step( + "Performing risk assessment...", status=status, logs=risk_result["logs"] + ) + if not risk_result["success"]: + sys.exit(codes.RISK_ASSESSMENT) + + elif mode == "online": console.print("Stage #5 – Online Risk Assessment", style="bold") sync_result = sync_assessment( @@ -651,23 +683,6 @@ def run_assessment(config, provider_name): if not sync_result["success"]: sys.exit(codes.RISK_ASSESSMENT) - elif mode == "offline": - console.print("Stage #5 – Offline Risk Assessment", style="bold") - - with console.status("Performing risk assessment...", spinner="dots"): - risk_result = perform_risk_assessment( - exit_strategy=config["exitStrategy"], - report_path=report_path, - mode=mode, - ) - - status = "ok" if risk_result["success"] else "error" - print_step( - "Performing risk assessment...", status=status, logs=risk_result["logs"] - ) - if not risk_result["success"]: - sys.exit(codes.RISK_ASSESSMENT) - console.print("-------------------------------------------") # Stage 6: Generate Report @@ -697,6 +712,8 @@ def run_assessment(config, provider_name): # Output the report path after the separator console.print("-------------------------------------------") console.print("Outputs:", style="bold") + if payload_path: + console.print(f"Payload: {payload_path}", style="cyan") html_report_path = report_status.get("reports", {}).get("HTML") if html_report_path: console.print(f"HTML Report: {html_report_path}", style="cyan") @@ -726,6 +743,8 @@ def parse_arguments(): " python3 main.py azure --config config.json # Use a configuration file for Azure\n" " python3 main.py azure --cli # Use Azure CLI credentials\n" " python3 main.py azure --name 'DMS System' # Use a pre-defined assessment name\n" + " python3 main.py aws --config config.json --dry-run # Local report + payload.json, no remote sync\n" + " python3 main.py azure --config config.json --dry-run\n" ), formatter_class=argparse.RawDescriptionHelpFormatter, ) @@ -753,6 +772,14 @@ def parse_arguments(): action="store_true", help="Run without prompts; read all inputs from environment variables (for CI use).", ) + aws_parser.add_argument( + "--dry-run", + action="store_true", + help=( + "Run a local assessment and also write raw_data/payload.json without " + "remote sync." + ), + ) # Subparser for Azure azure_parser = subparsers.add_parser("azure", help="Perform an Azure assessment.") @@ -773,6 +800,14 @@ def parse_arguments(): action="store_true", help="Run without prompts; read all inputs from environment variables (for CI use).", ) + azure_parser.add_argument( + "--dry-run", + action="store_true", + help=( + "Run a local assessment and also write raw_data/payload.json without " + "remote sync." + ), + ) return parser.parse_args() diff --git a/tests/test_utils_and_main.py b/tests/test_utils_and_main.py index eb7fe03..5c6df3b 100644 --- a/tests/test_utils_and_main.py +++ b/tests/test_utils_and_main.py @@ -5,7 +5,7 @@ import unittest from argparse import Namespace from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import main from utils import codes @@ -294,6 +294,83 @@ def test_full_success_exits_0(self): self.assertIsNone(result) + def test_dry_run_writes_payload_generates_report_and_skips_sync(self): + with tempfile.TemporaryDirectory() as tmp_dir: + raw_data_path = os.path.join(tmp_dir, "raw_data") + os.makedirs(raw_data_path, exist_ok=True) + payload_path = os.path.join(raw_data_path, "payload.json") + config = VALID_CONFIG.copy() + config["assessmentType"] = 2 + + with ( + patch("main.validate_config"), + patch("main.resolve_mode", return_value=("online", "jwt-token")), + patch( + "main.create_directory", + return_value=(tmp_dir, raw_data_path), + ), + patch("main.verify_credentials", return_value=(True, "ok")), + patch("main.test_permissions", return_value=(True, True, True, "ok")), + patch( + "main.create_resource_inventory", + return_value={"success": True, "logs": ""}, + ), + patch( + "main.create_cost_inventory", + return_value={"success": True, "logs": ""}, + ), + patch( + "main.write_assessment_payload", + return_value=payload_path, + ) as mock_write, + patch( + "main.perform_risk_assessment", + return_value={"success": True, "logs": ""}, + ) as mock_risk, + patch( + "main.generate_report", + return_value={ + "success": True, + "reports": { + "HTML": f"{tmp_dir}/index.html", + "PDF": f"{tmp_dir}/report.pdf", + }, + }, + ) as mock_report, + patch("main.sync_assessment") as mock_sync, + patch("main.print_step"), + patch("main.console.print"), + ): + main.run_assessment(config, "aws", dry_run=True) + + mock_write.assert_called_once_with( + raw_data_path, + report_path=tmp_dir, + name=config["name"], + started_at=ANY, + exit_strategy=config["exitStrategy"], + cloud_service_provider=config["cloudServiceProvider"], + assessment_type=2, + ) + mock_risk.assert_called_once_with( + exit_strategy=config["exitStrategy"], + report_path=tmp_dir, + mode="offline", + ) + mock_report.assert_called_once() + mock_sync.assert_not_called() + + def test_dry_run_flag_passed_from_handle_aws(self): + with ( + patch.dict(os.environ, NonInteractiveAWSTests._BASE_ENV, clear=False), + patch("main.validate_region"), + patch("main.run_assessment") as mock_run, + patch("main.console.print"), + ): + main.handle_aws(_ni_aws_args(dry_run=True)) + + mock_run.assert_called_once_with(ANY, "aws", dry_run=True) + def _ni_aws_args(**kwargs): """Build a Namespace that looks like 'aws --non-interactive' with optional overrides.""" @@ -302,6 +379,7 @@ def _ni_aws_args(**kwargs): profile=None, name=None, non_interactive=True, + dry_run=False, ) defaults.update(kwargs) return Namespace(**defaults) @@ -314,6 +392,7 @@ def _ni_azure_args(**kwargs): cli=False, name=None, non_interactive=True, + dry_run=False, ) defaults.update(kwargs) return Namespace(**defaults) diff --git a/tests/test_utils_sync.py b/tests/test_utils_sync.py index 88c2b1c..b26f266 100644 --- a/tests/test_utils_sync.py +++ b/tests/test_utils_sync.py @@ -1,4 +1,6 @@ +import json import os +import tempfile import types import unittest from unittest.mock import MagicMock, patch @@ -65,5 +67,56 @@ def test_returns_clear_error_when_host_missing_everywhere(self): self.assertEqual(result["logs"], "HOST missing in environment and config.py") +class WriteAssessmentPayloadTests(unittest.TestCase): + def test_writes_payload_json_to_raw_data(self): + with tempfile.TemporaryDirectory() as tmp_dir: + report_path = os.path.join(tmp_dir, "report") + raw_data_path = os.path.join(report_path, "raw_data") + data_path = os.path.join(report_path, "data") + os.makedirs(raw_data_path, exist_ok=True) + os.makedirs(data_path, exist_ok=True) + + with patch("core.utils_sync.load_data") as mock_load: + mock_load.side_effect = lambda table, db_path=None: ( + [ + { + "resource_type": 10, + "location": "eu-central-1", + "count": 3, + } + ] + if table == "resource_inventory" + else [ + { + "month": "2025-01", + "cost": 42.5, + "currency": "USD", + } + ] + ) + + from core.utils_sync import write_assessment_payload + + payload_path = write_assessment_payload( + raw_data_path, + report_path=report_path, + name="Dry Run Demo", + started_at=1000, + exit_strategy=1, + cloud_service_provider=2, + assessment_type=2, + ) + + self.assertEqual(payload_path, os.path.join(raw_data_path, "payload.json")) + with open(payload_path, encoding="utf-8") as payload_file: + payload = json.load(payload_file) + + self.assertEqual(payload["type"], "local.assessment.succeeded") + self.assertEqual(payload["data"]["name"], "Dry Run Demo") + self.assertEqual(payload["data"]["assessmentType"], 2) + self.assertEqual(len(payload["data"]["resource_inventory"]), 1) + self.assertEqual(len(payload["data"]["cost_inventory"]), 1) + + if __name__ == "__main__": unittest.main()