diff --git a/tests/utils/test_autotailor.py b/tests/utils/test_autotailor.py index a86d0ece85..fcbeaa2752 100644 --- a/tests/utils/test_autotailor.py +++ b/tests/utils/test_autotailor.py @@ -134,3 +134,98 @@ def test_get_datastream_uri(): uri = t._get_datastream_uri() assert uri.startswith("file://") assert "relative/path/to/ds.xml" in uri + + +def test_datastream_validator(): + """Test that DataStreamValidator properly validates IDs.""" + ds_path = pathlib.Path(__file__).parent.joinpath("data_stream.xml") + validator = autotailor.DataStreamValidator(str(ds_path)) + + # Test valid profile validation + validator.validate_profile("xccdf_com.example.www_profile_P1") + + # Test valid value validation + validator.validate_value("xccdf_com.example.www_value_V1") + validator.validate_value("xccdf_com.example.www_value_V2") + + # Test valid rule validation + validator.validate_rule("xccdf_com.example.www_rule_R1") + validator.validate_rule("xccdf_com.example.www_rule_R2") + validator.validate_rule("xccdf_com.example.www_rule_R3") + validator.validate_rule("xccdf_com.example.www_rule_R4") + + # Test valid group validation + validator.validate_group("xccdf_com.example.www_group_G34") + + # Test invalid profile + with pytest.raises(ValueError) as e: + validator.validate_profile("xccdf_com.example.www_profile_INVALID") + assert "Profile ID 'xccdf_com.example.www_profile_INVALID' does not exist" in str(e.value) + + # Test invalid value with suggestion + with pytest.raises(ValueError) as e: + validator.validate_value("xccdf_com.example.www_value_V3") + assert "Value ID 'xccdf_com.example.www_value_V3' does not exist" in str(e.value) + + # Test invalid rule with suggestion + with pytest.raises(ValueError) as e: + validator.validate_rule("xccdf_com.example.www_rule_R5") + assert "Rule ID 'xccdf_com.example.www_rule_R5' does not exist" in str(e.value) + + # Test invalid group + with pytest.raises(ValueError) as e: + validator.validate_group("xccdf_com.example.www_group_INVALID") + assert "Group ID 'xccdf_com.example.www_group_INVALID' does not exist" in str(e.value) + + +def test_profile_with_validator(): + """Test that Profile uses validator to check IDs.""" + ds_path = pathlib.Path(__file__).parent.joinpath("data_stream.xml") + validator = autotailor.DataStreamValidator(str(ds_path)) + + p = autotailor.Profile(validator=validator) + p.reverse_dns = "com.example.www" + + # Test valid variable change works + p.add_value_change("V1", "30") + + # Test invalid variable name fails + with pytest.raises(ValueError) as e: + p.add_value_change("INVALID_VAR", "test") + assert "Value ID 'xccdf_com.example.www_value_INVALID_VAR' does not exist" in str(e.value) + + # Test valid rule selection works + p.select_rule("R1") + + # Test invalid rule selection fails + with pytest.raises(ValueError) as e: + p.select_rule("INVALID_RULE") + assert "Rule ID 'xccdf_com.example.www_rule_INVALID_RULE' does not exist" in str(e.value) + + # Test valid base profile validation + p.validate_base_profile("P1") + + # Test invalid base profile fails + with pytest.raises(ValueError) as e: + p.validate_base_profile("INVALID_PROFILE") + assert "Profile ID 'xccdf_com.example.www_profile_INVALID_PROFILE' does not exist" in str(e.value) + + +def test_validator_suggestions(): + """Test that validator provides helpful suggestions for typos.""" + ds_path = pathlib.Path(__file__).parent.joinpath("data_stream.xml") + validator = autotailor.DataStreamValidator(str(ds_path)) + + # Test suggestion for value with typo (V11 instead of V1) + with pytest.raises(ValueError) as e: + validator.validate_value("xccdf_com.example.www_value_V11") + error_msg = str(e.value) + assert "Did you mean one of these?" in error_msg + assert "xccdf_com.example.www_value_V1" in error_msg + + # Test suggestion for rule with typo (R11 instead of R1) + with pytest.raises(ValueError) as e: + validator.validate_rule("xccdf_com.example.www_rule_R11") + error_msg = str(e.value) + assert "Did you mean one of these?" in error_msg + assert "xccdf_com.example.www_rule_R1" in error_msg diff --git a/utils/autotailor b/utils/autotailor index b2ef37ff65..61a8dfee32 100755 --- a/utils/autotailor +++ b/utils/autotailor @@ -26,10 +26,12 @@ import pathlib import xml.etree.ElementTree as ET import xml.dom.minidom import json +import difflib NS = "http://checklists.nist.gov/xccdf/1.2" NS_PREFIX = "xccdf-1.2" +DS_NS = "http://scap.nist.gov/schema/scap/source/1.2" DEFAULT_PROFILE_SUFFIX = "_customized" DEFAULT_REVERSE_DNS = "org.ssgproject.content" ROLES = ["full", "unscored", "unchecked"] @@ -53,12 +55,115 @@ def is_valid_xccdf_id(string): string) is not None +class DataStreamValidator: + """Validates IDs against the SCAP datastream.""" + + def __init__(self, datastream_path): + self.datastream_path = datastream_path + self.profile_ids = set() + self.value_ids = set() + self.rule_ids = set() + self.group_ids = set() + self._parse_datastream() + + def _parse_datastream(self): + """Parse the datastream to extract all valid IDs.""" + try: + tree = ET.parse(self.datastream_path) + root = tree.getroot() + + # Register namespaces + namespaces = { + 'ds': DS_NS, + 'xccdf': NS + } + + # Find all Benchmark elements (may be in data-stream-collection or standalone) + benchmarks = root.findall('.//xccdf:Benchmark', namespaces) + + for benchmark in benchmarks: + # Extract Profile IDs + for profile in benchmark.findall('.//xccdf:Profile', namespaces): + profile_id = profile.get('id') + if profile_id: + self.profile_ids.add(profile_id) + + # Extract Value IDs + for value in benchmark.findall('.//xccdf:Value', namespaces): + value_id = value.get('id') + if value_id: + self.value_ids.add(value_id) + + # Extract Rule IDs + for rule in benchmark.findall('.//xccdf:Rule', namespaces): + rule_id = rule.get('id') + if rule_id: + self.rule_ids.add(rule_id) + + # Extract Group IDs + for group in benchmark.findall('.//xccdf:Group', namespaces): + group_id = group.get('id') + if group_id: + self.group_ids.add(group_id) + + except ET.ParseError as e: + raise ValueError(f"Failed to parse datastream '{self.datastream_path}': {e}") + except FileNotFoundError: + raise ValueError(f"Datastream file not found: '{self.datastream_path}'") + + def _suggest_similar(self, invalid_id, valid_ids, n=3): + """Suggest similar valid IDs using fuzzy matching.""" + if not valid_ids: + return [] + # Get close matches + matches = difflib.get_close_matches(invalid_id, valid_ids, n=n, cutoff=0.6) + return matches + + def _create_validation_error(self, id_type, invalid_id, valid_ids): + """Create a detailed error message with suggestions.""" + msg = f"{id_type} ID '{invalid_id}' does not exist in the datastream." + + suggestions = self._suggest_similar(invalid_id, valid_ids) + if suggestions: + msg += "\n\nDid you mean one of these?" + for suggestion in suggestions: + msg += f"\n - {suggestion}" + + if not valid_ids: + msg += f"\n\nNo {id_type.lower()}s found in the datastream." + else: + msg += f"\n\nAvailable {id_type.lower()}s can be listed by examining the datastream file." + + return msg + + def validate_profile(self, profile_id): + """Validate a profile ID exists in the datastream.""" + if profile_id not in self.profile_ids: + raise ValueError(self._create_validation_error("Profile", profile_id, self.profile_ids)) + + def validate_value(self, value_id): + """Validate a value ID exists in the datastream.""" + if value_id not in self.value_ids: + raise ValueError(self._create_validation_error("Value", value_id, self.value_ids)) + + def validate_rule(self, rule_id): + """Validate a rule ID exists in the datastream.""" + if rule_id not in self.rule_ids: + raise ValueError(self._create_validation_error("Rule", rule_id, self.rule_ids)) + + def validate_group(self, group_id): + """Validate a group ID exists in the datastream.""" + if group_id not in self.group_ids: + raise ValueError(self._create_validation_error("Group", group_id, self.group_ids)) + + class Profile: - def __init__(self): + def __init__(self, validator=None): self.reverse_dns = DEFAULT_REVERSE_DNS self._profile_id = None self.extends = "" self.profile_title = "" + self.validator = validator self.value_changes = set() self.rules_to_select = set() @@ -137,15 +242,22 @@ class Profile: def refine_rule(self, rule_id, attribute, value): Profile._validate_rule_refinement_params(rule_id, attribute, value) + if self.validator: + self.validator.validate_rule(rule_id) self._prevent_duplicate_rule_refinement(attribute, rule_id, value) self._rule_refinements[rule_id][attribute] = value def refine_value(self, value_id, attribute, value): Profile._validate_value_refinement_params(value_id, attribute, value) + if self.validator: + self.validator.validate_value(value_id) self._prevent_duplicate_value_refinement(attribute, value_id, value) self._value_refinements[value_id][attribute] = value def add_value_change(self, varname, value): + full_var_id = self._full_var_id(varname) + if self.validator: + self.validator.validate_value(full_var_id) self.value_changes.add((varname, value)) def change_rule_attribute(self, rule_id, attribute, value): @@ -177,6 +289,41 @@ class Profile: varname, selector = assignment_to_tuple(change) self.change_value_attribute(varname, "selector", selector) + def select_rule(self, rule_id): + """Select a rule with validation.""" + full_rule_id = self._full_rule_id(rule_id) + if self.validator: + self.validator.validate_rule(full_rule_id) + self.rules_to_select.add(rule_id) + + def unselect_rule(self, rule_id): + """Unselect a rule with validation.""" + full_rule_id = self._full_rule_id(rule_id) + if self.validator: + self.validator.validate_rule(full_rule_id) + self.rules_to_unselect.add(rule_id) + + def select_group(self, group_id): + """Select a group with validation.""" + full_group_id = self._full_group_id(group_id) + if self.validator: + self.validator.validate_group(full_group_id) + self.groups_to_select.add(group_id) + + def unselect_group(self, group_id): + """Unselect a group with validation.""" + full_group_id = self._full_group_id(group_id) + if self.validator: + self.validator.validate_group(full_group_id) + self.groups_to_unselect.add(group_id) + + def validate_base_profile(self, profile_id): + """Validate the base profile ID.""" + if not profile_id or not self.validator: + return + full_profile_id = self._full_profile_id(profile_id) + self.validator.validate_profile(full_profile_id) + def _full_id(self, string, el_type): if is_valid_xccdf_id(string): return string @@ -241,9 +388,9 @@ class Profile: for group_id, props in tailoring["groups"].items(): if "evaluate" in props: if props["evaluate"]: - self.groups_to_select.add(group_id) + self.select_group(group_id) else: - self.groups_to_unselect.add(group_id) + self.unselect_group(group_id) def _import_variables_from_tailoring(self, tailoring): if "variables" in tailoring: @@ -258,9 +405,9 @@ class Profile: for rule_id, props in tailoring["rules"].items(): if "evaluate" in props: if props["evaluate"]: - self.rules_to_select.add(rule_id) + self.select_rule(rule_id) else: - self.rules_to_unselect.add(rule_id) + self.unselect_rule(rule_id) for attr in ATTRIBUTES: if attr in props: self.change_rule_attribute(rule_id, attr, props[attr]) @@ -298,12 +445,13 @@ class Profile: class Tailoring: - def __init__(self): + def __init__(self, validator=None): self.reverse_dns = DEFAULT_REVERSE_DNS self.id = "xccdf_auto_tailoring_default" self.version = 1 self.original_ds_filename = "" self.use_local_path = False + self.validator = validator self.profiles = [] @@ -312,7 +460,7 @@ class Tailoring: for profile in self.profiles: if profile.profile_id == profile_id: return profile - profile = Profile() + profile = Profile(validator=self.validator) if profile_id is not None: profile.profile_id = profile_id profile.reverse_dns = self.reverse_dns @@ -367,7 +515,7 @@ class Tailoring: if 'profiles' in tailoring_dict and tailoring_dict['profiles']: for profile_dict in tailoring_dict['profiles']: - profile = Profile() + profile = Profile(validator=self.validator) profile.reverse_dns = self.reverse_dns profile.import_json_tailoring_profile(profile_dict) self.profiles.append(profile) @@ -460,6 +608,11 @@ def get_parser(): "--local-path", action="store_true", help="Use local path for the benchmark href instead of absolute file:// URI. " "Absolute paths are converted to basename, relative paths are preserved.") + parser.add_argument( + "--no-validate", action="store_true", + help="Skip validation of IDs against the datastream. This significantly speeds up " + "execution on large datastreams but may produce invalid tailoring files if incorrect " + "IDs are provided. Use with caution.") return parser @@ -473,27 +626,55 @@ if __name__ == "__main__": parser.error("one of the following arguments has to be provided: " "BASE_PROFILE_ID or --json-tailoring JSON_TAILORING_FILENAME") - t = Tailoring() + # Create validator to check IDs against the datastream (unless --no-validate is specified) + validator = None + if not args.no_validate: + try: + validator = DataStreamValidator(args.datastream) + except (ValueError, FileNotFoundError) as e: + print(f"Error loading datastream: {e}", file=sys.stderr) + sys.exit(1) + + t = Tailoring(validator=validator) t.original_ds_filename = args.datastream t.reverse_dns = args.id_namespace t.use_local_path = args.local_path if args.json_tailoring: - t.import_json_tailoring(args.json_tailoring) + try: + t.import_json_tailoring(args.json_tailoring) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) if args.profile or (args.json_tailoring and args.tailored_profile_id): p = t.get_or_create_tailored_profile_with_id(args.tailored_profile_id) p.extends = args.profile + + # Validate base profile + try: + p.validate_base_profile(args.profile) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + if args.title: p.profile_title = args.title - p.rules_to_select.update(args.select) - p.rules_to_unselect.update(args.unselect) - p.rules_to_select.difference_update(p.rules_to_unselect) - - p.change_values(args.var_value) - p.change_selectors(args.var_select) - p.change_roles(args.rule_role) - p.change_severities(args.rule_severity) + # Select/unselect rules with validation + try: + for rule_id in args.select: + p.select_rule(rule_id) + for rule_id in args.unselect: + p.unselect_rule(rule_id) + p.rules_to_select.difference_update(p.rules_to_unselect) + + p.change_values(args.var_value) + p.change_selectors(args.var_select) + p.change_roles(args.rule_role) + p.change_severities(args.rule_severity) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) t.as_xml_string(args.output)