diff --git a/ymir/tools/unprivileged/tests/unit/test_upstream_tools.py b/ymir/tools/unprivileged/tests/unit/test_upstream_tools.py new file mode 100644 index 00000000..40d7bf27 --- /dev/null +++ b/ymir/tools/unprivileged/tests/unit/test_upstream_tools.py @@ -0,0 +1,884 @@ +import subprocess +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from beeai_framework.middleware.trajectory import GlobalTrajectoryMiddleware +from beeai_framework.tools import ToolError + +from ymir.tools.unprivileged.upstream_tools import ( + ApplyDownstreamPatchesTool, + ApplyDownstreamPatchesToolInput, + CherryPickCommitTool, + CherryPickCommitToolInput, + CherryPickContinueTool, + CherryPickContinueToolInput, + CloneUpstreamRepositoryTool, + CloneUpstreamRepositoryToolInput, + ExtractUpstreamRepositoryInput, + ExtractUpstreamRepositoryTool, + FindBaseCommitTool, + FindBaseCommitToolInput, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _mock_aiohttp_json(json_data, status=200): + """Create a mock aiohttp ClientSession that returns json_data from any GET. + + Supports both usage patterns in upstream_tools.py: + - `async with ClientSession() as s, s.get() as r:` + - `async with ClientSession() as s:` then `async with s.get() as r:` + + The key is that `session.get(...)` must be a synchronous call returning + an async context manager (not a coroutine), matching aiohttp's real API. + """ + response = AsyncMock() + response.status = status + response.json = AsyncMock(return_value=json_data) + response.raise_for_status = MagicMock() + if status >= 400: + response.raise_for_status.side_effect = Exception(f"HTTP {status}") + + # session.get(...) returns an async context manager (not a coroutine) + get_cm = MagicMock() + get_cm.__aenter__ = AsyncMock(return_value=response) + get_cm.__aexit__ = AsyncMock(return_value=False) + + # The session itself is an object with a .get() method + session = MagicMock() + session.get = MagicMock(return_value=get_cm) + + # ClientSession() is an async context manager that yields session + client_session = MagicMock() + client_session.__aenter__ = AsyncMock(return_value=session) + client_session.__aexit__ = AsyncMock(return_value=False) + return client_session + + +def _mock_aiohttp_error(error_msg="error"): + """Create a mock aiohttp ClientSession whose GET raises aiohttp.ClientError.""" + import aiohttp + + get_cm = MagicMock() + get_cm.__aenter__ = AsyncMock(side_effect=aiohttp.ClientError(error_msg)) + get_cm.__aexit__ = AsyncMock(return_value=False) + + session = MagicMock() + session.get = MagicMock(return_value=get_cm) + + client_session = MagicMock() + client_session.__aenter__ = AsyncMock(return_value=session) + client_session.__aexit__ = AsyncMock(return_value=False) + return client_session + + +# --------------------------------------------------------------------------- +# ExtractUpstreamRepositoryTool - URL parsing tests +# --------------------------------------------------------------------------- + + +class TestExtractUpstreamRepositoryTool: + @pytest.fixture + def tool(self): + return ExtractUpstreamRepositoryTool(options={"working_directory": None}) + + @pytest.mark.asyncio + async def test_github_commit_url(self, tool): + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://github.com/libexpat/libexpat/commit/a93ef2756c88c4e3e6e7e8a9f42daa06e90e8e5b" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.repo_url == "https://github.com/libexpat/libexpat.git" + assert data.commit_hash == "a93ef2756c88c4e3e6e7e8a9f42daa06e90e8e5b" # pragma: allowlist secret + assert data.is_pr is False + assert data.is_compare is False + + @pytest.mark.asyncio + async def test_github_commit_url_with_patch_suffix(self, tool): + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://github.com/libexpat/libexpat/commit/abc1234.patch" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.repo_url == "https://github.com/libexpat/libexpat.git" + assert data.commit_hash == "abc1234" + + @pytest.mark.asyncio + async def test_gitlab_commit_url(self, tool): + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://gitlab.com/owner/repo/-/commit/deadbeef1234567" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.repo_url == "https://gitlab.com/owner/repo.git" + assert data.commit_hash == "deadbeef1234567" # pragma: allowlist secret + assert data.is_pr is False + + @pytest.mark.asyncio + async def test_cgit_query_param_url(self, tool): + """Test gitweb-style URL with p= and h= query params (p= not first).""" + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://git.example.org/gitweb?a=commitdiff&p=project.git&h=abcdef1234567890" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.repo_url == "https://git.example.org/project.git" + assert data.commit_hash == "abcdef1234567890" # pragma: allowlist secret + + @pytest.mark.asyncio + async def test_kernel_org_cgit_url(self, tool): + """kernel.org cgit URL: repo path in URL path, commit hash in ?id= query param.""" + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/commit/?id=abcdef1234567" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.commit_hash == "abcdef1234567" # pragma: allowlist secret + assert data.repo_url == "https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git" + + @pytest.mark.asyncio + async def test_cgit_commit_path_url(self, tool): + """Test cgit URL with commit hash embedded in the path.""" + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://git.savannah.gnu.org/cgit/grep.git/commit/abcdef1234567" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.commit_hash == "abcdef1234567" # pragma: allowlist secret + assert "grep.git" in data.repo_url + + @pytest.mark.asyncio + async def test_github_pr_url(self, tool): + mock_session = _mock_aiohttp_json({"head": {"sha": "pr_commit_sha_1234567890abcdef"}}) + + with patch("ymir.tools.unprivileged.upstream_tools.aiohttp.ClientSession", return_value=mock_session): + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://github.com/torvalds/linux/pull/42" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.repo_url == "https://github.com/torvalds/linux.git" + assert data.commit_hash == "pr_commit_sha_1234567890abcdef" # pragma: allowlist secret + assert data.is_pr is True + assert data.pr_number == "42" + + @pytest.mark.asyncio + async def test_github_pr_url_with_patch_suffix(self, tool): + mock_session = _mock_aiohttp_json({"head": {"sha": "abc123def456"}}) + + with patch("ymir.tools.unprivileged.upstream_tools.aiohttp.ClientSession", return_value=mock_session): + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://github.com/owner/repo/pull/99.patch" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.is_pr is True + assert data.pr_number == "99" + + @pytest.mark.asyncio + async def test_gitlab_mr_url(self, tool): + mock_session = _mock_aiohttp_json({"sha": "mr_commit_sha_abcdef"}) + + with patch("ymir.tools.unprivileged.upstream_tools.aiohttp.ClientSession", return_value=mock_session): + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://gitlab.com/owner/repo/-/merge_requests/15" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.repo_url == "https://gitlab.com/owner/repo.git" + assert data.commit_hash == "mr_commit_sha_abcdef" + assert data.is_pr is True + assert data.pr_number == "15" + + @pytest.mark.asyncio + async def test_github_compare_url(self, tool): + mock_session = _mock_aiohttp_json( + { + "commits": [ + {"sha": "aaa111"}, + {"sha": "bbb222"}, + {"sha": "ccc333"}, + ] + } + ) + + with patch("ymir.tools.unprivileged.upstream_tools.aiohttp.ClientSession", return_value=mock_session): + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://github.com/owner/repo/compare/v3.7.0...v3.7.1" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.is_compare is True + assert data.base_ref == "v3.7.0" + assert data.target_ref == "v3.7.1" + assert data.compare_commits == ["aaa111", "bbb222", "ccc333"] + assert data.commit_hash == "ccc333" + assert data.repo_url == "https://github.com/owner/repo.git" + + @pytest.mark.asyncio + async def test_gitlab_compare_url(self, tool): + mock_session = _mock_aiohttp_json( + { + "commits": [ + {"id": "newest"}, + {"id": "oldest"}, + ] + } + ) + + with patch("ymir.tools.unprivileged.upstream_tools.aiohttp.ClientSession", return_value=mock_session): + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://gitlab.com/owner/repo/-/compare/v1.0...v1.1" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.is_compare is True + # GitLab commits are reversed (newest first in API -> oldest first in output) + assert data.compare_commits == ["oldest", "newest"] + assert data.commit_hash == "newest" + + @pytest.mark.asyncio + async def test_compare_url_api_failure_falls_back_to_target_ref(self, tool): + """When API is unavailable, compare URL still returns target_ref as commit_hash.""" + client_session = _mock_aiohttp_error("timeout") + + with patch( + "ymir.tools.unprivileged.upstream_tools.aiohttp.ClientSession", return_value=client_session + ): + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://github.com/owner/repo/compare/v1.0...v1.1" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.is_compare is True + assert data.commit_hash == "v1.1" + assert data.compare_commits is None + + @pytest.mark.asyncio + async def test_pr_api_failure_raises_tool_error(self, tool): + client_session = _mock_aiohttp_error("404") + + with ( + patch( + "ymir.tools.unprivileged.upstream_tools.aiohttp.ClientSession", return_value=client_session + ), + pytest.raises(ToolError, match="Failed to fetch PR/MR information"), + ): + await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://github.com/owner/repo/pull/999" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + @pytest.mark.asyncio + async def test_unparseable_url_raises_tool_error(self, tool): + with pytest.raises(ToolError, match="Could not extract commit hash"): + await tool.run( + input=ExtractUpstreamRepositoryInput(upstream_fix_url="https://example.com/not-a-commit-url") + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + @pytest.mark.asyncio + async def test_commit_url_without_repo_path_raises_tool_error(self, tool): + """cgit URL with commit hash but no repo path should error.""" + with pytest.raises(ToolError, match="Could not extract"): + await tool.run( + input=ExtractUpstreamRepositoryInput(upstream_fix_url="https://example.org/?id=abcdef1234567") + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + @pytest.mark.asyncio + async def test_double_dot_compare_separator(self, tool): + """Compare URLs with '..' separator should also work.""" + mock_session = _mock_aiohttp_json({"commits": [{"sha": "only1"}]}) + + with patch("ymir.tools.unprivileged.upstream_tools.aiohttp.ClientSession", return_value=mock_session): + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://github.com/owner/repo/compare/v1.0..v1.1" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.is_compare is True + assert data.base_ref == "v1.0" + assert data.target_ref == "v1.1" + + @pytest.mark.asyncio + async def test_compare_url_target_ref_ending_in_patch_chars(self, tool): + """Compare URL where target_ref ends in characters from the set '.patch'.""" + client_session = _mock_aiohttp_error("skip API") + + with patch( + "ymir.tools.unprivileged.upstream_tools.aiohttp.ClientSession", return_value=client_session + ): + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://github.com/owner/repo/compare/v1.0...some-branch-path" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.target_ref == "some-branch-path" + + @pytest.mark.asyncio + async def test_gitlab_nested_path_mr_url(self, tool): + """GitLab MR URL with deeply nested project path (more than owner/repo).""" + mock_session = _mock_aiohttp_json({"sha": "mr_head_commit"}) + + with patch("ymir.tools.unprivileged.upstream_tools.aiohttp.ClientSession", return_value=mock_session): + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://gitlab.com/redhat/centos-stream/rpms/bind/-/merge_requests/15" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.repo_url == "https://gitlab.com/redhat/centos-stream/rpms/bind.git" + # Verify the API URL was correctly encoded with the full nested path + session = mock_session.__aenter__.return_value + api_url = session.get.call_args[0][0] + assert "redhat%2Fcentos-stream%2Frpms%2Fbind" in api_url + assert "/merge_requests/15" in api_url + + @pytest.mark.asyncio + async def test_gitlab_nested_path_compare_url(self, tool): + """GitLab compare URL with deeply nested project path.""" + mock_session = _mock_aiohttp_json({"commits": [{"id": "abc123"}]}) + + with patch("ymir.tools.unprivileged.upstream_tools.aiohttp.ClientSession", return_value=mock_session): + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://gitlab.com/redhat/centos-stream/rpms/bind/-/compare/v9.18.0...v9.18.1" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.repo_url == "https://gitlab.com/redhat/centos-stream/rpms/bind.git" + # Verify the API URL was correctly encoded with the full nested path + session = mock_session.__aenter__.return_value + api_url = session.get.call_args[0][0] + assert "redhat%2Fcentos-stream%2Frpms%2Fbind" in api_url + assert "/repository/compare" in api_url + + @pytest.mark.asyncio + async def test_cgit_p_param_at_start_of_query(self, tool): + """cgit/gitweb URL where p= is the first query parameter.""" + result = await tool.run( + input=ExtractUpstreamRepositoryInput( + upstream_fix_url="https://git.example.org/gitweb?p=project.git&h=abcdef1234567" + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + data = result.result + assert data.repo_url == "https://git.example.org/project.git" + assert data.commit_hash == "abcdef1234567" # pragma: allowlist secret + + +# --------------------------------------------------------------------------- +# CloneUpstreamRepositoryTool +# --------------------------------------------------------------------------- + + +class TestCloneUpstreamRepositoryTool: + @pytest.fixture + def tool(self): + return CloneUpstreamRepositoryTool(options={"working_directory": None}) + + @pytest.mark.asyncio + async def test_clone_success(self, tool, tmp_path): + clone_dir = tmp_path / "mypackage" + expected_path = tmp_path / "mypackage-upstream" + + async def mock_run_subprocess(cmd, **kwargs): + # Simulate the git clone by creating the directory structure + expected_path.mkdir(parents=True, exist_ok=True) + (expected_path / ".git").mkdir(exist_ok=True) + return (0, "", "") + + with patch("ymir.tools.unprivileged.upstream_tools.run_subprocess", side_effect=mock_run_subprocess): + result = await tool.run( + input=CloneUpstreamRepositoryToolInput( + repo_url="https://github.com/owner/repo.git", + clone_directory=str(clone_dir), + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + assert "Successfully cloned" in result.result + assert "mypackage-upstream" in result.result + + @pytest.mark.asyncio + async def test_clone_directory_already_exists(self, tool, tmp_path): + clone_dir = tmp_path / "mypackage" + # Create the -upstream directory to trigger the error + (tmp_path / "mypackage-upstream").mkdir() + + with pytest.raises(ToolError, match="Clone directory already exists"): + await tool.run( + input=CloneUpstreamRepositoryToolInput( + repo_url="https://github.com/owner/repo.git", + clone_directory=str(clone_dir), + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + @pytest.mark.asyncio + async def test_clone_git_failure(self, tool, tmp_path): + clone_dir = tmp_path / "mypackage" + + with ( + patch( + "ymir.tools.unprivileged.upstream_tools.run_subprocess", new_callable=AsyncMock + ) as mock_run, + pytest.raises(ToolError, match="Git clone failed"), + ): + mock_run.return_value = (128, "", "fatal: repository not found") + await tool.run( + input=CloneUpstreamRepositoryToolInput( + repo_url="https://github.com/nonexistent/repo.git", + clone_directory=str(clone_dir), + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + +# --------------------------------------------------------------------------- +# FindBaseCommitTool +# --------------------------------------------------------------------------- + + +class TestFindBaseCommitTool: + @pytest.fixture + def upstream_repo(self, tmp_path): + repo = tmp_path / "upstream" + repo.mkdir() + subprocess.run(["git", "init"], cwd=repo, check=True) + (repo / "file.c").write_text("int main() {}\n") + subprocess.run(["git", "add", "."], cwd=repo, check=True) + subprocess.run(["git", "commit", "-m", "Initial"], cwd=repo, check=True) + subprocess.run(["git", "tag", "v1.2.3"], cwd=repo, check=True) + subprocess.run(["git", "tag", "release-2.0.0"], cwd=repo, check=True) + return repo + + @pytest.fixture + def tool(self): + return FindBaseCommitTool(options={"working_directory": None}) + + @pytest.mark.asyncio + async def test_finds_v_prefixed_tag(self, tool, upstream_repo): + result = await tool.run( + input=FindBaseCommitToolInput( + repo_path=str(upstream_repo), + version="1.2.3", + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + assert "v1.2.3" in result.result + assert "base_tag_commit" in tool.options + + @pytest.mark.asyncio + async def test_finds_release_prefixed_tag(self, tool, upstream_repo): + result = await tool.run( + input=FindBaseCommitToolInput( + repo_path=str(upstream_repo), + version="2.0.0", + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + assert "release-2.0.0" in result.result + + @pytest.mark.asyncio + async def test_explicit_tag_override(self, tool, upstream_repo): + result = await tool.run( + input=FindBaseCommitToolInput( + repo_path=str(upstream_repo), + version="99.99.99", + tag="v1.2.3", + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + assert "v1.2.3" in result.result + + @pytest.mark.asyncio + async def test_explicit_commit_override(self, tool, upstream_repo): + head = subprocess.run( + ["git", "rev-parse", "HEAD"], cwd=upstream_repo, capture_output=True, text=True, check=True + ).stdout.strip() + + result = await tool.run( + input=FindBaseCommitToolInput( + repo_path=str(upstream_repo), + version="99.99.99", + commit=head, + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + assert head in result.result + assert tool.options["base_tag_commit"] == head + + @pytest.mark.asyncio + async def test_no_matching_tag_raises_with_available_tags(self, tool, upstream_repo): + with pytest.raises(ToolError, match=r"Could not find tag matching version 99\.99\.99") as exc_info: + await tool.run( + input=FindBaseCommitToolInput( + repo_path=str(upstream_repo), + version="99.99.99", + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + assert "v1.2.3" in exc_info.value.message + + @pytest.mark.asyncio + async def test_not_a_git_repo(self, tool, tmp_path): + with pytest.raises(ToolError, match="Not a git repository"): + await tool.run( + input=FindBaseCommitToolInput( + repo_path=str(tmp_path), + version="1.0.0", + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + +# --------------------------------------------------------------------------- +# ApplyDownstreamPatchesTool +# --------------------------------------------------------------------------- + + +class TestApplyDownstreamPatchesTool: + @pytest.fixture + def upstream_repo(self, tmp_path): + repo = tmp_path / "upstream" + repo.mkdir() + subprocess.run(["git", "init"], cwd=repo, check=True) + (repo / "main.c").write_text("int main() { return 0; }\n") + subprocess.run(["git", "add", "."], cwd=repo, check=True) + subprocess.run(["git", "commit", "-m", "Initial"], cwd=repo, check=True) + return repo + + @pytest.fixture + def patches_dir(self, tmp_path): + d = tmp_path / "patches" + d.mkdir() + # A valid unified diff patch + (d / "fix-one.patch").write_text( + "--- a/main.c\n+++ b/main.c\n@@ -1 +1,2 @@\n int main() { return 0; }\n+/* fix one */\n" + ) + (d / "fix-two.patch").write_text( + "--- a/main.c\n" + "+++ b/main.c\n" + "@@ -1,2 +1,3 @@\n" + " int main() { return 0; }\n" + " /* fix one */\n" + "+/* fix two */\n" + ) + return d + + @pytest.fixture + def tool(self): + return ApplyDownstreamPatchesTool(options={"working_directory": None}) + + @pytest.mark.asyncio + async def test_apply_patches_success(self, tool, upstream_repo, patches_dir): + result = await tool.run( + input=ApplyDownstreamPatchesToolInput( + repo_path=str(upstream_repo), + patch_files=["fix-one.patch", "fix-two.patch"], + patches_directory=str(patches_dir), + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + assert "Successfully applied 2 patches" in result.result + assert "fix-one.patch" in result.result + assert "fix-two.patch" in result.result + assert "base_head_commit" in tool.options + + content = (upstream_repo / "main.c").read_text() + assert "/* fix one */" in content + assert "/* fix two */" in content + + @pytest.mark.asyncio + async def test_apply_empty_patch_list(self, tool, upstream_repo, patches_dir): + result = await tool.run( + input=ApplyDownstreamPatchesToolInput( + repo_path=str(upstream_repo), + patch_files=[], + patches_directory=str(patches_dir), + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + assert "No patches to apply" in result.result + assert "base_head_commit" in tool.options + + @pytest.mark.asyncio + async def test_missing_patch_file_raises(self, tool, upstream_repo, patches_dir): + with pytest.raises(ToolError, match="Patch file not found"): + await tool.run( + input=ApplyDownstreamPatchesToolInput( + repo_path=str(upstream_repo), + patch_files=["nonexistent.patch"], + patches_directory=str(patches_dir), + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + @pytest.mark.asyncio + async def test_conflicting_patch_raises(self, tool, upstream_repo, patches_dir): + # Write a patch that won't apply (wrong context) + (patches_dir / "bad.patch").write_text( + "--- a/main.c\n" + "+++ b/main.c\n" + "@@ -1,3 +1,4 @@\n" + " this context does not exist\n" + " neither does this\n" + " or this\n" + "+added line\n" + ) + + with pytest.raises(ToolError, match=r"Failed to apply existing patch 'bad\.patch'"): + await tool.run( + input=ApplyDownstreamPatchesToolInput( + repo_path=str(upstream_repo), + patch_files=["bad.patch"], + patches_directory=str(patches_dir), + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + @pytest.mark.asyncio + async def test_custom_strip_levels(self, tool, upstream_repo, patches_dir): + # Create a patch with a/subdir/main.c prefix requiring -p2 + (patches_dir / "strip2.patch").write_text( + "--- a/subdir/main.c\n" + "+++ b/subdir/main.c\n" + "@@ -1 +1,2 @@\n" + " int main() { return 0; }\n" + "+/* strip 2 applied */\n" + ) + + result = await tool.run( + input=ApplyDownstreamPatchesToolInput( + repo_path=str(upstream_repo), + patch_files=["strip2.patch"], + patches_directory=str(patches_dir), + patch_strip_levels={"strip2.patch": 2}, + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + assert "Successfully applied 1 patches" in result.result + assert "/* strip 2 applied */" in (upstream_repo / "main.c").read_text() + + @pytest.mark.asyncio + async def test_not_a_git_repo(self, tool, tmp_path, patches_dir): + with pytest.raises(ToolError, match="Not a git repository"): + await tool.run( + input=ApplyDownstreamPatchesToolInput( + repo_path=str(tmp_path), + patch_files=["fix-one.patch"], + patches_directory=str(patches_dir), + ) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + +# --------------------------------------------------------------------------- +# CherryPickCommitTool +# --------------------------------------------------------------------------- + + +class TestCherryPickCommitTool: + @pytest.fixture + def repo_with_branch(self, tmp_path): + """Create a repo with a diverged branch to test cherry-pick scenarios.""" + repo = tmp_path / "repo" + repo.mkdir() + subprocess.run(["git", "init"], cwd=repo, check=True) + (repo / "file.c").write_text("line 1\n") + subprocess.run(["git", "add", "."], cwd=repo, check=True) + subprocess.run(["git", "commit", "-m", "Initial"], cwd=repo, check=True) + + # Create a branch and add a commit there + subprocess.run(["git", "checkout", "-b", "feature"], cwd=repo, check=True) + (repo / "file.c").write_text("line 1\nfeature line\n") + subprocess.run(["git", "add", "."], cwd=repo, check=True) + subprocess.run(["git", "commit", "-m", "Add feature"], cwd=repo, check=True) + feature_commit = subprocess.run( + ["git", "rev-parse", "HEAD"], cwd=repo, capture_output=True, text=True, check=True + ).stdout.strip() + + # Go back to main and add a different commit (will conflict) + subprocess.run(["git", "checkout", "-"], cwd=repo, check=True) + (repo / "file.c").write_text("line 1\nmain line\n") + subprocess.run(["git", "add", "."], cwd=repo, check=True) + subprocess.run(["git", "commit", "-m", "Main change"], cwd=repo, check=True) + + return repo, feature_commit + + @pytest.fixture + def tool(self): + return CherryPickCommitTool(options={"working_directory": None}) + + @pytest.mark.asyncio + async def test_cherry_pick_success_no_conflict(self, tool, tmp_path): + repo = tmp_path / "repo" + repo.mkdir() + subprocess.run(["git", "init"], cwd=repo, check=True) + (repo / "a.txt").write_text("hello\n") + subprocess.run(["git", "add", "."], cwd=repo, check=True) + subprocess.run(["git", "commit", "-m", "Initial"], cwd=repo, check=True) + + subprocess.run(["git", "checkout", "-b", "other"], cwd=repo, check=True) + (repo / "b.txt").write_text("new file\n") + subprocess.run(["git", "add", "."], cwd=repo, check=True) + subprocess.run(["git", "commit", "-m", "Add b.txt"], cwd=repo, check=True) + commit = subprocess.run( + ["git", "rev-parse", "HEAD"], cwd=repo, capture_output=True, text=True, check=True + ).stdout.strip() + + subprocess.run(["git", "checkout", "-"], cwd=repo, check=True) + + result = await tool.run( + input=CherryPickCommitToolInput(repo_path=str(repo), commit_hash=commit) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + assert "Successfully cherry-picked" in result.result + assert (repo / "b.txt").exists() + + @pytest.mark.asyncio + async def test_cherry_pick_with_conflict(self, tool, repo_with_branch): + repo, feature_commit = repo_with_branch + + result = await tool.run( + input=CherryPickCommitToolInput(repo_path=str(repo), commit_hash=feature_commit) + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + assert "conflicts" in result.result.lower() + assert "file.c" in result.result + assert tool.options.get("fix_commit") == feature_commit + + @pytest.mark.asyncio + async def test_cherry_pick_commit_not_found(self, tool, tmp_path): + repo = tmp_path / "repo" + repo.mkdir() + subprocess.run(["git", "init"], cwd=repo, check=True) + (repo / "a.txt").write_text("x\n") + subprocess.run(["git", "add", "."], cwd=repo, check=True) + subprocess.run(["git", "commit", "-m", "Init"], cwd=repo, check=True) + + with pytest.raises(ToolError, match="not found"): + await tool.run( + input=CherryPickCommitToolInput(repo_path=str(repo), commit_hash="deadbeefdeadbeef") + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + @pytest.mark.asyncio + async def test_not_a_git_repo(self, tool, tmp_path): + with pytest.raises(ToolError, match="Not a git repository"): + await tool.run( + input=CherryPickCommitToolInput(repo_path=str(tmp_path), commit_hash="abc1234") + ).middleware(GlobalTrajectoryMiddleware(pretty=True)) + + +# --------------------------------------------------------------------------- +# CherryPickContinueTool +# --------------------------------------------------------------------------- + + +class TestCherryPickContinueTool: + @pytest.fixture + def conflicted_repo(self, tmp_path, monkeypatch): + """Set up a repo in a cherry-pick conflict state.""" + monkeypatch.setenv("GIT_EDITOR", "true") + monkeypatch.setenv("GIT_COMMITTER_NAME", "Test") + monkeypatch.setenv("GIT_COMMITTER_EMAIL", "test@test.com") + monkeypatch.setenv("GIT_AUTHOR_NAME", "Test") + monkeypatch.setenv("GIT_AUTHOR_EMAIL", "test@test.com") + + repo = tmp_path / "repo" + repo.mkdir() + subprocess.run(["git", "init"], cwd=repo, check=True) + (repo / "file.c").write_text("original\n") + subprocess.run(["git", "add", "."], cwd=repo, check=True) + subprocess.run(["git", "commit", "-m", "Initial"], cwd=repo, check=True) + + subprocess.run(["git", "checkout", "-b", "feature"], cwd=repo, check=True) + (repo / "file.c").write_text("feature change\n") + subprocess.run(["git", "add", "."], cwd=repo, check=True) + subprocess.run(["git", "commit", "-m", "Feature"], cwd=repo, check=True) + feature_commit = subprocess.run( + ["git", "rev-parse", "HEAD"], cwd=repo, capture_output=True, text=True, check=True + ).stdout.strip() + + subprocess.run(["git", "checkout", "-"], cwd=repo, check=True) + (repo / "file.c").write_text("main change\n") + subprocess.run(["git", "add", "."], cwd=repo, check=True) + subprocess.run(["git", "commit", "-m", "Main"], cwd=repo, check=True) + + # Start cherry-pick that will conflict + subprocess.run(["git", "cherry-pick", feature_commit], cwd=repo) + # Repo is now in cherry-pick conflict state + return repo + + @pytest.fixture + def tool(self): + return CherryPickContinueTool(options={"working_directory": None}) + + @pytest.mark.asyncio + async def test_continue_after_resolving_conflicts(self, tool, conflicted_repo): + # Resolve the conflict by writing clean content and staging it + (conflicted_repo / "file.c").write_text("resolved content\n") + subprocess.run(["git", "add", "file.c"], cwd=conflicted_repo, check=True) + + result = await tool.run(input=CherryPickContinueToolInput(repo_path=str(conflicted_repo))).middleware( + GlobalTrajectoryMiddleware(pretty=True) + ) + + assert "Successfully completed cherry-pick" in result.result + + @pytest.mark.asyncio + async def test_continue_with_unresolved_conflicts(self, tool, conflicted_repo): + # Don't resolve — file still has conflict markers + with pytest.raises(ToolError, match="Unresolved conflicts"): + await tool.run(input=CherryPickContinueToolInput(repo_path=str(conflicted_repo))).middleware( + GlobalTrajectoryMiddleware(pretty=True) + ) + + @pytest.mark.asyncio + async def test_not_in_cherry_pick_state(self, tool, tmp_path): + repo = tmp_path / "repo" + repo.mkdir() + subprocess.run(["git", "init"], cwd=repo, check=True) + (repo / "a.txt").write_text("x\n") + subprocess.run(["git", "add", "."], cwd=repo, check=True) + subprocess.run(["git", "commit", "-m", "Init"], cwd=repo, check=True) + + with pytest.raises(ToolError, match="Not in a cherry-pick state"): + await tool.run(input=CherryPickContinueToolInput(repo_path=str(repo))).middleware( + GlobalTrajectoryMiddleware(pretty=True) + ) + + @pytest.mark.asyncio + async def test_not_a_git_repo(self, tool, tmp_path): + with pytest.raises(ToolError, match="Not a git repository"): + await tool.run(input=CherryPickContinueToolInput(repo_path=str(tmp_path))).middleware( + GlobalTrajectoryMiddleware(pretty=True) + ) diff --git a/ymir/tools/unprivileged/upstream_tools.py b/ymir/tools/unprivileged/upstream_tools.py index 387bf3a9..0b7935c5 100644 --- a/ymir/tools/unprivileged/upstream_tools.py +++ b/ymir/tools/unprivileged/upstream_tools.py @@ -89,25 +89,30 @@ async def _run( # Check if this is a pull request URL and extract owner/repo/PR number in one match. pr_match = re.search(r"/([\w\-\.]+)/([\w\-\.]+)/pull/(\d+)(?:\.patch)?", parsed.path) mr_match = re.search( - r"/([\w\-\.]+)/([\w\-\.]+)/-/merge_requests/(\d+)(?:\.patch)?", + r"/(.+?)/(?:-/)?merge_requests/(\d+)(?:\.patch)?", parsed.path, ) if pr_match or mr_match: # Handle GitHub Pull Request or GitLab Merge Request - match = pr_match if pr_match else mr_match - owner = match.group(1) - repo = match.group(2) - pr_number = match.group(3) + if pr_match: + owner = pr_match.group(1) + repo = pr_match.group(2) + pr_number = pr_match.group(3) + project_path = f"{owner}/{repo}" + else: + project_path = mr_match.group(1).removesuffix(".git") + pr_number = mr_match.group(2) # Fetch PR/MR information to get the head commit if pr_match: # GitHub API - api_url = f"https://api.github.com/repos/{owner}/{repo}/pulls/{pr_number}" + api_url = f"https://api.github.com/repos/{project_path}/pulls/{pr_number}" else: - # GitLab API + # GitLab API - URL-encode the full project path api_url = ( - f"https://{parsed.netloc}/api/v4/projects/{owner}%2F{repo}/merge_requests/{pr_number}" + f"https://{parsed.netloc}/api/v4/projects/" + f"{quote(project_path, safe='')}/merge_requests/{pr_number}" ) headers = { @@ -133,7 +138,7 @@ async def _run( ) from e # Construct repository URL - repo_url = f"https://{parsed.netloc}/{owner}/{repo}.git" + repo_url = f"https://{parsed.netloc}/{project_path}.git" # Return with PR information return ExtractUpstreamRepositoryOutput( @@ -148,18 +153,17 @@ async def _run( # Try to match compare URL compare_match = re.search( - r"/([\w\-\.]+)/([\w\-\.]+)/(?:-/)?compare/(.+?)(\.{2,3})([^\s\?#]+)", + r"/(.+?)/(?:-/)?compare/(.+?)(\.{2,3})([^\s\?#]+)", parsed.path, ) if compare_match: # Handle GitHub/GitLab Compare URLs - owner = compare_match.group(1) - repo = compare_match.group(2) - base_ref = compare_match.group(3) - # Group 4 is the separator (.. or ...) - not used, we always use ... for APIs - target_ref = compare_match.group(5).rstrip(".patch") # Remove .patch if present + project_path = compare_match.group(1).removesuffix(".git") + base_ref = compare_match.group(2) + # Group 3 is the separator (.. or ...) - not used, we always use ... for APIs + target_ref = compare_match.group(4).removesuffix(".patch") # Construct repository URL - repo_url = f"https://{parsed.netloc}/{owner}/{repo}.git" + repo_url = f"https://{parsed.netloc}/{project_path}.git" # Fetch compare information to get the list of commits headers = { "Accept": "application/json", @@ -169,11 +173,11 @@ async def _run( commit_hash = target_ref try: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - # Determine if this is GitHub or GitLab based on the URL pattern - if "/-/" not in parsed.path: + # Determine if this is GitHub or GitLab based on the domain + if "github" in parsed.netloc.lower(): # GitHub API - URL-encode refs to handle special characters like / in branch names api_url = ( - f"https://api.github.com/repos/{owner}/{repo}/compare/" + f"https://api.github.com/repos/{project_path}/compare/" f"{quote(base_ref, safe='')}...{quote(target_ref, safe='')}" ) async with session.get(api_url, headers=headers) as response: @@ -182,9 +186,10 @@ async def _run( # GitHub: commits are in 'commits' array (oldest first) commits = [commit["sha"] for commit in data.get("commits", [])] else: - # GitLab API - use params dict for automatic URL encoding + # GitLab API - URL-encode the full project path api_url = ( - f"https://{parsed.netloc}/api/v4/projects/{owner}%2F{repo}/repository/compare" + f"https://{parsed.netloc}/api/v4/projects/" + f"{quote(project_path, safe='')}/repository/compare" ) params = {"from": base_ref, "to": target_ref} async with session.get(api_url, params=params, headers=headers) as response: @@ -231,9 +236,14 @@ async def _run( query_match = re.search(r"(?:id|h)=([a-f0-9]{7,40})", parsed.query) if query_match: commit_hash = query_match.group(1) - repo_query_match = re.search(r"[?&]p=([^;&]+)", parsed.query) + repo_query_match = re.search(r"(?:^|[?&])p=([^;&]+)", parsed.query) if repo_query_match: repo_path = repo_query_match.group(1) + else: + # cgit-style: repo path in URL path (e.g. /pub/scm/.../linux.git/commit/) + path_repo_match = re.search(r"^/?(.+?\.git)(?:/|$)", parsed.path) + if path_repo_match: + repo_path = path_repo_match.group(1) if commit_hash: if not repo_path: raise ToolError( @@ -745,10 +755,6 @@ async def _run( if exit_code != 0: raise ToolError(f"Failed to check git status: {stderr}") - # Validate stdout is not None - if stdout is None: - raise ToolError("Git status command returned no output") - # Check if we're actually in a cherry-pick state by looking for .git/CHERRY_PICK_HEAD if not (tool_input.repo_path / ".git" / "CHERRY_PICK_HEAD").exists(): raise ToolError("Not in a cherry-pick state. Cannot continue cherry-pick.")