Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions core/utils_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,31 @@ def _build_payload(
return payload


def write_assessment_payload(
raw_data_path: str,
*,
report_path: str,
name: str,
started_at: int,
exit_strategy: int,
cloud_service_provider: int,
assessment_type: int,
) -> str:
payload = _build_payload(
report_path=report_path,
name=name,
started_at=started_at,
exit_strategy=exit_strategy,
cloud_service_provider=cloud_service_provider,
assessment_type=assessment_type,
)
payload_path = os.path.join(raw_data_path, "payload.json")
with open(payload_path, "w", encoding="utf-8") as payload_file:
json.dump(payload, payload_file, indent=2)
payload_file.write("\n")
return payload_path


def post_assessment(
*,
name: str,
Expand Down
79 changes: 57 additions & 22 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
sync_assessment,
generate_report,
)
from core.utils_sync import write_assessment_payload
from utils.azure import (
select_subscription,
select_resource_group,
Expand Down Expand Up @@ -241,7 +242,7 @@ def handle_aws(args):
return

# Run the AWS assessment pipeline
run_assessment(config, "aws")
run_assessment(config, "aws", dry_run=args.dry_run)


def handle_azure(args):
Expand Down Expand Up @@ -464,10 +465,10 @@ def handle_azure(args):

# Run the Azure assessment pipeline
# logger.info("Starting Azure assessment pipeline.")
run_assessment(config, "azure")
run_assessment(config, "azure", dry_run=args.dry_run)


def run_assessment(config, provider_name):
def run_assessment(config, provider_name, *, dry_run=False):
# Record the assessment start time to propagate across stages
started_at = int(time.time())

Expand All @@ -484,7 +485,9 @@ def run_assessment(config, provider_name):

# Detect ExitCloud Integration
mode, jwt = resolve_mode()
if mode == "online":
if dry_run:
print_step("Dry run mode – no remote sync.", status="ok")
elif mode == "online":
print_step("ExitCloud integration configured.", status="ok")
else:
print_step("ExitCloud integration not configured.", status="warning")
Expand Down Expand Up @@ -629,8 +632,37 @@ def run_assessment(config, provider_name):
or f"Exit Assessment {datetime.now().strftime('%Y%m%d_%H%M%S')}"
)

payload_path = None
if dry_run:
payload_path = write_assessment_payload(
raw_data_path,
report_path=report_path,
name=name,
started_at=started_at,
exit_strategy=config["exitStrategy"],
cloud_service_provider=config["cloudServiceProvider"],
assessment_type=config["assessmentType"],
)

# Stage 5 – Online / Offline Risk Assessment
if mode == "online":
if dry_run or mode == "offline":
console.print("Stage #5 – Offline Risk Assessment", style="bold")

with console.status("Performing risk assessment...", spinner="dots"):
risk_result = perform_risk_assessment(
exit_strategy=config["exitStrategy"],
report_path=report_path,
mode="offline",
)

status = "ok" if risk_result["success"] else "error"
print_step(
"Performing risk assessment...", status=status, logs=risk_result["logs"]
)
if not risk_result["success"]:
sys.exit(codes.RISK_ASSESSMENT)

elif mode == "online":
console.print("Stage #5 – Online Risk Assessment", style="bold")

sync_result = sync_assessment(
Expand All @@ -651,23 +683,6 @@ def run_assessment(config, provider_name):
if not sync_result["success"]:
sys.exit(codes.RISK_ASSESSMENT)

elif mode == "offline":
console.print("Stage #5 – Offline Risk Assessment", style="bold")

with console.status("Performing risk assessment...", spinner="dots"):
risk_result = perform_risk_assessment(
exit_strategy=config["exitStrategy"],
report_path=report_path,
mode=mode,
)

status = "ok" if risk_result["success"] else "error"
print_step(
"Performing risk assessment...", status=status, logs=risk_result["logs"]
)
if not risk_result["success"]:
sys.exit(codes.RISK_ASSESSMENT)

console.print("-------------------------------------------")

# Stage 6: Generate Report
Expand Down Expand Up @@ -697,6 +712,8 @@ def run_assessment(config, provider_name):
# Output the report path after the separator
console.print("-------------------------------------------")
console.print("Outputs:", style="bold")
if payload_path:
console.print(f"Payload: {payload_path}", style="cyan")
html_report_path = report_status.get("reports", {}).get("HTML")
if html_report_path:
console.print(f"HTML Report: {html_report_path}", style="cyan")
Expand Down Expand Up @@ -726,6 +743,8 @@ def parse_arguments():
" python3 main.py azure --config config.json # Use a configuration file for Azure\n"
" python3 main.py azure --cli # Use Azure CLI credentials\n"
" python3 main.py azure --name 'DMS System' # Use a pre-defined assessment name\n"
" python3 main.py aws --config config.json --dry-run # Local report + payload.json, no remote sync\n"
" python3 main.py azure --config config.json --dry-run\n"
),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
Expand Down Expand Up @@ -753,6 +772,14 @@ def parse_arguments():
action="store_true",
help="Run without prompts; read all inputs from environment variables (for CI use).",
)
aws_parser.add_argument(
"--dry-run",
action="store_true",
help=(
"Run a local assessment and also write raw_data/payload.json without "
"remote sync."
),
)

# Subparser for Azure
azure_parser = subparsers.add_parser("azure", help="Perform an Azure assessment.")
Expand All @@ -773,6 +800,14 @@ def parse_arguments():
action="store_true",
help="Run without prompts; read all inputs from environment variables (for CI use).",
)
azure_parser.add_argument(
"--dry-run",
action="store_true",
help=(
"Run a local assessment and also write raw_data/payload.json without "
"remote sync."
),
)

return parser.parse_args()

Expand Down
81 changes: 80 additions & 1 deletion tests/test_utils_and_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import unittest
from argparse import Namespace
from pathlib import Path
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch

import main
from utils import codes
Expand Down Expand Up @@ -294,6 +294,83 @@ def test_full_success_exits_0(self):

self.assertIsNone(result)

def test_dry_run_writes_payload_generates_report_and_skips_sync(self):
with tempfile.TemporaryDirectory() as tmp_dir:
raw_data_path = os.path.join(tmp_dir, "raw_data")
os.makedirs(raw_data_path, exist_ok=True)
payload_path = os.path.join(raw_data_path, "payload.json")
config = VALID_CONFIG.copy()
config["assessmentType"] = 2

with (
patch("main.validate_config"),
patch("main.resolve_mode", return_value=("online", "jwt-token")),
patch(
"main.create_directory",
return_value=(tmp_dir, raw_data_path),
),
patch("main.verify_credentials", return_value=(True, "ok")),
patch("main.test_permissions", return_value=(True, True, True, "ok")),
patch(
"main.create_resource_inventory",
return_value={"success": True, "logs": ""},
),
patch(
"main.create_cost_inventory",
return_value={"success": True, "logs": ""},
),
patch(
"main.write_assessment_payload",
return_value=payload_path,
) as mock_write,
patch(
"main.perform_risk_assessment",
return_value={"success": True, "logs": ""},
) as mock_risk,
patch(
"main.generate_report",
return_value={
"success": True,
"reports": {
"HTML": f"{tmp_dir}/index.html",
"PDF": f"{tmp_dir}/report.pdf",
},
},
) as mock_report,
patch("main.sync_assessment") as mock_sync,
patch("main.print_step"),
patch("main.console.print"),
):
main.run_assessment(config, "aws", dry_run=True)

mock_write.assert_called_once_with(
raw_data_path,
report_path=tmp_dir,
name=config["name"],
started_at=ANY,
exit_strategy=config["exitStrategy"],
cloud_service_provider=config["cloudServiceProvider"],
assessment_type=2,
)
mock_risk.assert_called_once_with(
exit_strategy=config["exitStrategy"],
report_path=tmp_dir,
mode="offline",
)
mock_report.assert_called_once()
mock_sync.assert_not_called()

def test_dry_run_flag_passed_from_handle_aws(self):
with (
patch.dict(os.environ, NonInteractiveAWSTests._BASE_ENV, clear=False),
patch("main.validate_region"),
patch("main.run_assessment") as mock_run,
patch("main.console.print"),
):
main.handle_aws(_ni_aws_args(dry_run=True))

mock_run.assert_called_once_with(ANY, "aws", dry_run=True)


def _ni_aws_args(**kwargs):
"""Build a Namespace that looks like 'aws --non-interactive' with optional overrides."""
Expand All @@ -302,6 +379,7 @@ def _ni_aws_args(**kwargs):
profile=None,
name=None,
non_interactive=True,
dry_run=False,
)
defaults.update(kwargs)
return Namespace(**defaults)
Expand All @@ -314,6 +392,7 @@ def _ni_azure_args(**kwargs):
cli=False,
name=None,
non_interactive=True,
dry_run=False,
)
defaults.update(kwargs)
return Namespace(**defaults)
Expand Down
53 changes: 53 additions & 0 deletions tests/test_utils_sync.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import os
import tempfile
import types
import unittest
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -65,5 +67,56 @@ def test_returns_clear_error_when_host_missing_everywhere(self):
self.assertEqual(result["logs"], "HOST missing in environment and config.py")


class WriteAssessmentPayloadTests(unittest.TestCase):
def test_writes_payload_json_to_raw_data(self):
with tempfile.TemporaryDirectory() as tmp_dir:
report_path = os.path.join(tmp_dir, "report")
raw_data_path = os.path.join(report_path, "raw_data")
data_path = os.path.join(report_path, "data")
os.makedirs(raw_data_path, exist_ok=True)
os.makedirs(data_path, exist_ok=True)

with patch("core.utils_sync.load_data") as mock_load:
mock_load.side_effect = lambda table, db_path=None: (
[
{
"resource_type": 10,
"location": "eu-central-1",
"count": 3,
}
]
if table == "resource_inventory"
else [
{
"month": "2025-01",
"cost": 42.5,
"currency": "USD",
}
]
)

from core.utils_sync import write_assessment_payload

payload_path = write_assessment_payload(
raw_data_path,
report_path=report_path,
name="Dry Run Demo",
started_at=1000,
exit_strategy=1,
cloud_service_provider=2,
assessment_type=2,
)

self.assertEqual(payload_path, os.path.join(raw_data_path, "payload.json"))
with open(payload_path, encoding="utf-8") as payload_file:
payload = json.load(payload_file)

self.assertEqual(payload["type"], "local.assessment.succeeded")
self.assertEqual(payload["data"]["name"], "Dry Run Demo")
self.assertEqual(payload["data"]["assessmentType"], 2)
self.assertEqual(len(payload["data"]["resource_inventory"]), 1)
self.assertEqual(len(payload["data"]["cost_inventory"]), 1)


if __name__ == "__main__":
unittest.main()