diff --git a/.agents/skills/e2e_to_bazel/SKILL.md b/.agents/skills/e2e_to_bazel/SKILL.md new file mode 100644 index 0000000000..dac8fdef3d --- /dev/null +++ b/.agents/skills/e2e_to_bazel/SKILL.md @@ -0,0 +1,51 @@ +--- +name: e2e-to-bazel +description: >- + Converts end-to-end (e2e) test targets or paths to `bazel run` commands for + `heir-opt`. Use when you need to convert an e2e test target (where flags are + defined in the BUILD file) to a shell command whose flags can be modified + for further analysis and debugging. +--- + +# E2E To Bazel + +## Overview + +This skill guides the agent in using the `e2e_to_bazel` tool to convert e2e test +targets or paths to executable commands. + +## Usage + +### Converting a Test Target or Path + +To convert an e2e test target, directory, or source file to a `bazel run` +command, use the following command recipes: + +```bash +bazel run //scripts:e2e_to_bazel -- {target_or_path} +``` + +Replace `{target_or_path}` with the test target (e.g., +`//tests/Examples/lattigo/ckks/mnist:mnist_test`), directory, or source file +path. + +## Gotchas + +- **Blaze Query Dependency**: The tool relies on `blaze query` to extract + attributes. This requires a working `blaze` environment and may be slow for + the first run. +- **Workspace State**: It requires that the workspace is in a state where + `blaze query` can evaluate the targets. +- **Source file mapping**: If a source file is provided, it searches for + `heir_opt` targets that depend on it. + +```markdown +Copy this checklist and track progress: + +- [ ] Step 1: Identify the e2e test target, directory, or file to convert. +- [ ] Step 2: Run the `e2e_to_bazel` tool on it. +- [ ] Step 3: Inspect the generated command or run it. +- [ ] Step 4: Verify execution results. +``` + + diff --git a/scripts/BUILD b/scripts/BUILD index ff9e111eeb..b6dbb9c797 100644 --- a/scripts/BUILD +++ b/scripts/BUILD @@ -22,6 +22,15 @@ py_test( ], ) +py_test( + name = "test_e2e_to_bazel", + srcs = ["test_e2e_to_bazel.py"], + deps = [ + ":e2e_to_bazel_lib", + "@abseil-py//absl/testing:absltest", + ], +) + py_library( name = "lit_to_bazel_lib", srcs = ["lit_to_bazel_lib.py"], @@ -39,6 +48,23 @@ py_binary( ], ) +py_library( + name = "e2e_to_bazel_lib", + srcs = ["e2e_to_bazel_lib.py"], + deps = [ + "@heir_pip_deps//fire", + ], +) + +py_binary( + name = "e2e_to_bazel", + srcs = ["e2e_to_bazel.py"], + deps = [ + ":e2e_to_bazel_lib", + "@heir_pip_deps//fire", + ], +) + py_library( name = "get_version", srcs = ["get_version.py"], diff --git a/scripts/e2e_to_bazel.py b/scripts/e2e_to_bazel.py new file mode 100755 index 0000000000..dfc00cf48b --- /dev/null +++ b/scripts/e2e_to_bazel.py @@ -0,0 +1,7 @@ +"""Binary entry point for e2e_to_bazel.""" + +import fire +from scripts import e2e_to_bazel_lib + +if __name__ == "__main__": + fire.Fire(e2e_to_bazel_lib.e2e_to_bazel) diff --git a/scripts/e2e_to_bazel_lib.py b/scripts/e2e_to_bazel_lib.py new file mode 100644 index 0000000000..6901fc4c1e --- /dev/null +++ b/scripts/e2e_to_bazel_lib.py @@ -0,0 +1,160 @@ +"""Library for converting e2e tests to bazel run commands.""" + +import os +import shlex +import subprocess +import xml.etree.ElementTree as ET + +DEFAULT_TOOL_PREFIX = "bazel run --noallow_analysis_cache_discard //tools" + + +def run_blaze_query(query_str, options=None): + """Runs a blaze query and returns the output.""" + cmd = ["bazel", "query"] + if options: + cmd.extend(options) + cmd.extend([query_str, "--keep_going"]) + cwd = os.environ.get("BUILD_WORKSPACE_DIRECTORY") + try: + result = subprocess.run( + cmd, capture_output=True, text=True, check=False, cwd=cwd + ) + # Bazel query returns 3 if it encountered errors but still produced results with --keep_going. + if result.returncode not in [0, 3]: + print(f"Warning: blaze query failed with exit code {result.returncode}") + print(result.stderr) + return result.stdout + except Exception as e: + print(f"Error running blaze query: {e}") + return "" + + +def path_to_label(path): + """Attempts to convert a file path to a bazel label.""" + if path.startswith("//"): + return path + if os.path.isfile(path): + dir_path = os.path.dirname(path) + file_name = os.path.basename(path) + if os.path.exists(os.path.join(dir_path, "BUILD")): + return f"//{dir_path}:{file_name}" + return path + + +def get_heir_opt_target(target_or_path): + """Finds the heir_opt target associated with the given target or path.""" + # Normalize path to label if it looks like a file + label = path_to_label(target_or_path) + + if label.endswith(".mlir"): + query = f"kind(heir_opt, rdeps(//tests/Examples/..., {label}))" # fmt: skip + output = run_blaze_query(query) + lines = output.strip().split("\n") + targets = [l for l in lines if l.startswith("//")] + if targets: + return targets[0] + return None + + if os.path.isdir(target_or_path): + label = target_or_path + if not label.startswith("//"): + label = "//" + label + query = f"kind(heir_opt, {label}:*)" + output = run_blaze_query(query) + lines = output.strip().split("\n") + targets = [l for l in lines if l.startswith("//")] + if targets: + return targets[0] + return None + + if target_or_path.startswith("//"): + # Try to find heir_opt targets in the same package first + pkg = target_or_path.split(":")[0] + query = f"kind(heir_opt, {pkg}:*)" + output = run_blaze_query(query) + lines = output.strip().split("\n") + targets = [l for l in lines if l.startswith("//")] + if targets: + return targets[0] + + # Fallback: search deps + query = f"kind(heir_opt, deps({target_or_path}))" + output = run_blaze_query(query) + lines = output.strip().split("\n") + targets = [l for l in lines if l.startswith("//")] + if targets: + return targets[0] + return None + + return None + + +def e2e_to_bazel(target_or_path, tool_prefix=DEFAULT_TOOL_PREFIX): + """Converts an e2e test target or path to a blaze run command for heir-opt. + + Args: + target_or_path: The test target, directory, or source file. + tool_prefix: The prefix for the heir-opt tool. + """ + heir_opt_target = get_heir_opt_target(target_or_path) + if not heir_opt_target: + print(f"Could not find heir_opt target for {target_or_path}") + return + + xml_output = run_blaze_query(heir_opt_target, options=["--output=xml"]) + + if not xml_output: + print("Failed to get XML output from blaze query") + return + + try: + root = ET.fromstring(xml_output) + rule = root.find("rule") + if rule is None or rule.get("class") != "heir_opt": + print(f"Target {heir_opt_target} is not an heir_opt rule") + return + + pass_flags = [] + src = "" + + for list_elem in rule.findall("list"): + if list_elem.get("name") == "pass_flags": + for str_elem in list_elem.findall("string"): + pass_flags.append(str_elem.get("value")) + + for label_elem in rule.findall("label"): + if label_elem.get("name") == "src": + src = label_elem.get("value") + + if not src: + print("Could not find src attribute") + return + + # Resolve src label to path + src_path = src + if src_path.startswith("//"): + src_path = src_path[2:].replace(":", "/") + + workspace_root = os.environ.get("BUILD_WORKSPACE_DIRECTORY") + if not workspace_root: + try: + result = subprocess.run( + ["bazel", "info", "workspace"], + capture_output=True, + text=True, + check=True, + ) + workspace_root = result.stdout.strip() + except Exception as e: + # Fallback to relative path if we can't get workspace root + workspace_root = "" + + if workspace_root: + src_path = os.path.join(workspace_root, src_path) + + flags_str = " ".join(shlex.quote(f) for f in pass_flags) + command = f"{tool_prefix}:heir-opt -- {flags_str} {src_path}" + print(command) + + except Exception as e: + print(f"Error parsing XML: {e}") diff --git a/scripts/test_e2e_to_bazel.py b/scripts/test_e2e_to_bazel.py new file mode 100644 index 0000000000..de6dd82ed2 --- /dev/null +++ b/scripts/test_e2e_to_bazel.py @@ -0,0 +1,99 @@ +"""Tests for e2e_to_bazel.py.""" + +from contextlib import redirect_stdout +import io +from absl.testing import absltest +from scripts import e2e_to_bazel_lib + +patch = absltest.mock.patch + + +class E2EToBazelTest(absltest.TestCase): + """Tests for e2e_to_bazel script.""" + + @patch("scripts.e2e_to_bazel_lib.run_blaze_query") + @absltest.mock.patch.dict( + "os.environ", {"BUILD_WORKSPACE_DIRECTORY": "/workspace"} + ) + def test_e2e_to_bazel_with_target(self, mock_run_blaze_query): + """Tests e2e_to_bazel with a target argument.""" + + def side_effect(query_str, options=None): + if options and "--output=xml" in options: + return """ + + + + + + + + +""" + elif ( + "kind(heir_opt, //tests/Examples/openfhe/ckks/dot_product_8f:*)" # fmt: skip + in query_str + ): + return "//tests/Examples/openfhe/ckks/dot_product_8f:dot_product_8f_test_heir_opt" + return "" + + mock_run_blaze_query.side_effect = side_effect + + f = io.StringIO() + with redirect_stdout(f): + e2e_to_bazel_lib.e2e_to_bazel( + "//tests/Examples/openfhe/ckks/dot_product_8f:dot_product_8f_test" + ) + output = f.getvalue().strip() + + expected_command = ( + "bazel run --noallow_analysis_cache_discard //tools:heir-opt --" + " '--annotate-module=backend=openfhe scheme=ckks'" + " --mlir-to-ckks=ciphertext-degree=1024 --scheme-to-openfhe" + " /workspace/tests/Examples/common/dot_product_8f.mlir" + ) + self.assertIn(expected_command, output) + + @patch("scripts.e2e_to_bazel_lib.run_blaze_query") + @absltest.mock.patch.dict( + "os.environ", {"BUILD_WORKSPACE_DIRECTORY": "/workspace"} + ) + def test_e2e_to_bazel_with_file(self, mock_run_blaze_query): + """Tests e2e_to_bazel with a file argument.""" + + def side_effect(query_str, options=None): + if options and "--output=xml" in options: + return """ + + + + + + + + +""" + elif "rdeps(" in query_str: + return "//tests/Examples/openfhe/ckks/dot_product_8f:dot_product_8f_test_heir_opt" + return "" + + mock_run_blaze_query.side_effect = side_effect + + f = io.StringIO() + with redirect_stdout(f): + e2e_to_bazel_lib.e2e_to_bazel("tests/Examples/common/dot_product_8f.mlir") # fmt: skip + output = f.getvalue().strip() + + expected_command = ( + "bazel run --noallow_analysis_cache_discard //tools:heir-opt --" + " '--annotate-module=backend=openfhe scheme=ckks'" + " --mlir-to-ckks=ciphertext-degree=1024 --scheme-to-openfhe" + " /workspace/tests/Examples/common/dot_product_8f.mlir" + ) + self.assertIn(expected_command, output) + + +if __name__ == "__main__": + absltest.main()