Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/agent_scan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ def add_common_arguments(parser):
default=False,
help="Exit with a non-zero code when there are analysis findings or runtime failures",
)
parser.add_argument(
"--enable-oauth",
action="store_true",
default=False,
help="Enable interactive OAuth authentication for remote MCP servers",
)


def add_server_arguments(parser):
Expand Down Expand Up @@ -567,12 +573,15 @@ async def run_scan(args, mode: Literal["scan", "inspect"] = "scan") -> list[Scan
with open(args.mcp_oauth_tokens_path) as f:
tokens = TokenAndClientInfoList.model_validate_json(f.read()).root

enable_oauth: bool = hasattr(args, "enable_oauth") and args.enable_oauth

inspect_args = InspectArgs(
timeout=server_timeout,
tokens=tokens,
paths=files,
all_users=scan_all_users,
scan_skills=scan_skills,
enable_oauth=enable_oauth,
)

if mode == "scan":
Expand Down
12 changes: 9 additions & 3 deletions src/agent_scan/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,15 @@ async def inspect_extension(
config: StdioServer | RemoteServer | SkillServer,
timeout: int,
token: TokenAndClientInfo | None = None,
enable_oauth: bool = False,
) -> InspectedExtensions:
"""
Scan an extension (MCP server or skill) and return a InspectedExtensions object.
"""
traffic_capture = TrafficCapture()
if isinstance(config, StdioServer):
try:
signature, _ = await check_server(config, timeout, traffic_capture, token)
signature, _ = await check_server(config, timeout, traffic_capture, token, enable_oauth=enable_oauth)
return InspectedExtensions(name=name, config=config, signature_or_error=signature)
except Exception as e:
return InspectedExtensions(
Expand All @@ -180,7 +181,9 @@ async def inspect_extension(

if isinstance(config, RemoteServer):
try:
signature, fixed_config = await check_server(config.model_copy(deep=True), timeout, traffic_capture, token)
signature, fixed_config = await check_server(
config.model_copy(deep=True), timeout, traffic_capture, token, enable_oauth=enable_oauth
)
assert isinstance(fixed_config, RemoteServer), f"Fixed config is not a RemoteServer: {fixed_config}"
return InspectedExtensions(name=name, config=fixed_config, signature_or_error=signature)
except HTTPStatusError as e:
Expand Down Expand Up @@ -234,6 +237,7 @@ async def inspect_client(
timeout: int,
tokens: list[TokenAndClientInfo],
scan_skills: bool,
enable_oauth: bool = False,
) -> InspectedClient:
"""
Scan a client (Cursor, VSCode, etc.) and return a InspectedClient object.
Expand All @@ -248,7 +252,9 @@ async def inspect_client(
continue
extensions_for_mcp_config: list[InspectedExtensions] = []
for name, server in mcp_configs:
extension = await inspect_extension(name, server, timeout, find_relevant_token(tokens, name))
extension = await inspect_extension(
name, server, timeout, find_relevant_token(tokens, name), enable_oauth=enable_oauth
)
extensions_for_mcp_config.append(extension)
extensions[mcp_config_path] = extensions_for_mcp_config

Expand Down
85 changes: 53 additions & 32 deletions src/agent_scan/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ClaudeConfigFile,
ConfigWithoutMCP,
FileTokenStorage,
InteractiveTokenStorage,
MCPConfig,
RemoteServer,
ServerSignature,
Expand All @@ -30,69 +31,82 @@
VSCodeConfigFile,
VSCodeMCPConfig,
)
from agent_scan.oauth import build_oauth_client_provider, make_callback_handler, make_redirect_handler
from agent_scan.traffic_capture import PipeStderrCapture, TrafficCapture, capturing_client
from agent_scan.utils import resolve_command_and_args

# Set up logger for this module
logger = logging.getLogger(__name__)


# [REVIEW][BEFORE] streamablehttp_client_without_session accepted a `token` param and
# built OAuthClientProvider internally, duplicating provider construction logic.
# [REVIEW][AFTER] Accepts a pre-built `oauth_client_provider` so construction is
# centralised in get_client, making it easier to swap storage strategies.
@asynccontextmanager
async def streamablehttp_client_without_session(
url: str,
headers: dict[str, str],
timeout: int,
token: TokenAndClientInfo | None = None,
oauth_client_provider: OAuthClientProvider | None = None,
):
async def handle_redirect(auth_url: str) -> None:
raise NotImplementedError(f"handle_redirect is not implemented {auth_url}")

async def handle_callback(auth_code: str, state: str | None) -> tuple[str, str | None]:
raise NotImplementedError(f"handle_callback is not implemented {auth_code} {state}")

if token:
oauth_client_provider = OAuthClientProvider(
server_url=token.mcp_server_url,
client_metadata=OAuthClientMetadata(
client_name="mcp-scan",
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
redirect_uris=["http://localhost:3030/callback"],
),
storage=FileTokenStorage(data=token),
redirect_handler=handle_redirect,
callback_handler=handle_callback,
)
else:
oauth_client_provider = None
async with httpx.AsyncClient(
auth=oauth_client_provider, follow_redirects=True, headers=headers, timeout=timeout
) as custom_client:
async with streamable_http_client(url=url, http_client=custom_client) as (read, write, _):
yield read, write


# [REVIEW][BEFORE] get_client did not support interactive OAuth — only token-based auth
# [REVIEW][AFTER] Added enable_oauth param; constructs OAuthClientProvider centrally
# using either FileTokenStorage (when token provided) or InteractiveTokenStorage
# (when enable_oauth=True for remote servers)
@asynccontextmanager
async def get_client(
server_config: StdioServer | RemoteServer,
timeout: int | None = None,
traffic_capture: TrafficCapture | None = None,
token: TokenAndClientInfo | None = None,
enable_oauth: bool = False,
) -> AsyncIterator[tuple]:
"""
Create an MCP client for the given server config.

If traffic_capture is provided, all MCP protocol traffic will be captured
for debugging purposes.
"""
# Construct the OAuthClientProvider centrally
oauth_client_provider: OAuthClientProvider | None = None
if token and isinstance(server_config, RemoteServer):
oauth_client_provider = OAuthClientProvider(
server_url=token.mcp_server_url,
client_metadata=OAuthClientMetadata(
client_name="mcp-scan",
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
redirect_uris=["http://localhost:3030/callback"],
),
storage=FileTokenStorage(data=token),
redirect_handler=make_redirect_handler(),
callback_handler=make_callback_handler(),
)
elif enable_oauth and isinstance(server_config, RemoteServer):
storage = InteractiveTokenStorage(server_url=server_config.url)
oauth_client_provider = build_oauth_client_provider(
server_url=server_config.url,
storage=storage,
)
Comment on lines +88 to +100
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Action required

4. Oauth storage keyed by path 🐞 Bug ≡ Correctness

InteractiveTokenStorage uses server_config.url as its persistence key, but check_server()
mutates server_config.url while probing /mcp and /sse variants. This causes tokens to be stored
under different directories for the same logical server, leading to repeated authorization prompts
and fragmented token state.
Agent Prompt
### Issue description
Interactive OAuth token persistence is keyed off `server_config.url`, but `check_server()` mutates that URL while trying different protocol/path combinations. This makes the token storage directory vary per attempt, so previously-authorized tokens often won’t be found.

### Issue Context
Users may be forced to re-authorize multiple times for the same server, and multiple per-variant token directories may be created.

### Fix Focus Areas
- Use a stable key for token storage (e.g., `scheme://host:port` from `urlparse(url)`), not the full mutable path.
- Alternatively, capture the original URL before probing and pass a separate stable `oauth_server_id`/`storage_key` into `get_client()`.

### Fix Focus Areas (code locations)
- src/agent_scan/mcp_client.py[80-98]
- src/agent_scan/mcp_client.py[236-276]

ⓘ Copy this prompt and use it to remediate the issue with your preferred AI generation tools


if isinstance(server_config, RemoteServer) and server_config.type == "sse":
logger.debug("Creating SSE client with URL: %s", server_config.url)
client_cm = sse_client(
url=server_config.url,
headers=server_config.headers,
# env=server_config.env, #Not supported by MCP yet, but present in vscode
timeout=timeout,
)
sse_kwargs: dict = {
"url": server_config.url,
"headers": server_config.headers,
"timeout": timeout,
}
if oauth_client_provider is not None:
sse_kwargs["auth"] = oauth_client_provider
client_cm = sse_client(**sse_kwargs)
elif isinstance(server_config, RemoteServer) and server_config.type == "http":
logger.debug(
"Creating Streamable HTTP client with URL: %s with headers %s", server_config.url, server_config.headers
Expand All @@ -101,7 +115,7 @@ async def get_client(
url=server_config.url,
headers=server_config.headers,
timeout=timeout or 60,
token=token,
oauth_client_provider=oauth_client_provider,
)
elif isinstance(server_config, StdioServer):
logger.debug("Creating stdio client")
Expand Down Expand Up @@ -140,9 +154,12 @@ async def _check_server_pass(
timeout: int,
traffic_capture: TrafficCapture | None = None,
token: TokenAndClientInfo | None = None,
enable_oauth: bool = False,
) -> ServerSignature:
async def _check_server() -> ServerSignature:
async with get_client(server_config, timeout=timeout, traffic_capture=traffic_capture, token=token) as (
async with get_client(
server_config, timeout=timeout, traffic_capture=traffic_capture, token=token, enable_oauth=enable_oauth
) as (
read,
write,
):
Expand Down Expand Up @@ -205,11 +222,14 @@ async def check_server(
timeout: int,
traffic_capture: TrafficCapture | None = None,
token: TokenAndClientInfo | None = None,
enable_oauth: bool = False,
) -> tuple[ServerSignature, StdioServer | RemoteServer]:
logger.debug("Checking server with timeout: %s seconds", timeout)

if not isinstance(server_config, RemoteServer):
result = await asyncio.wait_for(_check_server_pass(server_config, timeout, traffic_capture), timeout)
result = await asyncio.wait_for(
_check_server_pass(server_config, timeout, traffic_capture, enable_oauth=enable_oauth), timeout
)
logger.debug("Server check completed within timeout")
return result, server_config
else:
Expand Down Expand Up @@ -251,7 +271,8 @@ async def check_server(
server_config.url = url
logger.debug(f"Trying {protocol} with url: {url}")
result = await asyncio.wait_for(
_check_server_pass(server_config, timeout, traffic_capture, token), timeout
_check_server_pass(server_config, timeout, traffic_capture, token, enable_oauth=enable_oauth),
timeout,
)
logger.debug("Server check completed within timeout")
return result, server_config
Expand Down
65 changes: 63 additions & 2 deletions src/agent_scan/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import base64
import json
import logging
import os
import re
from itertools import chain
from pathlib import Path
from typing import Any, Literal, TypeAlias

from lark import Lark
Expand Down Expand Up @@ -454,7 +457,9 @@ async def get_tokens(self) -> OAuthToken | None:
return self.data.token

async def set_tokens(self, tokens: OAuthToken) -> None:
raise NotImplementedError("set_tokens is not supported for FileTokenStorage")
# [REVIEW][BEFORE] Raised NotImplementedError, preventing token refresh flows
# [REVIEW][AFTER] Update in-memory token so refresh cycles work with FileTokenStorage
self.data.token = tokens

async def get_client_info(self) -> OAuthClientInformationFull | None:
return OAuthClientInformationFull(
Expand All @@ -464,7 +469,63 @@ async def get_client_info(self) -> OAuthClientInformationFull | None:

async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
"""Store client information."""
raise NotImplementedError("set_client_info is not supported for FileTokenStorage")
# [REVIEW][BEFORE] Raised NotImplementedError, breaking dynamic client registration
# [REVIEW][AFTER] Update in-memory client_id so the provider can track registrations
self.data.client_id = client_info.client_id


class InteractiveTokenStorage(TokenStorage):
"""Persistent file-based token storage for interactive OAuth flows.

Stores tokens and client information as JSON files in a directory
derived from the server URL, under a configurable base directory.
"""

def __init__(self, server_url: str, base_dir: str = "~/.mcp-scan-oauth") -> None:
self._server_url = server_url
self._base_dir = os.path.expanduser(base_dir)

async def get_tokens(self) -> OAuthToken | None:
"""Read tokens from {storage_dir}/tokens.json, returning None if absent."""
token_path = self._get_storage_dir() / "tokens.json"
if not token_path.exists():
return None
with open(token_path, encoding="utf-8") as f:
data = json.load(f)
return OAuthToken.model_validate(data)

async def set_tokens(self, tokens: OAuthToken) -> None:
"""Write tokens to {storage_dir}/tokens.json."""
token_path = self._get_storage_dir() / "tokens.json"
with open(token_path, "w", encoding="utf-8") as f:
json.dump(tokens.model_dump(mode="json"), f)

async def get_client_info(self) -> OAuthClientInformationFull | None:
"""Read client info from {storage_dir}/client_info.json, returning None if absent."""
info_path = self._get_storage_dir() / "client_info.json"
if not info_path.exists():
return None
with open(info_path, encoding="utf-8") as f:
data = json.load(f)
return OAuthClientInformationFull.model_validate(data)

async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
"""Write client info to {storage_dir}/client_info.json."""
info_path = self._get_storage_dir() / "client_info.json"
with open(info_path, "w", encoding="utf-8") as f:
json.dump(client_info.model_dump(mode="json"), f)

def _get_storage_dir(self) -> Path:
"""Return the per-server storage directory, creating it if necessary."""
safe_name = self._url_safe_filename(self._server_url)
storage_dir = Path(self._base_dir) / safe_name
os.makedirs(storage_dir, exist_ok=True)
return storage_dir

@staticmethod
def _url_safe_filename(url: str) -> str:
"""Convert a URL to a filesystem-safe directory name using base64url encoding."""
return base64.urlsafe_b64encode(url.encode()).decode().rstrip("=")


class SerializedException(BaseModel):
Expand Down
Loading
Loading