diff --git a/python/01-learn/18-input-output-guardrails/01_input_guardrail.ipynb b/python/01-learn/18-input-output-guardrails/01_input_guardrail.ipynb new file mode 100644 index 00000000..7aa18303 --- /dev/null +++ b/python/01-learn/18-input-output-guardrails/01_input_guardrail.ipynb @@ -0,0 +1,544 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Input Guardrails\n", + "\n", + "This notebook demonstrates how to implement **input guardrails** using the Strands Agents SDK's `BeforeInvocationEvent` hook.\n", + "\n", + "Input guardrails inspect user messages **before** they reach the model, allowing you to:\n", + "- Block harmful, off-topic, or non-compliant requests\n", + "- Detect and reject PII (emails, phone numbers, SSNs)\n", + "- Compose multiple content filters for layered protection\n", + "\n", + "**Key concept:** Use a `HookProvider` class with `BeforeInvocationEvent` to intercept messages before model inference. Access messages via `event.agent.messages`. When a violation is detected, replace the messages with a rejection prompt.\n", + "\n", + "## Architecture\n", + "\n", + "The following diagram shows where input guardrails sit in the agent lifecycle:\n", + "\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "Make sure you have the required packages installed and the content filters module (`content_filters.py`) available in the same directory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install required packages\n", + "!pip install strands-agents strands-agents-tools --upgrade -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Content filter classes \u2014 inline definitions (no external file needed)\n", + "from dataclasses import dataclass\n", + "from enum import Enum\n", + "from typing import Optional\n", + "import re\n", + "\n", + "class Severity(Enum):\n", + " BLOCK = 'block'\n", + " WARN = 'warn'\n", + " REDACT = 'redact'\n", + "\n", + "@dataclass\n", + "class FilterResult:\n", + " passed: bool\n", + " filter_name: str\n", + " severity: Severity\n", + " message: Optional[str] = None\n", + " redacted_text: Optional[str] = None\n", + "\n", + "class ContentFilter:\n", + " def __init__(self, name, severity=Severity.BLOCK):\n", + " self.name = name\n", + " self.severity = severity\n", + " def evaluate(self, text):\n", + " raise NotImplementedError\n", + "\n", + "class RegexContentFilter(ContentFilter):\n", + " def __init__(self, name, patterns, severity=Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " self.patterns = [re.compile(p) for p in patterns]\n", + " def evaluate(self, text):\n", + " for pattern in self.patterns:\n", + " if pattern.search(text):\n", + " if self.severity == Severity.REDACT:\n", + " redacted = text\n", + " for p in self.patterns:\n", + " redacted = p.sub('[REDACTED]', redacted)\n", + " return FilterResult(False, self.name, self.severity,\n", + " f'Pattern matched: {pattern.pattern}', redacted)\n", + " return FilterResult(False, self.name, self.severity,\n", + " f'Pattern matched: {pattern.pattern}')\n", + " return FilterResult(True, self.name, self.severity)\n", + "\n", + "class KeywordContentFilter(ContentFilter):\n", + " def __init__(self, name, keywords, severity=Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " self.keywords = [kw.lower() for kw in keywords]\n", + " def evaluate(self, text):\n", + " text_lower = text.lower()\n", + " for keyword in self.keywords:\n", + " if keyword in text_lower:\n", + " return FilterResult(False, self.name, self.severity,\n", + " f\"Prohibited keyword: '{keyword}'\")\n", + " return FilterResult(True, self.name, self.severity)\n", + "\n", + "class FormatComplianceFilter(ContentFilter):\n", + " EXECUTION_PATTERNS = [\n", + " re.compile(r'\\b(run|execute|eval)\\s*\\(', re.IGNORECASE),\n", + " re.compile(r'```\\s*(bash|shell|sh)\\b', re.IGNORECASE),\n", + " re.compile(r'sudo\\s+\\w+', re.IGNORECASE),\n", + " ]\n", + " def __init__(self, name='format_compliance', severity=Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " def evaluate(self, text):\n", + " for pattern in self.EXECUTION_PATTERNS:\n", + " if pattern.search(text):\n", + " return FilterResult(False, self.name, self.severity,\n", + " 'Code execution instruction detected')\n", + " return FilterResult(True, self.name, self.severity)\n", + "\n", + "def run_filters(text, filters):\n", + " for f in filters:\n", + " result = f.evaluate(text)\n", + " if not result.passed:\n", + " return result\n", + " return None\n", + "\n", + "print('Content filter classes loaded.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "logger = logging.getLogger(__name__)\n", + "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s [%(levelname)s] %(message)s\", datefmt=\"%H:%M:%S\")\n", + "from strands.hooks import HookProvider, HookRegistry, BeforeInvocationEvent\n", + "\n", + "# Import content filters from our shared module\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Helper: Extract Text from a Message\n", + "\n", + "Messages follow the Bedrock Converse API format:\n", + "```python\n", + "{\"role\": \"user\", \"content\": [{\"text\": \"user message here\"}]}\n", + "```\n", + "\n", + "This helper extracts and concatenates all text content blocks from a message." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def _extract_text_from_message(message: dict) -> str:\n", + " \"\"\"Extract text content from a Strands SDK message.\n", + "\n", + " Args:\n", + " message: A message dictionary with 'role' and 'content' keys.\n", + "\n", + " Returns:\n", + " Concatenated text from all text content blocks in the message.\n", + " \"\"\"\n", + " text_parts = []\n", + " for block in message.get(\"content\", []):\n", + " if \"text\" in block:\n", + " text_parts.append(block[\"text\"])\n", + " return \" \".join(text_parts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Input Guardrail: Keyword-Based Content Blocking\n", + "\n", + "The simplest guardrail pattern: define a list of prohibited keywords and block any request that contains them. This uses the `KeywordContentFilter` from our shared content filters module.\n", + "\n", + "We define the guardrail logic as a standalone function first (for easy testing with mocks), then wrap it in a `HookProvider` class for agent registration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define prohibited keywords for topic blocking\n", + "PROHIBITED_KEYWORDS = [\"hack\", \"exploit\", \"bypass security\", \"illegal\", \"steal credentials\"]\n", + "\n", + "# Create a keyword filter instance\n", + "keyword_filter = KeywordContentFilter(\n", + " name=\"prohibited_topics\",\n", + " keywords=PROHIBITED_KEYWORDS,\n", + " severity=Severity.BLOCK,\n", + ")\n", + "\n", + "\n", + "def input_guardrail_logic(messages: list) -> None:\n", + " \"\"\"Simple input guardrail that blocks prohibited topics.\n", + "\n", + " Inspects the last user message and blocks requests containing\n", + " prohibited keywords by replacing messages with a rejection prompt.\n", + " \"\"\"\n", + " if not messages:\n", + " return\n", + "\n", + " last_message = messages[-1]\n", + " if last_message.get(\"role\") != \"user\":\n", + " return\n", + "\n", + " text = _extract_text_from_message(last_message)\n", + " if not text:\n", + " return\n", + "\n", + " result = keyword_filter.evaluate(text)\n", + "\n", + " if not result.passed:\n", + " logger.warning(\n", + " f\"[INPUT GUARDRAIL] Blocked request. \"\n", + " f\"Filter: {result.filter_name}, Reason: {result.message}\"\n", + " )\n", + " messages.clear()\n", + " messages.append(\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": [\n", + " {\n", + " \"text\": (\n", + " \"Respond only with: I cannot process that request. \"\n", + " \"The input was blocked by a content safety filter.\"\n", + " )\n", + " }\n", + " ],\n", + " }\n", + " )\n", + " else:\n", + " logger.debug(\"[INPUT GUARDRAIL] Request passed keyword filter.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Input Guardrail: PII Detection\n", + "\n", + "This guardrail detects personally identifiable information (emails, phone numbers, SSNs) in user messages and blocks the request to prevent PII from being sent to the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define PII patterns\n", + "PII_PATTERNS = [\n", + " r\"\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b\", # Email\n", + " r\"\\b\\d{3}[-.]?\\d{3}[-.]?\\d{4}\\b\", # Phone number\n", + " r\"\\b\\d{3}-\\d{2}-\\d{4}\\b\", # SSN\n", + "]\n", + "\n", + "# Create a PII filter instance that blocks requests containing PII\n", + "pii_filter = RegexContentFilter(\n", + " name=\"pii_detector\",\n", + " patterns=PII_PATTERNS,\n", + " severity=Severity.BLOCK,\n", + ")\n", + "\n", + "\n", + "def pii_input_guardrail_logic(messages: list) -> None:\n", + " \"\"\"Input guardrail that blocks requests containing PII.\"\"\"\n", + " if not messages:\n", + " return\n", + "\n", + " last_message = messages[-1]\n", + " if last_message.get(\"role\") != \"user\":\n", + " return\n", + "\n", + " text = _extract_text_from_message(last_message)\n", + " if not text:\n", + " return\n", + "\n", + " result = pii_filter.evaluate(text)\n", + "\n", + " if not result.passed:\n", + " logger.warning(\n", + " f\"[PII GUARDRAIL] Blocked request containing PII. \"\n", + " f\"Filter: {result.filter_name}, Reason: {result.message}\"\n", + " )\n", + " messages.clear()\n", + " messages.append(\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": [\n", + " {\n", + " \"text\": (\n", + " \"Respond only with: I cannot process that request. \"\n", + " \"Personal information (PII) was detected in your message. \"\n", + " \"Please remove any email addresses, phone numbers, or \"\n", + " \"social security numbers and try again.\"\n", + " )\n", + " }\n", + " ],\n", + " }\n", + " )\n", + " else:\n", + " logger.debug(\"[PII GUARDRAIL] Request passed PII filter.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Combined Input Guardrail: Multiple Filters in Sequence\n", + "\n", + "Demonstrates composing multiple content filters into a single guardrail. Filters are evaluated in order \u2014 the first violation stops evaluation and triggers a rejection.\n", + "\n", + "Filter order:\n", + "1. Keyword filter (blocks prohibited topics)\n", + "2. PII filter (blocks personal information)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def combined_input_guardrail_logic(messages: list) -> None:\n", + " \"\"\"Input guardrail that applies multiple filters in sequence.\"\"\"\n", + " if not messages:\n", + " return\n", + "\n", + " last_message = messages[-1]\n", + " if last_message.get(\"role\") != \"user\":\n", + " return\n", + "\n", + " text = _extract_text_from_message(last_message)\n", + " if not text:\n", + " return\n", + "\n", + " # Run all filters in sequence, stopping at the first violation\n", + " filters = [keyword_filter, pii_filter]\n", + " violation = run_filters(text, filters)\n", + "\n", + " if violation is not None:\n", + " logger.warning(\n", + " f\"[COMBINED GUARDRAIL] Blocked request. \"\n", + " f\"Filter: {violation.filter_name}, Reason: {violation.message}\"\n", + " )\n", + " messages.clear()\n", + " messages.append(\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": [\n", + " {\n", + " \"text\": (\n", + " \"Respond only with: I cannot process that request. \"\n", + " f\"Reason: {violation.message}\"\n", + " )\n", + " }\n", + " ],\n", + " }\n", + " )\n", + " else:\n", + " logger.debug(\"[COMBINED GUARDRAIL] Request passed all input filters.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing the Guardrails\n", + "\n", + "We can test guardrail logic directly using mock messages \u2014 no live model needed. This simulates how the Strands SDK would call our hook functions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test: Prohibited keyword detection\n", + "print(\"Test: Prohibited keyword detection\")\n", + "test_messages = [{\"role\": \"user\", \"content\": [{\"text\": \"How do I hack into a WiFi network?\"}]}]\n", + "input_guardrail_logic(test_messages)\n", + "print(f\" Input: 'How do I hack into a WiFi network?'\")\n", + "print(f\" Result: {test_messages[0]['content'][0]['text'][:60]}...\")\n", + "\n", + "# Test: PII detection\n", + "print(\"\\nTest: PII detection\")\n", + "test_messages = [{\"role\": \"user\", \"content\": [{\"text\": \"Send the report to john@example.com\"}]}]\n", + "pii_input_guardrail_logic(test_messages)\n", + "print(f\" Input: 'Send the report to john@example.com'\")\n", + "print(f\" Result: {test_messages[0]['content'][0]['text'][:60]}...\")\n", + "\n", + "# Test: Clean content passes through\n", + "print(\"\\nTest: Clean content passes through\")\n", + "test_messages = [{\"role\": \"user\", \"content\": [{\"text\": \"What are best practices for application security?\"}]}]\n", + "input_guardrail_logic(test_messages)\n", + "print(f\" Input: 'What are best practices for application security?'\")\n", + "print(f\" Result: Message unchanged (passed)\")\n", + "assert test_messages[0][\"content\"][0][\"text\"] == \"What are best practices for application security?\"\n", + "\n", + "# Test: Combined guardrail (keyword + PII)\n", + "print(\"\\nTest: Combined guardrail (keyword + PII)\")\n", + "test_messages = [{\"role\": \"user\", \"content\": [{\"text\": \"Help me exploit this, email me at attacker@evil.com\"}]}]\n", + "combined_input_guardrail_logic(test_messages)\n", + "print(f\" Input: 'Help me exploit this, email me at attacker@evil.com'\")\n", + "print(f\" Result: {test_messages[0]['content'][0]['text'][:60]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## HookProvider: Wrapping Guardrails for Agent Registration\n", + "\n", + "In strands-agents 1.40.0, hooks are registered via `HookProvider` classes. The `hooks=` parameter on `Agent` expects a list of `HookProvider` objects.\n", + "\n", + "```python\n", + "from strands.hooks import HookProvider, HookRegistry, BeforeInvocationEvent\n", + "\n", + "class InputGuardrailHook(HookProvider):\n", + " def register_hooks(self, registry: HookRegistry) -> None:\n", + " registry.add_callback(BeforeInvocationEvent, self._validate_input)\n", + "\n", + " def _validate_input(self, event: BeforeInvocationEvent) -> None:\n", + " messages = event.agent.messages\n", + " combined_input_guardrail_logic(messages)\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class InputGuardrailHook(HookProvider):\n", + " \"\"\"HookProvider that applies the combined input guardrail before each invocation.\"\"\"\n", + "\n", + " def register_hooks(self, registry: HookRegistry) -> None:\n", + " registry.add_callback(BeforeInvocationEvent, self._validate_input)\n", + "\n", + " def _validate_input(self, event: BeforeInvocationEvent) -> None:\n", + " messages = event.agent.messages\n", + " combined_input_guardrail_logic(messages)\n", + "\n", + "\n", + "print(\"InputGuardrailHook defined successfully.\")\n", + "print(\"Register with: Agent(hooks=[InputGuardrailHook()])\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Attaching to a Live Agent\n", + "\n", + "In production, you attach guardrails to a Strands Agent using the `hooks` parameter with `HookProvider` instances.\n", + "\n", + "**Note:** The cell below requires a configured model provider (e.g., AWS Bedrock credentials)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from strands import Agent\n", + " from strands.models.bedrock import BedrockModel\n", + "\n", + " model = BedrockModel(model_id=\"us.anthropic.claude-sonnet-4-5-20250929-v1:0\")\n", + "\n", + " # Create an agent with the combined input guardrail\n", + " agent = Agent(\n", + " model=model,\n", + " system_prompt=\"You are a helpful assistant.\",\n", + " hooks=[InputGuardrailHook()],\n", + " )\n", + "\n", + " print(\"Agent created with input guardrail attached.\")\n", + " print(\"Testing with a safe request...\")\n", + " response = agent(\"What is the capital of France?\")\n", + " print(f\" Response: {response}\")\n", + "\n", + " print(\"\\nTesting with a prohibited request...\")\n", + " response = agent(\"How do I hack into a system?\")\n", + " print(f\" Response: {response}\")\n", + "\n", + "except Exception as e:\n", + " print(f\"Skipping live agent demo: {e}\")\n", + " print(\"(This is expected if no model provider is configured)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "In this notebook you learned how to:\n", + "1. Use `BeforeInvocationEvent` to intercept user messages before model inference\n", + "2. Build a keyword-based content filter that blocks prohibited topics\n", + "3. Build a PII detection filter using regex patterns\n", + "4. Compose multiple filters into a combined guardrail\n", + "5. Test guardrails using mock messages (no model needed)\n", + "6. Wrap guardrail logic in a `HookProvider` class for agent registration\n", + "7. Attach guardrails to a live Strands Agent\n", + "\n", + "**Next Steps:** See `02_output_guardrail.ipynb` to learn how to validate model responses *after* inference." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/python/01-learn/18-input-output-guardrails/02_output_guardrail.ipynb b/python/01-learn/18-input-output-guardrails/02_output_guardrail.ipynb new file mode 100644 index 00000000..f42e0361 --- /dev/null +++ b/python/01-learn/18-input-output-guardrails/02_output_guardrail.ipynb @@ -0,0 +1,530 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Output Guardrails\n", + "\n", + "This notebook demonstrates how to implement **output guardrails** using the Strands Agents SDK's `AfterInvocationEvent` hook.\n", + "\n", + "Output guardrails inspect model responses **after** inference completes, allowing you to:\n", + "- **BLOCK**: Replace the entire response with a safe fallback message\n", + "- **REDACT**: Replace only matched patterns (e.g., PII) with `[REDACTED]`\n", + "\n", + "**Key concept:** Use a `HookProvider` class with `AfterInvocationEvent` to intercept responses after model inference. Access the agent's messages via `event.agent.messages` and modify the last assistant message." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install required packages\n", + "!pip install strands-agents strands-agents-tools --upgrade -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Content filter classes \u2014 inline definitions (no external file needed)\n", + "from dataclasses import dataclass\n", + "from enum import Enum\n", + "from typing import Optional\n", + "import re\n", + "\n", + "class Severity(Enum):\n", + " BLOCK = 'block'\n", + " WARN = 'warn'\n", + " REDACT = 'redact'\n", + "\n", + "@dataclass\n", + "class FilterResult:\n", + " passed: bool\n", + " filter_name: str\n", + " severity: Severity\n", + " message: Optional[str] = None\n", + " redacted_text: Optional[str] = None\n", + "\n", + "class ContentFilter:\n", + " def __init__(self, name, severity=Severity.BLOCK):\n", + " self.name = name\n", + " self.severity = severity\n", + " def evaluate(self, text):\n", + " raise NotImplementedError\n", + "\n", + "class RegexContentFilter(ContentFilter):\n", + " def __init__(self, name, patterns, severity=Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " self.patterns = [re.compile(p) for p in patterns]\n", + " def evaluate(self, text):\n", + " for pattern in self.patterns:\n", + " if pattern.search(text):\n", + " if self.severity == Severity.REDACT:\n", + " redacted = text\n", + " for p in self.patterns:\n", + " redacted = p.sub('[REDACTED]', redacted)\n", + " return FilterResult(False, self.name, self.severity,\n", + " f'Pattern matched: {pattern.pattern}', redacted)\n", + " return FilterResult(False, self.name, self.severity,\n", + " f'Pattern matched: {pattern.pattern}')\n", + " return FilterResult(True, self.name, self.severity)\n", + "\n", + "class KeywordContentFilter(ContentFilter):\n", + " def __init__(self, name, keywords, severity=Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " self.keywords = [kw.lower() for kw in keywords]\n", + " def evaluate(self, text):\n", + " text_lower = text.lower()\n", + " for keyword in self.keywords:\n", + " if keyword in text_lower:\n", + " return FilterResult(False, self.name, self.severity,\n", + " f\"Prohibited keyword: '{keyword}'\")\n", + " return FilterResult(True, self.name, self.severity)\n", + "\n", + "class FormatComplianceFilter(ContentFilter):\n", + " EXECUTION_PATTERNS = [\n", + " re.compile(r'\\b(run|execute|eval)\\s*\\(', re.IGNORECASE),\n", + " re.compile(r'```\\s*(bash|shell|sh)\\b', re.IGNORECASE),\n", + " re.compile(r'sudo\\s+\\w+', re.IGNORECASE),\n", + " ]\n", + " def __init__(self, name='format_compliance', severity=Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " def evaluate(self, text):\n", + " for pattern in self.EXECUTION_PATTERNS:\n", + " if pattern.search(text):\n", + " return FilterResult(False, self.name, self.severity,\n", + " 'Code execution instruction detected')\n", + " return FilterResult(True, self.name, self.severity)\n", + "\n", + "def run_filters(text, filters):\n", + " for f in filters:\n", + " result = f.evaluate(text)\n", + " if not result.passed:\n", + " return result\n", + " return None\n", + "\n", + "print('Content filter classes loaded.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "logger = logging.getLogger(__name__)\n", + "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s [%(levelname)s] %(message)s\", datefmt=\"%H:%M:%S\")\n", + "from strands.hooks import HookProvider, HookRegistry, AfterInvocationEvent\n", + "\n", + "# Import content filters from our shared module\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Helper: Modify the Last Assistant Message\n", + "\n", + "Output guardrails work by modifying the conversation history. These helpers replace or redact the last assistant message content." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def _replace_assistant_response(messages: list, new_text: str) -> None:\n", + " \"\"\"Replace the last assistant message content with new text.\"\"\"\n", + " if not messages:\n", + " return\n", + " for message in reversed(messages):\n", + " if message.get(\"role\") == \"assistant\":\n", + " message[\"content\"] = [{\"text\": new_text}]\n", + " return\n", + "\n", + "\n", + "def _redact_assistant_response(messages: list, redacted_text: str) -> None:\n", + " \"\"\"Redact specific patterns in the last assistant message.\n", + " \n", + " Unlike full replacement, redaction preserves the overall response structure\n", + " but replaces sensitive patterns with [REDACTED].\n", + " \"\"\"\n", + " if not messages:\n", + " return\n", + " for message in reversed(messages):\n", + " if message.get(\"role\") == \"assistant\":\n", + " new_content = []\n", + " for block in message.get(\"content\", []):\n", + " if \"text\" in block:\n", + " new_content.append({\"text\": redacted_text})\n", + " else:\n", + " new_content.append(block)\n", + " message[\"content\"] = new_content\n", + " return" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Output Guardrail: BLOCK Behavior\n", + "\n", + "The BLOCK behavior replaces the **entire** response with a safe fallback message when prohibited content is detected in the model's output." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define prohibited keywords that should never appear in model output\n", + "OUTPUT_PROHIBITED_KEYWORDS = [\n", + " \"confidential internal\",\n", + " \"classified information\",\n", + " \"trade secret\",\n", + " \"proprietary algorithm\",\n", + "]\n", + "\n", + "output_keyword_filter = KeywordContentFilter(\n", + " name=\"output_topic_blocker\",\n", + " keywords=OUTPUT_PROHIBITED_KEYWORDS,\n", + " severity=Severity.BLOCK,\n", + ")\n", + "\n", + "BLOCKED_RESPONSE_FALLBACK = (\n", + " \"I'm sorry, but I cannot provide that information. \"\n", + " \"The response was blocked by a content safety filter.\"\n", + ")\n", + "\n", + "\n", + "def output_guardrail_logic(messages: list) -> None:\n", + " \"\"\"Output guardrail that blocks responses containing prohibited content.\"\"\"\n", + " if not messages:\n", + " return\n", + "\n", + " # Find the last assistant message\n", + " last_assistant_text = None\n", + " for message in reversed(messages):\n", + " if message.get(\"role\") == \"assistant\":\n", + " for block in message.get(\"content\", []):\n", + " if \"text\" in block:\n", + " last_assistant_text = block[\"text\"]\n", + " break\n", + " break\n", + "\n", + " if not last_assistant_text:\n", + " return\n", + "\n", + " result = output_keyword_filter.evaluate(last_assistant_text)\n", + "\n", + " if not result.passed:\n", + " logger.warning(\n", + " f\"[OUTPUT GUARDRAIL] Blocked response. \"\n", + " f\"Filter: {result.filter_name}, Reason: {result.message}\"\n", + " )\n", + " _replace_assistant_response(messages, BLOCKED_RESPONSE_FALLBACK)\n", + " else:\n", + " logger.debug(\"[OUTPUT GUARDRAIL] Response passed keyword filter.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Output Guardrail: REDACT Behavior\n", + "\n", + "The REDACT behavior replaces only the **matched patterns** (like PII) with `[REDACTED]`, preserving the rest of the response. This is less disruptive than blocking the entire response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define PII patterns for output redaction\n", + "OUTPUT_PII_PATTERNS = [\n", + " r\"\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b\", # Email\n", + " r\"\\b\\d{3}[-.]?\\d{3}[-.]?\\d{4}\\b\", # Phone number\n", + " r\"\\b\\d{3}-\\d{2}-\\d{4}\\b\", # SSN\n", + "]\n", + "\n", + "pii_redaction_filter = RegexContentFilter(\n", + " name=\"output_pii_redactor\",\n", + " patterns=OUTPUT_PII_PATTERNS,\n", + " severity=Severity.REDACT,\n", + ")\n", + "\n", + "\n", + "def pii_output_guardrail_logic(messages: list) -> None:\n", + " \"\"\"Output guardrail that redacts PII from model responses.\"\"\"\n", + " if not messages:\n", + " return\n", + "\n", + " # Find the last assistant message text\n", + " last_assistant_text = None\n", + " for message in reversed(messages):\n", + " if message.get(\"role\") == \"assistant\":\n", + " for block in message.get(\"content\", []):\n", + " if \"text\" in block:\n", + " last_assistant_text = block[\"text\"]\n", + " break\n", + " break\n", + "\n", + " if not last_assistant_text:\n", + " return\n", + "\n", + " result = pii_redaction_filter.evaluate(last_assistant_text)\n", + "\n", + " if not result.passed:\n", + " logger.warning(\n", + " f\"[PII OUTPUT GUARDRAIL] Redacted PII from response. \"\n", + " f\"Filter: {result.filter_name}, Reason: {result.message}\"\n", + " )\n", + " _redact_assistant_response(messages, result.redacted_text)\n", + " else:\n", + " logger.debug(\"[PII OUTPUT GUARDRAIL] Response passed PII filter.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Combined Output Guardrail: Mixed Behaviors\n", + "\n", + "Compose multiple output filters with different behaviors:\n", + "1. **Keyword filter (BLOCK)** \u2014 replaces entire response if triggered\n", + "2. **Format compliance filter (BLOCK)** \u2014 blocks code execution instructions\n", + "3. **PII filter (REDACT)** \u2014 redacts matched patterns only\n", + "\n", + "BLOCK-severity violations are checked first. If none are found, REDACT filters clean up the response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "format_filter = FormatComplianceFilter(\n", + " name=\"output_format_compliance\",\n", + " severity=Severity.BLOCK,\n", + ")\n", + "\n", + "\n", + "def combined_output_guardrail_logic(messages: list) -> None:\n", + " \"\"\"Output guardrail that applies multiple filters with different behaviors.\"\"\"\n", + " if not messages:\n", + " return\n", + "\n", + " # Find the last assistant message text\n", + " response_text = None\n", + " for message in reversed(messages):\n", + " if message.get(\"role\") == \"assistant\":\n", + " for block in message.get(\"content\", []):\n", + " if \"text\" in block:\n", + " response_text = block[\"text\"]\n", + " break\n", + " break\n", + "\n", + " if not response_text:\n", + " return\n", + "\n", + " # Phase 1: Check BLOCK-severity filters first\n", + " block_filters = [output_keyword_filter, format_filter]\n", + " violation = run_filters(response_text, block_filters)\n", + "\n", + " if violation is not None:\n", + " logger.warning(\n", + " f\"[COMBINED OUTPUT GUARDRAIL] Blocked response. \"\n", + " f\"Filter: {violation.filter_name}, Reason: {violation.message}\"\n", + " )\n", + " _replace_assistant_response(messages, BLOCKED_RESPONSE_FALLBACK)\n", + " return\n", + "\n", + " # Phase 2: Apply REDACT-severity filters (PII redaction)\n", + " redact_result = pii_redaction_filter.evaluate(response_text)\n", + "\n", + " if not redact_result.passed:\n", + " logger.info(\n", + " f\"[COMBINED OUTPUT GUARDRAIL] Redacted content from response. \"\n", + " f\"Filter: {redact_result.filter_name}, Reason: {redact_result.message}\"\n", + " )\n", + " _redact_assistant_response(messages, redact_result.redacted_text)\n", + " else:\n", + " logger.debug(\"[COMBINED OUTPUT GUARDRAIL] Response passed all output filters.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing the Output Guardrails\n", + "\n", + "We test output guardrails using mock messages that simulate the conversation state after model inference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test BLOCK behavior\n", + "print(\"Test: BLOCK \u2014 Prohibited keyword in output\")\n", + "mock_messages = [{\"role\": \"assistant\", \"content\": [{\"text\": \"Here is the confidential internal document you requested...\"}]}]\n", + "output_guardrail_logic(mock_messages)\n", + "print(f\" Result: '{mock_messages[0]['content'][0]['text'][:60]}...'\")\n", + "assert \"cannot provide\" in mock_messages[0][\"content\"][0][\"text\"]\n", + "\n", + "# Test REDACT behavior\n", + "print(\"\\nTest: REDACT \u2014 PII redaction in output\")\n", + "mock_messages = [{\"role\": \"assistant\", \"content\": [{\"text\": \"The user's email is john.doe@example.com and phone is 555-123-4567.\"}]}]\n", + "pii_output_guardrail_logic(mock_messages)\n", + "redacted = mock_messages[0][\"content\"][0][\"text\"]\n", + "print(f\" Result: '{redacted}'\")\n", + "assert \"[REDACTED]\" in redacted\n", + "assert \"john.doe@example.com\" not in redacted\n", + "\n", + "# Test clean content passes\n", + "print(\"\\nTest: Clean content passes through\")\n", + "mock_messages = [{\"role\": \"assistant\", \"content\": [{\"text\": \"The capital of France is Paris.\"}]}]\n", + "output_guardrail_logic(mock_messages)\n", + "print(f\" Result: Message unchanged (passed)\")\n", + "assert mock_messages[0][\"content\"][0][\"text\"] == \"The capital of France is Paris.\"\n", + "\n", + "# Test combined: BLOCK takes priority\n", + "print(\"\\nTest: Combined \u2014 BLOCK takes priority over REDACT\")\n", + "mock_messages = [{\"role\": \"assistant\", \"content\": [{\"text\": \"Run this command: sudo rm -rf /tmp/cache to fix the issue.\"}]}]\n", + "combined_output_guardrail_logic(mock_messages)\n", + "print(f\" Result: '{mock_messages[0]['content'][0]['text'][:60]}...'\")\n", + "assert \"cannot provide\" in mock_messages[0][\"content\"][0][\"text\"]\n", + "\n", + "# Test combined: REDACT when no BLOCK violation\n", + "print(\"\\nTest: Combined \u2014 REDACT when no BLOCK violation\")\n", + "mock_messages = [{\"role\": \"assistant\", \"content\": [{\"text\": \"Please contact support at help@company.com for assistance.\"}]}]\n", + "combined_output_guardrail_logic(mock_messages)\n", + "redacted = mock_messages[0][\"content\"][0][\"text\"]\n", + "print(f\" Result: '{redacted}'\")\n", + "assert \"[REDACTED]\" in redacted\n", + "assert \"help@company.com\" not in redacted" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## HookProvider: Wrapping Output Guardrails for Agent Registration\n", + "\n", + "In strands-agents 1.40.0, hooks are registered via `HookProvider` classes. The `AfterInvocationEvent` gives access to `event.agent.messages` which contains the full conversation including the assistant's response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class OutputGuardrailHook(HookProvider):\n", + " \"\"\"HookProvider that applies the combined output guardrail after each invocation.\"\"\"\n", + "\n", + " def register_hooks(self, registry: HookRegistry) -> None:\n", + " registry.add_callback(AfterInvocationEvent, self._validate_output)\n", + "\n", + " def _validate_output(self, event: AfterInvocationEvent) -> None:\n", + " messages = event.agent.messages\n", + " combined_output_guardrail_logic(messages)\n", + "\n", + "\n", + "print(\"OutputGuardrailHook defined successfully.\")\n", + "print(\"Register with: Agent(hooks=[OutputGuardrailHook()])\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Attaching to a Live Agent\n", + "\n", + "Register output guardrails using the `hooks` parameter with `HookProvider` instances:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from strands import Agent\n", + " from strands.models.bedrock import BedrockModel\n", + "\n", + " model = BedrockModel(model_id=\"us.anthropic.claude-sonnet-4-5-20250929-v1:0\")\n", + "\n", + " agent = Agent(\n", + " model=model,\n", + " system_prompt=\"You are a helpful assistant.\",\n", + " hooks=[OutputGuardrailHook()],\n", + " )\n", + "\n", + " print(\"Agent created with output guardrail attached.\")\n", + " print(\"Testing with a safe request...\")\n", + " response = agent(\"What is the capital of France?\")\n", + " print(f\" Response: {response}\")\n", + "\n", + " print(\"\\nTesting with a request that might produce PII...\")\n", + " response = agent(\"Generate a fake contact card with name, email, and phone number.\")\n", + " print(f\" Response: {response}\")\n", + "\n", + "except Exception as e:\n", + " print(f\"Skipping live agent demo: {e}\")\n", + " print(\"(This is expected if no model provider is configured)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "In this notebook you learned:\n", + "1. **BLOCK behavior** \u2014 replace the entire response when prohibited content is detected\n", + "2. **REDACT behavior** \u2014 replace only matched patterns (PII) while preserving the rest\n", + "3. How to compose multiple output filters with mixed severity levels\n", + "4. How to test output guardrails with mock messages\n", + "5. How to wrap guardrail logic in a `HookProvider` for agent registration\n", + "6. How to attach output guardrails to a live agent\n", + "\n", + "**Next Steps:** See `03_content_filters.ipynb` to learn how to build custom content filter classes." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/python/01-learn/18-input-output-guardrails/03_content_filters.ipynb b/python/01-learn/18-input-output-guardrails/03_content_filters.ipynb new file mode 100644 index 00000000..2bd75278 --- /dev/null +++ b/python/01-learn/18-input-output-guardrails/03_content_filters.ipynb @@ -0,0 +1,426 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Custom Content Filters\n", + "\n", + "This notebook demonstrates how to build **reusable content filters** that can be composed into guardrail pipelines.\n", + "\n", + "Key concepts:\n", + "- **Severity levels** control what happens when a filter matches: BLOCK, WARN, or REDACT\n", + "- **Filters are composable** \u2014 run multiple filters in sequence and stop at the first violation\n", + "- The **base class pattern** makes it easy to add new filter types\n", + "\n", + "This module is imported by the other notebooks as a shared dependency.\n", + "\n", + "## Filter Pipeline Architecture\n", + "\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install required packages\n", + "!pip install strands-agents strands-agents-tools --upgrade -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dataclasses import dataclass\n", + "from enum import Enum\n", + "from typing import Optional\n", + "import re" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Models: Severity and FilterResult\n", + "\n", + "Every filter evaluation returns a `FilterResult` that tells the guardrail what action to take:\n", + "- `Severity.BLOCK` \u2014 reject the entire request/response\n", + "- `Severity.WARN` \u2014 log a warning but allow through\n", + "- `Severity.REDACT` \u2014 remove/replace the matched content" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Severity(Enum):\n", + " \"\"\"Action to take when a filter matches.\"\"\"\n", + " BLOCK = \"block\"\n", + " WARN = \"warn\"\n", + " REDACT = \"redact\"\n", + "\n", + "\n", + "@dataclass\n", + "class FilterResult:\n", + " \"\"\"Result of a content filter evaluation.\"\"\"\n", + " passed: bool\n", + " filter_name: str\n", + " severity: Severity\n", + " message: Optional[str] = None\n", + " redacted_text: Optional[str] = None" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Base Class: ContentFilter\n", + "\n", + "All content filters inherit from this base class. Subclasses must override the `evaluate()` method to implement custom filtering logic." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class ContentFilter:\n", + " \"\"\"Base class for content filters.\n", + "\n", + " Subclasses must override the `evaluate()` method to implement\n", + " custom filtering logic.\n", + " \"\"\"\n", + "\n", + " def __init__(self, name: str, severity: Severity = Severity.BLOCK):\n", + " self.name = name\n", + " self.severity = severity\n", + "\n", + " def evaluate(self, text: str) -> FilterResult:\n", + " \"\"\"Evaluate text against this filter. Override in subclasses.\"\"\"\n", + " raise NotImplementedError(\"Subclasses must implement evaluate()\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## RegexContentFilter\n", + "\n", + "Uses regex pattern matching to detect structured sensitive information like emails, phone numbers, and SSNs. When severity is `REDACT`, matched patterns are replaced with `[REDACTED]`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class RegexContentFilter(ContentFilter):\n", + " \"\"\"Content filter using regex pattern matching.\"\"\"\n", + "\n", + " def __init__(self, name: str, patterns: list[str], severity: Severity = Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " self.patterns = [re.compile(p) for p in patterns]\n", + "\n", + " def evaluate(self, text: str) -> FilterResult:\n", + " for pattern in self.patterns:\n", + " if pattern.search(text):\n", + " if self.severity == Severity.REDACT:\n", + " redacted = text\n", + " for p in self.patterns:\n", + " redacted = p.sub(\"[REDACTED]\", redacted)\n", + " return FilterResult(\n", + " passed=False,\n", + " filter_name=self.name,\n", + " severity=self.severity,\n", + " message=f\"Pattern matched: {pattern.pattern}\",\n", + " redacted_text=redacted,\n", + " )\n", + " return FilterResult(\n", + " passed=False,\n", + " filter_name=self.name,\n", + " severity=self.severity,\n", + " message=f\"Pattern matched: {pattern.pattern}\",\n", + " )\n", + " return FilterResult(passed=True, filter_name=self.name, severity=self.severity)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## KeywordContentFilter\n", + "\n", + "Performs case-insensitive matching against a list of prohibited keywords or phrases. Useful for blocking off-topic requests or detecting harmful content categories." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class KeywordContentFilter(ContentFilter):\n", + " \"\"\"Content filter using keyword-based topic detection.\"\"\"\n", + "\n", + " def __init__(self, name: str, keywords: list[str], severity: Severity = Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " self.keywords = [kw.lower() for kw in keywords]\n", + "\n", + " def evaluate(self, text: str) -> FilterResult:\n", + " text_lower = text.lower()\n", + " for keyword in self.keywords:\n", + " if keyword in text_lower:\n", + " return FilterResult(\n", + " passed=False,\n", + " filter_name=self.name,\n", + " severity=self.severity,\n", + " message=f\"Prohibited keyword detected: '{keyword}'\",\n", + " )\n", + " return FilterResult(passed=True, filter_name=self.name, severity=self.severity)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FormatComplianceFilter\n", + "\n", + "Validates output format compliance \u2014 ensures responses don't include code execution instructions or other format violations. This is an example of a domain-specific filter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class FormatComplianceFilter(ContentFilter):\n", + " \"\"\"Content filter that validates output format compliance.\"\"\"\n", + "\n", + " EXECUTION_PATTERNS = [\n", + " re.compile(r\"\\b(run|execute|eval)\\s*\\(\", re.IGNORECASE),\n", + " re.compile(r\"```\\s*(bash|shell|sh)\\b\", re.IGNORECASE),\n", + " re.compile(r\"\\$\\s*\\w+\"), # Shell variable references\n", + " re.compile(r\"sudo\\s+\\w+\", re.IGNORECASE),\n", + " ]\n", + "\n", + " def __init__(self, name: str = \"format_compliance\", severity: Severity = Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + "\n", + " def evaluate(self, text: str) -> FilterResult:\n", + " for pattern in self.EXECUTION_PATTERNS:\n", + " if pattern.search(text):\n", + " return FilterResult(\n", + " passed=False,\n", + " filter_name=self.name,\n", + " severity=self.severity,\n", + " message=f\"Output contains code execution instruction: {pattern.pattern}\",\n", + " )\n", + " return FilterResult(passed=True, filter_name=self.name, severity=self.severity)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pipeline Helper: `run_filters()`\n", + "\n", + "Evaluates text against a list of filters in order. The first filter that fails has its result returned immediately. If all filters pass, returns `None`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def run_filters(text: str, filters: list[ContentFilter]) -> Optional[FilterResult]:\n", + " \"\"\"Evaluate text against a list of filters, returning the first violation.\n", + "\n", + " Args:\n", + " text: The text to evaluate.\n", + " filters: Ordered list of content filters to apply.\n", + "\n", + " Returns:\n", + " The FilterResult of the first failing filter, or None if all pass.\n", + " \"\"\"\n", + " for content_filter in filters:\n", + " result = content_filter.evaluate(text)\n", + " if not result.passed:\n", + " return result\n", + " return None" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demo: RegexContentFilter (PII Detection with REDACT)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pii_filter = RegexContentFilter(\n", + " name=\"pii_detector\",\n", + " patterns=[\n", + " r\"\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b\", # Email\n", + " r\"\\b\\d{3}[-.]?\\d{3}[-.]?\\d{4}\\b\", # Phone number\n", + " r\"\\b\\d{3}-\\d{2}-\\d{4}\\b\", # SSN\n", + " ],\n", + " severity=Severity.REDACT,\n", + ")\n", + "\n", + "test_text = \"Contact me at john@example.com or call 555-123-4567\"\n", + "result = pii_filter.evaluate(test_text)\n", + "print(f\"Input: {test_text}\")\n", + "print(f\"Passed: {result.passed}\")\n", + "print(f\"Severity: {result.severity.value}\")\n", + "print(f\"Message: {result.message}\")\n", + "print(f\"Redacted: {result.redacted_text}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demo: KeywordContentFilter (Topic Blocking)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "topic_filter = KeywordContentFilter(\n", + " name=\"topic_blocker\",\n", + " keywords=[\"hack\", \"exploit\", \"bypass security\"],\n", + " severity=Severity.BLOCK,\n", + ")\n", + "\n", + "# Blocked text\n", + "blocked_text = \"How do I hack into a system?\"\n", + "result = topic_filter.evaluate(blocked_text)\n", + "print(f\"Input: {blocked_text}\")\n", + "print(f\"Passed: {result.passed}\")\n", + "print(f\"Message: {result.message}\")\n", + "\n", + "# Safe text\n", + "safe_text = \"How do I set up a firewall?\"\n", + "result = topic_filter.evaluate(safe_text)\n", + "print(f\"\\nInput: {safe_text}\")\n", + "print(f\"Passed: {result.passed}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demo: FormatComplianceFilter (Output Validation)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "format_filter = FormatComplianceFilter()\n", + "\n", + "# Unsafe output with code execution instruction\n", + "unsafe_output = \"To fix this, run: sudo rm -rf /tmp/cache\"\n", + "result = format_filter.evaluate(unsafe_output)\n", + "print(f\"Input: {unsafe_output}\")\n", + "print(f\"Passed: {result.passed}\")\n", + "print(f\"Message: {result.message}\")\n", + "\n", + "# Safe output\n", + "safe_output = \"The recommended approach is to clear the cache manually.\"\n", + "result = format_filter.evaluate(safe_output)\n", + "print(f\"\\nInput: {safe_output}\")\n", + "print(f\"Passed: {result.passed}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demo: Pipeline Evaluation with `run_filters()`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "filters = [topic_filter, pii_filter, format_filter]\n", + "\n", + "# Text that violates the keyword filter (first in the list)\n", + "violation_text = \"How to exploit a vulnerability? Email me at attacker@evil.com\"\n", + "pipeline_result = run_filters(violation_text, filters)\n", + "print(f\"Input: {violation_text}\")\n", + "print(f\"First violation from: {pipeline_result.filter_name}\")\n", + "print(f\"Message: {pipeline_result.message}\")\n", + "\n", + "# Clean text that passes all filters\n", + "clean_text = \"What are best practices for application security?\"\n", + "pipeline_result = run_filters(clean_text, filters)\n", + "print(f\"\\nInput: {clean_text}\")\n", + "print(f\"Result: {'All filters passed' if pipeline_result is None else 'Violation found'}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "In this notebook you learned:\n", + "1. The `Severity` enum and `FilterResult` dataclass that drive guardrail behavior\n", + "2. The `ContentFilter` base class pattern for building custom filters\n", + "3. `RegexContentFilter` \u2014 pattern-based detection with optional redaction\n", + "4. `KeywordContentFilter` \u2014 case-insensitive keyword/phrase blocking\n", + "5. `FormatComplianceFilter` \u2014 domain-specific output validation\n", + "6. `run_filters()` \u2014 composing filters into a pipeline\n", + "\n", + "**Next Steps:** See `04_guardrail_plugin.ipynb` to learn how to package guardrails as a reusable Plugin." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/python/01-learn/18-input-output-guardrails/04_guardrail_plugin.ipynb b/python/01-learn/18-input-output-guardrails/04_guardrail_plugin.ipynb new file mode 100644 index 00000000..dab05acc --- /dev/null +++ b/python/01-learn/18-input-output-guardrails/04_guardrail_plugin.ipynb @@ -0,0 +1,579 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Guardrail Plugin\n", + "\n", + "This notebook demonstrates how to build a **full-featured guardrail plugin** using the Strands Agents SDK's `HookProvider` pattern.\n", + "\n", + "The plugin bundles input validation, output validation, and tool call enforcement into a single reusable component that can be attached to any agent.\n", + "\n", + "## Plugin Architecture\n", + "\n", + "
\n", + " \n", + "
\n", + "\n", + "Key concepts:\n", + "- Implement `HookProvider` to create a reusable guardrail component\n", + "- Register callbacks for `BeforeInvocationEvent`, `AfterInvocationEvent`, and `BeforeToolCallEvent`\n", + "- Combine input, output, and tool call guardrails in one provider\n", + "- Configure filters, tool allowlists, and error handling via constructor\n", + "\n", + "**Registration pattern:**\n", + "```python\n", + "agent = Agent(hooks=[GuardrailPlugin(input_filters=[...], output_filters=[...])])\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install required packages\n", + "!pip install strands-agents strands-agents-tools --upgrade -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Content filter classes \u2014 inline definitions (no external file needed)\n", + "from dataclasses import dataclass\n", + "from enum import Enum\n", + "from typing import Optional\n", + "import re\n", + "\n", + "class Severity(Enum):\n", + " BLOCK = 'block'\n", + " WARN = 'warn'\n", + " REDACT = 'redact'\n", + "\n", + "@dataclass\n", + "class FilterResult:\n", + " passed: bool\n", + " filter_name: str\n", + " severity: Severity\n", + " message: Optional[str] = None\n", + " redacted_text: Optional[str] = None\n", + "\n", + "class ContentFilter:\n", + " def __init__(self, name, severity=Severity.BLOCK):\n", + " self.name = name\n", + " self.severity = severity\n", + " def evaluate(self, text):\n", + " raise NotImplementedError\n", + "\n", + "class RegexContentFilter(ContentFilter):\n", + " def __init__(self, name, patterns, severity=Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " self.patterns = [re.compile(p) for p in patterns]\n", + " def evaluate(self, text):\n", + " for pattern in self.patterns:\n", + " if pattern.search(text):\n", + " if self.severity == Severity.REDACT:\n", + " redacted = text\n", + " for p in self.patterns:\n", + " redacted = p.sub('[REDACTED]', redacted)\n", + " return FilterResult(False, self.name, self.severity,\n", + " f'Pattern matched: {pattern.pattern}', redacted)\n", + " return FilterResult(False, self.name, self.severity,\n", + " f'Pattern matched: {pattern.pattern}')\n", + " return FilterResult(True, self.name, self.severity)\n", + "\n", + "class KeywordContentFilter(ContentFilter):\n", + " def __init__(self, name, keywords, severity=Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " self.keywords = [kw.lower() for kw in keywords]\n", + " def evaluate(self, text):\n", + " text_lower = text.lower()\n", + " for keyword in self.keywords:\n", + " if keyword in text_lower:\n", + " return FilterResult(False, self.name, self.severity,\n", + " f\"Prohibited keyword: '{keyword}'\")\n", + " return FilterResult(True, self.name, self.severity)\n", + "\n", + "class FormatComplianceFilter(ContentFilter):\n", + " EXECUTION_PATTERNS = [\n", + " re.compile(r'\\b(run|execute|eval)\\s*\\(', re.IGNORECASE),\n", + " re.compile(r'```\\s*(bash|shell|sh)\\b', re.IGNORECASE),\n", + " re.compile(r'sudo\\s+\\w+', re.IGNORECASE),\n", + " ]\n", + " def __init__(self, name='format_compliance', severity=Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " def evaluate(self, text):\n", + " for pattern in self.EXECUTION_PATTERNS:\n", + " if pattern.search(text):\n", + " return FilterResult(False, self.name, self.severity,\n", + " 'Code execution instruction detected')\n", + " return FilterResult(True, self.name, self.severity)\n", + "\n", + "def run_filters(text, filters):\n", + " for f in filters:\n", + " result = f.evaluate(text)\n", + " if not result.passed:\n", + " return result\n", + " return None\n", + "\n", + "print('Content filter classes loaded.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "logger = logging.getLogger(__name__)\n", + "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s [%(levelname)s] %(message)s\", datefmt=\"%H:%M:%S\")\n", + "from typing import Optional\n", + "\n", + "from strands.hooks import (\n", + " HookProvider,\n", + " HookRegistry,\n", + " BeforeInvocationEvent,\n", + " AfterInvocationEvent,\n", + " BeforeToolCallEvent,\n", + ")\n", + "\n", + "# Import content filters from our shared module\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The GuardrailPlugin Class\n", + "\n", + "This plugin bundles three types of validation:\n", + "1. **Input validation** \u2014 inspects user messages before model inference\n", + "2. **Output validation** \u2014 inspects model responses before returning to the user\n", + "3. **Tool call validation** \u2014 enforces a tool allowlist before tool execution\n", + "\n", + "Configuration options:\n", + "- `input_filters`: List of ContentFilter instances for user input\n", + "- `output_filters`: List of ContentFilter instances for model output\n", + "- `tool_allowlist`: List of allowed tool names (None = all tools allowed)\n", + "- `fail_open`: If True, filter exceptions allow the request through" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class GuardrailPlugin(HookProvider):\n", + " \"\"\"A reusable HookProvider that applies content guardrails to agent input, output, and tool calls.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " input_filters: list[ContentFilter] | None = None,\n", + " output_filters: list[ContentFilter] | None = None,\n", + " tool_allowlist: list[str] | None = None,\n", + " fail_open: bool = True,\n", + " ):\n", + " self.input_filters = input_filters or []\n", + " self.output_filters = output_filters or []\n", + " self.tool_allowlist = tool_allowlist\n", + " self.fail_open = fail_open\n", + "\n", + " def register_hooks(self, registry: HookRegistry) -> None:\n", + " \"\"\"Register all guardrail callbacks with the hook registry.\"\"\"\n", + " registry.add_callback(BeforeInvocationEvent, self._validate_input)\n", + " registry.add_callback(AfterInvocationEvent, self._validate_output)\n", + " registry.add_callback(BeforeToolCallEvent, self._validate_tool_call)\n", + "\n", + " def _validate_input(self, event: BeforeInvocationEvent) -> None:\n", + " \"\"\"Validate user input before model inference.\"\"\"\n", + " messages = event.agent.messages\n", + " text = self._extract_input_text(messages)\n", + " if not text:\n", + " return\n", + "\n", + " for content_filter in self.input_filters:\n", + " try:\n", + " result = content_filter.evaluate(text)\n", + " if not result.passed:\n", + " self._handle_input_violation(messages, result, text)\n", + " return\n", + " except Exception as e:\n", + " if not self._handle_filter_error(e, content_filter, \"input\"):\n", + " self._block_input(messages, content_filter.name)\n", + " return\n", + "\n", + " self._log_decision(\"input\", \"all_filters\", \"passed\", text)\n", + "\n", + " def _validate_output(self, event: AfterInvocationEvent) -> None:\n", + " \"\"\"Validate model output before returning to user.\"\"\"\n", + " messages = event.agent.messages\n", + " text = self._extract_output_text(messages)\n", + " if not text:\n", + " return\n", + "\n", + " for content_filter in self.output_filters:\n", + " try:\n", + " result = content_filter.evaluate(text)\n", + " if not result.passed:\n", + " self._handle_output_violation(messages, result, text)\n", + " return\n", + " except Exception as e:\n", + " if not self._handle_filter_error(e, content_filter, \"output\"):\n", + " self._block_output(messages, content_filter.name)\n", + " return\n", + "\n", + " self._log_decision(\"output\", \"all_filters\", \"passed\", text)\n", + "\n", + " def _validate_tool_call(self, event: BeforeToolCallEvent) -> None:\n", + " \"\"\"Validate tool calls against the configured allowlist.\"\"\"\n", + " if self.tool_allowlist is None:\n", + " return\n", + "\n", + " tool_name = event.tool_use.get(\"name\", \"\")\n", + "\n", + " if tool_name not in self.tool_allowlist:\n", + " reason = f\"Tool '{tool_name}' is not in the allowed tools list.\"\n", + " event.cancel_tool = reason\n", + " self._log_decision(\"tool_call\", tool_name, \"blocked\", tool_name)\n", + " else:\n", + " self._log_decision(\"tool_call\", tool_name, \"passed\", tool_name)\n", + "\n", + " # --- Helper methods ---\n", + "\n", + " def _extract_input_text(self, messages: list[dict]) -> str:\n", + " if not messages:\n", + " return \"\"\n", + " last_message = messages[-1]\n", + " if last_message.get(\"role\") != \"user\":\n", + " return \"\"\n", + " text_parts = []\n", + " for block in last_message.get(\"content\", []):\n", + " if \"text\" in block:\n", + " text_parts.append(block[\"text\"])\n", + " return \" \".join(text_parts)\n", + "\n", + " def _extract_output_text(self, messages: list[dict]) -> str:\n", + " if not messages:\n", + " return \"\"\n", + " for message in reversed(messages):\n", + " if message.get(\"role\") == \"assistant\":\n", + " for block in message.get(\"content\", []):\n", + " if \"text\" in block:\n", + " return block[\"text\"]\n", + " return \"\"\n", + "\n", + " def _handle_input_violation(self, messages, result, original_text):\n", + " if result.severity == Severity.BLOCK:\n", + " self._log_decision(\"input\", result.filter_name, \"blocked\", original_text, result.message)\n", + " messages.clear()\n", + " messages.append({\n", + " \"role\": \"user\",\n", + " \"content\": [{\"text\": (\n", + " \"Respond only with: I cannot process that request. \"\n", + " \"The input was blocked by a content safety filter.\"\n", + " )}],\n", + " })\n", + " elif result.severity == Severity.REDACT:\n", + " self._log_decision(\"input\", result.filter_name, \"redacted\", original_text, result.message)\n", + " if messages and result.redacted_text:\n", + " last_message = messages[-1]\n", + " if last_message.get(\"role\") == \"user\":\n", + " last_message[\"content\"] = [{\"text\": result.redacted_text}]\n", + "\n", + " def _handle_output_violation(self, messages, result, original_text):\n", + " if result.severity == Severity.BLOCK:\n", + " self._log_decision(\"output\", result.filter_name, \"blocked\", original_text, result.message)\n", + " self._replace_assistant_response(\n", + " messages,\n", + " \"I'm sorry, but I cannot provide that information. \"\n", + " \"The response was blocked by a content safety filter.\",\n", + " )\n", + " elif result.severity == Severity.REDACT:\n", + " self._log_decision(\"output\", result.filter_name, \"redacted\", original_text, result.message)\n", + " if result.redacted_text:\n", + " self._replace_assistant_response(messages, result.redacted_text)\n", + "\n", + " def _block_input(self, messages, filter_name):\n", + " messages.clear()\n", + " messages.append({\n", + " \"role\": \"user\",\n", + " \"content\": [{\"text\": (\n", + " \"Respond only with: I cannot process that request. \"\n", + " \"An internal error occurred during content validation.\"\n", + " )}],\n", + " })\n", + "\n", + " def _block_output(self, messages, filter_name):\n", + " self._replace_assistant_response(\n", + " messages,\n", + " \"I'm sorry, but I cannot provide a response at this time. \"\n", + " \"An internal error occurred during content validation.\",\n", + " )\n", + "\n", + " def _replace_assistant_response(self, messages, new_text):\n", + " if not messages:\n", + " return\n", + " for message in reversed(messages):\n", + " if message.get(\"role\") == \"assistant\":\n", + " message[\"content\"] = [{\"text\": new_text}]\n", + " return\n", + "\n", + " def _handle_filter_error(self, error, content_filter, direction):\n", + " if self.fail_open:\n", + " logger.error(f\"Filter error in {direction} (fail-open): {error}\")\n", + " return True\n", + " else:\n", + " logger.error(f\"Filter error in {direction} (fail-closed): {error}\")\n", + " return False\n", + "\n", + " def _log_decision(self, direction, filter_name, action, content, message=None):\n", + " snippet = content[:50] if content else \"\"\n", + " log_msg = f\"[GUARDRAIL] direction={direction} filter={filter_name} action={action} snippet='{snippet}'\"\n", + " if message:\n", + " log_msg += f\" message='{message}'\"\n", + " if action == \"blocked\":\n", + " logger.warning(log_msg)\n", + " else:\n", + " logger.debug(log_msg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demo: Testing Plugin Methods Directly\n", + "\n", + "We can test the plugin's validation methods using mock events \u2014 no live model needed.\n", + "\n", + "For unit testing, we call the internal `_validate_*` methods directly with mock event objects." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Mock event classes for testing\n", + "class MockAgent:\n", + " def __init__(self, messages):\n", + " self.messages = messages\n", + "\n", + "class MockBeforeEvent:\n", + " def __init__(self, messages):\n", + " self.agent = MockAgent(messages)\n", + "\n", + "class MockAfterEvent:\n", + " def __init__(self, messages):\n", + " self.agent = MockAgent(messages)\n", + "\n", + "class MockToolEvent:\n", + " def __init__(self, tool_name, tool_input=None):\n", + " self.tool_use = {\"name\": tool_name, \"input\": tool_input or {}}\n", + " self.cancel_tool = None\n", + "\n", + "\n", + "# Configure the plugin\n", + "plugin = GuardrailPlugin(\n", + " input_filters=[\n", + " KeywordContentFilter(\"prohibited_topics\", [\"hack\", \"exploit\"], Severity.BLOCK),\n", + " RegexContentFilter(\"input_pii\", [r\"\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b\"], Severity.BLOCK),\n", + " ],\n", + " output_filters=[\n", + " FormatComplianceFilter(\"output_format\", Severity.BLOCK),\n", + " RegexContentFilter(\"output_pii\", [r\"\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b\"], Severity.REDACT),\n", + " ],\n", + " tool_allowlist=[\"calculator\", \"web_search\", \"file_reader\"],\n", + " fail_open=True,\n", + ")\n", + "\n", + "print(f\"Plugin type: {type(plugin).__name__}\")\n", + "print(f\"Input filters: {[f.name for f in plugin.input_filters]}\")\n", + "print(f\"Output filters: {[f.name for f in plugin.output_filters]}\")\n", + "print(f\"Tool allowlist: {plugin.tool_allowlist}\")\n", + "print(f\"Fail-open: {plugin.fail_open}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Test input BLOCK\n", + "print(\"Test: Input BLOCK \u2014 prohibited keyword\")\n", + "messages = [{\"role\": \"user\", \"content\": [{\"text\": \"How do I hack a server?\"}]}]\n", + "event = MockBeforeEvent(messages)\n", + "plugin._validate_input(event)\n", + "print(f\" Result: '{event.agent.messages[0]['content'][0]['text'][:60]}...'\")\n", + "assert \"cannot process\" in event.agent.messages[0][\"content\"][0][\"text\"]\n", + "\n", + "# Test input PASS\n", + "print(\"\\nTest: Input PASS \u2014 clean content\")\n", + "messages = [{\"role\": \"user\", \"content\": [{\"text\": \"What is cloud computing?\"}]}]\n", + "event = MockBeforeEvent(messages)\n", + "plugin._validate_input(event)\n", + "print(f\" Result: Message unchanged (passed)\")\n", + "assert event.agent.messages[0][\"content\"][0][\"text\"] == \"What is cloud computing?\"\n", + "\n", + "# Test output REDACT\n", + "print(\"\\nTest: Output REDACT \u2014 PII in response\")\n", + "messages = [{\"role\": \"assistant\", \"content\": [{\"text\": \"Contact us at support@company.com for help.\"}]}]\n", + "event = MockAfterEvent(messages)\n", + "plugin._validate_output(event)\n", + "result_text = event.agent.messages[0][\"content\"][0][\"text\"]\n", + "print(f\" Result: '{result_text}'\")\n", + "assert \"[REDACTED]\" in result_text\n", + "\n", + "# Test tool PASS\n", + "print(\"\\nTest: Tool PASS \u2014 allowed tool\")\n", + "event = MockToolEvent(\"calculator\")\n", + "plugin._validate_tool_call(event)\n", + "print(f\" Result: Allowed (cancel_tool={event.cancel_tool})\")\n", + "assert event.cancel_tool is None\n", + "\n", + "# Test tool BLOCK\n", + "print(\"\\nTest: Tool BLOCK \u2014 disallowed tool\")\n", + "event = MockToolEvent(\"shell_execute\")\n", + "plugin._validate_tool_call(event)\n", + "print(f\" Result: Blocked (cancel_tool='{event.cancel_tool}')\")\n", + "assert event.cancel_tool is not None" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demo: Fail-Open vs Fail-Closed Behavior" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class BrokenFilter(ContentFilter):\n", + " \"\"\"A filter that always raises an exception.\"\"\"\n", + " def evaluate(self, text):\n", + " raise RuntimeError(\"Simulated filter failure!\")\n", + "\n", + "\n", + "# Fail-open: broken filter doesn't crash\n", + "print(\"Test: Fail-open \u2014 broken filter allows request through\")\n", + "fail_open_plugin = GuardrailPlugin(\n", + " input_filters=[BrokenFilter(\"broken\", Severity.BLOCK)],\n", + " fail_open=True,\n", + ")\n", + "messages = [{\"role\": \"user\", \"content\": [{\"text\": \"Hello world\"}]}]\n", + "event = MockBeforeEvent(messages)\n", + "fail_open_plugin._validate_input(event)\n", + "print(f\" Result: Message unchanged (error caught)\")\n", + "assert event.agent.messages[0][\"content\"][0][\"text\"] == \"Hello world\"\n", + "\n", + "# Fail-closed: broken filter blocks request\n", + "print(\"\\nTest: Fail-closed \u2014 broken filter blocks request\")\n", + "fail_closed_plugin = GuardrailPlugin(\n", + " input_filters=[BrokenFilter(\"broken\", Severity.BLOCK)],\n", + " fail_open=False,\n", + ")\n", + "messages = [{\"role\": \"user\", \"content\": [{\"text\": \"Hello world\"}]}]\n", + "event = MockBeforeEvent(messages)\n", + "fail_closed_plugin._validate_input(event)\n", + "print(f\" Result: '{event.agent.messages[0]['content'][0]['text'][:60]}...'\")\n", + "assert \"cannot process\" in event.agent.messages[0][\"content\"][0][\"text\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Attaching to a Live Agent\n", + "\n", + "Register the plugin using the `hooks` parameter:\n", + "\n", + "```python\n", + "agent = Agent(\n", + " system_prompt=\"You are a helpful assistant.\",\n", + " hooks=[plugin],\n", + ")\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from strands import Agent\n", + " from strands.models.bedrock import BedrockModel\n", + "\n", + " model = BedrockModel(model_id=\"us.anthropic.claude-sonnet-4-5-20250929-v1:0\")\n", + "\n", + " agent = Agent(\n", + " model=model,\n", + " system_prompt=\"You are a helpful assistant.\",\n", + " hooks=[plugin],\n", + " )\n", + "\n", + " print(\"Agent created with GuardrailPlugin attached.\")\n", + " print(\"Testing with a safe request...\")\n", + " response = agent(\"What is the capital of France?\")\n", + " print(f\" Response: {response}\")\n", + "\n", + " print(\"\\nTesting with a prohibited request...\")\n", + " response = agent(\"How do I hack into a system?\")\n", + " print(f\" Response: {response}\")\n", + "\n", + "except Exception as e:\n", + " print(f\"Skipping live agent demo: {e}\")\n", + " print(\"(This is expected if no model provider is configured)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "In this notebook you learned:\n", + "1. How to implement `HookProvider` for reusable guardrail components\n", + "2. Using `register_hooks` to register typed callbacks for multiple event types\n", + "3. Combining input, output, and tool call validation in one provider\n", + "4. Configuring fail-open vs fail-closed error handling\n", + "5. Testing plugin methods with mock events\n", + "\n", + "**Next Steps:** See `05_tool_call_validation.ipynb` to learn about advanced tool call guardrail patterns." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/python/01-learn/18-input-output-guardrails/05_tool_call_validation.ipynb b/python/01-learn/18-input-output-guardrails/05_tool_call_validation.ipynb new file mode 100644 index 00000000..6dc43aae --- /dev/null +++ b/python/01-learn/18-input-output-guardrails/05_tool_call_validation.ipynb @@ -0,0 +1,506 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tool Call Validation\n", + "\n", + "This notebook demonstrates how to implement **tool call guardrails** using the Strands Agents SDK's `BeforeToolCallEvent` hook.\n", + "\n", + "Tool call guardrails inspect tool names and arguments **before execution**, allowing you to:\n", + "- Enforce **allowlists** (only listed tools can run)\n", + "- Enforce **blocklists** (specific tools are blocked)\n", + "- Validate **arguments** for dangerous inputs (sensitive paths, dangerous commands)\n", + "- Log all tool call decisions for audit\n", + "\n", + "**Key concept:** Use `BeforeToolCallEvent` via a `HookProvider` to intercept tool calls. Access `event.tool_use` dict with keys `name`, `toolUseId`, `input`. Set `event.cancel_tool = \"reason\"` to block a tool call." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install required packages\n", + "!pip install strands-agents strands-agents-tools --upgrade -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "from typing import Optional\n", + "\n", + "from strands.hooks import HookProvider, HookRegistry, BeforeToolCallEvent\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "logging.basicConfig(level=logging.DEBUG, format=\"%(asctime)s [%(levelname)s] %(name)s - %(message)s\", datefmt=\"%H:%M:%S\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pattern 1: Allowlist Enforcement\n", + "\n", + "Only tools whose names appear in the allowlist are permitted to execute. All other tools are blocked with a descriptive reason." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def allowlist_validate(event, allowed_tools: list[str]) -> None:\n", + " \"\"\"Validate a tool call against an allowlist.\"\"\"\n", + " tool_name = event.tool_use.get(\"name\", \"\")\n", + "\n", + " if tool_name not in allowed_tools:\n", + " reason = f\"Tool '{tool_name}' is not permitted. Allowed tools: {allowed_tools}\"\n", + " event.cancel_tool = reason\n", + " logger.warning(f\"[TOOL GUARDRAIL] BLOCKED tool='{tool_name}' reason='not in allowlist'\")\n", + " else:\n", + " logger.debug(f\"[TOOL GUARDRAIL] ALLOWED tool='{tool_name}' reason='in allowlist'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pattern 2: Blocklist Enforcement\n", + "\n", + "Tools in the blocklist are blocked. All other tools are allowed. This is the inverse of allowlist — useful when you want to restrict a few dangerous tools but allow everything else." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def blocklist_validate(event, blocked_tools: list[str]) -> None:\n", + " \"\"\"Validate a tool call against a blocklist.\"\"\"\n", + " tool_name = event.tool_use.get(\"name\", \"\")\n", + "\n", + " if tool_name in blocked_tools:\n", + " reason = f\"Tool '{tool_name}' is explicitly blocked.\"\n", + " event.cancel_tool = reason\n", + " logger.warning(f\"[TOOL GUARDRAIL] BLOCKED tool='{tool_name}' reason='in blocklist'\")\n", + " else:\n", + " logger.debug(f\"[TOOL GUARDRAIL] ALLOWED tool='{tool_name}' reason='not in blocklist'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pattern 3: Argument Validation\n", + "\n", + "Inspects tool arguments to detect dangerous patterns:\n", + "- File operations targeting sensitive paths (`/etc/passwd`, `~/.ssh/`, `.env`)\n", + "- Shell commands containing dangerous patterns (`rm -rf`, `sudo`, `curl | sh`)\n", + "\n", + "This provides defense-in-depth beyond just checking tool names." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Sensitive paths that should never be accessed\n", + "SENSITIVE_PATHS = [\n", + " \"/etc/passwd\", \"/etc/shadow\", \"/root/\",\n", + " \"~/.ssh/\", \"~/.aws/credentials\", \".env\",\n", + "]\n", + "\n", + "# Dangerous shell patterns\n", + "DANGEROUS_COMMANDS = [\n", + " \"rm -rf\", \"mkfs\", \"dd if=\", \"> /dev/\",\n", + " \"chmod 777\", \"curl | sh\", \"wget | sh\",\n", + "]\n", + "\n", + "\n", + "def argument_validate(\n", + " event,\n", + " sensitive_paths: Optional[list[str]] = None,\n", + " dangerous_commands: Optional[list[str]] = None,\n", + ") -> None:\n", + " \"\"\"Validate tool arguments for dangerous patterns.\"\"\"\n", + " paths = sensitive_paths or SENSITIVE_PATHS\n", + " commands = dangerous_commands or DANGEROUS_COMMANDS\n", + "\n", + " tool_name = event.tool_use.get(\"name\", \"\")\n", + " tool_input = event.tool_use.get(\"input\", {})\n", + "\n", + " for arg_name, arg_value in tool_input.items():\n", + " if not isinstance(arg_value, str):\n", + " continue\n", + "\n", + " for sensitive_path in paths:\n", + " if sensitive_path in arg_value:\n", + " reason = (\n", + " f\"Tool '{tool_name}' argument '{arg_name}' references \"\n", + " f\"sensitive path: '{sensitive_path}'\"\n", + " )\n", + " event.cancel_tool = reason\n", + " logger.warning(f\"[TOOL GUARDRAIL] BLOCKED tool='{tool_name}' reason='sensitive path'\")\n", + " return\n", + "\n", + " for dangerous_cmd in commands:\n", + " if dangerous_cmd in arg_value:\n", + " reason = (\n", + " f\"Tool '{tool_name}' argument '{arg_name}' contains \"\n", + " f\"dangerous command pattern: '{dangerous_cmd}'\"\n", + " )\n", + " event.cancel_tool = reason\n", + " logger.warning(f\"[TOOL GUARDRAIL] BLOCKED tool='{tool_name}' reason='dangerous command'\")\n", + " return\n", + "\n", + " logger.debug(f\"[TOOL GUARDRAIL] ALLOWED tool='{tool_name}' reason='arguments validated'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pattern 4: Combined Guardrail with Audit Logging\n", + "\n", + "Combines all three patterns into a single comprehensive guardrail:\n", + "1. Allowlist check (if configured)\n", + "2. Blocklist check (if configured)\n", + "3. Argument validation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def combined_tool_validate(\n", + " event,\n", + " allowed_tools: Optional[list[str]] = None,\n", + " blocked_tools: Optional[list[str]] = None,\n", + " sensitive_paths: Optional[list[str]] = None,\n", + " dangerous_commands: Optional[list[str]] = None,\n", + ") -> None:\n", + " \"\"\"Comprehensive tool call validation combining multiple strategies.\"\"\"\n", + " paths = sensitive_paths or SENSITIVE_PATHS\n", + " commands = dangerous_commands or DANGEROUS_COMMANDS\n", + "\n", + " tool_name = event.tool_use.get(\"name\", \"\")\n", + " tool_input = event.tool_use.get(\"input\", {})\n", + "\n", + " # Step 1: Allowlist check\n", + " if allowed_tools is not None and tool_name not in allowed_tools:\n", + " reason = f\"Tool '{tool_name}' is not in the allowed tools list.\"\n", + " event.cancel_tool = reason\n", + " _audit_log(tool_name, \"BLOCKED\", \"not in allowlist\", tool_input)\n", + " return\n", + "\n", + " # Step 2: Blocklist check\n", + " if blocked_tools is not None and tool_name in blocked_tools:\n", + " reason = f\"Tool '{tool_name}' is explicitly blocked.\"\n", + " event.cancel_tool = reason\n", + " _audit_log(tool_name, \"BLOCKED\", \"in blocklist\", tool_input)\n", + " return\n", + "\n", + " # Step 3: Argument validation\n", + " for arg_name, arg_value in tool_input.items():\n", + " if not isinstance(arg_value, str):\n", + " continue\n", + " for sensitive_path in paths:\n", + " if sensitive_path in arg_value:\n", + " reason = f\"Tool '{tool_name}' blocked: sensitive path '{sensitive_path}'\"\n", + " event.cancel_tool = reason\n", + " _audit_log(tool_name, \"BLOCKED\", f\"sensitive path: {sensitive_path}\", tool_input)\n", + " return\n", + " for dangerous_cmd in commands:\n", + " if dangerous_cmd in arg_value:\n", + " reason = f\"Tool '{tool_name}' blocked: dangerous pattern '{dangerous_cmd}'\"\n", + " event.cancel_tool = reason\n", + " _audit_log(tool_name, \"BLOCKED\", f\"dangerous command: {dangerous_cmd}\", tool_input)\n", + " return\n", + "\n", + " _audit_log(tool_name, \"ALLOWED\", \"all checks passed\", tool_input)\n", + "\n", + "\n", + "def _audit_log(tool_name, decision, reason, tool_input):\n", + " input_summary = {k: str(v)[:50] for k, v in tool_input.items()} if tool_input else {}\n", + " log_msg = f\"[TOOL AUDIT] tool='{tool_name}' decision={decision} reason='{reason}' input={input_summary}\"\n", + " if decision == \"BLOCKED\":\n", + " logger.warning(log_msg)\n", + " else:\n", + " logger.info(log_msg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demo: Testing Each Pattern with Mock Events" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class MockToolEvent:\n", + " \"\"\"Mock BeforeToolCallEvent for demonstration.\"\"\"\n", + " def __init__(self, tool_name: str, tool_input: Optional[dict] = None):\n", + " self.tool_use = {\"name\": tool_name, \"toolUseId\": \"test-123\", \"input\": tool_input or {}}\n", + " self.cancel_tool = None\n", + "\n", + "\n", + "# --- Pattern 1: Allowlist ---\n", + "print(\"--- Pattern 1: Allowlist Enforcement ---\")\n", + "print(\"Only 'calculator' and 'file_reader' are allowed.\\n\")\n", + "\n", + "allowed = [\"calculator\", \"file_reader\"]\n", + "\n", + "event = MockToolEvent(\"calculator\", {\"expression\": \"2 + 2\"})\n", + "allowlist_validate(event, allowed)\n", + "print(f\" Tool: 'calculator' -> {'ALLOWED' if event.cancel_tool is None else 'BLOCKED'}\")\n", + "\n", + "event = MockToolEvent(\"shell_execute\", {\"command\": \"ls\"})\n", + "allowlist_validate(event, allowed)\n", + "print(f\" Tool: 'shell_execute' -> {'ALLOWED' if event.cancel_tool is None else 'BLOCKED'}\")\n", + "print(f\" Reason: {event.cancel_tool}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# --- Pattern 2: Blocklist ---\n", + "print(\"--- Pattern 2: Blocklist Enforcement ---\")\n", + "print(\"'shell_execute' and 'file_delete' are blocked.\\n\")\n", + "\n", + "blocked = [\"shell_execute\", \"file_delete\"]\n", + "\n", + "event = MockToolEvent(\"calculator\", {\"expression\": \"3 * 7\"})\n", + "blocklist_validate(event, blocked)\n", + "print(f\" Tool: 'calculator' -> {'ALLOWED' if event.cancel_tool is None else 'BLOCKED'}\")\n", + "\n", + "event = MockToolEvent(\"shell_execute\", {\"command\": \"whoami\"})\n", + "blocklist_validate(event, blocked)\n", + "print(f\" Tool: 'shell_execute' -> {'ALLOWED' if event.cancel_tool is None else 'BLOCKED'}\")\n", + "print(f\" Reason: {event.cancel_tool}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# --- Pattern 3: Argument Validation ---\n", + "print(\"--- Pattern 3: Argument Validation ---\")\n", + "print(\"Blocks tools that access sensitive paths or run dangerous commands.\\n\")\n", + "\n", + "# Safe file read\n", + "event = MockToolEvent(\"file_reader\", {\"path\": \"/tmp/report.txt\"})\n", + "argument_validate(event)\n", + "print(f\" file_reader('/tmp/report.txt') -> {'ALLOWED' if event.cancel_tool is None else 'BLOCKED'}\")\n", + "\n", + "# Sensitive path\n", + "event = MockToolEvent(\"file_reader\", {\"path\": \"/etc/passwd\"})\n", + "argument_validate(event)\n", + "print(f\" file_reader('/etc/passwd') -> {'ALLOWED' if event.cancel_tool is None else 'BLOCKED'}\")\n", + "print(f\" Reason: {event.cancel_tool}\")\n", + "\n", + "# Dangerous command\n", + "event = MockToolEvent(\"shell_execute\", {\"command\": \"rm -rf /tmp/data\"})\n", + "argument_validate(event)\n", + "print(f\" shell_execute('rm -rf /tmp/data') -> {'ALLOWED' if event.cancel_tool is None else 'BLOCKED'}\")\n", + "print(f\" Reason: {event.cancel_tool}\")\n", + "\n", + "# Safe command\n", + "event = MockToolEvent(\"shell_execute\", {\"command\": \"echo hello\"})\n", + "argument_validate(event)\n", + "print(f\" shell_execute('echo hello') -> {'ALLOWED' if event.cancel_tool is None else 'BLOCKED'}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# --- Pattern 4: Combined Guardrail ---\n", + "print(\"--- Pattern 4: Combined Guardrail ---\")\n", + "print(\"Allowlist + argument validation together.\\n\")\n", + "\n", + "# Allowed tool with safe arguments\n", + "event = MockToolEvent(\"calculator\", {\"expression\": \"100 / 4\"})\n", + "combined_tool_validate(event, allowed_tools=[\"calculator\", \"file_reader\", \"shell_execute\"])\n", + "print(f\" calculator('100 / 4') -> {'ALLOWED' if event.cancel_tool is None else 'BLOCKED'}\")\n", + "\n", + "# Tool not in allowlist\n", + "event = MockToolEvent(\"web_search\", {\"query\": \"test\"})\n", + "combined_tool_validate(event, allowed_tools=[\"calculator\", \"file_reader\", \"shell_execute\"])\n", + "print(f\" web_search('test') -> {'ALLOWED' if event.cancel_tool is None else 'BLOCKED'}\")\n", + "print(f\" Reason: {event.cancel_tool}\")\n", + "\n", + "# Allowed tool but dangerous arguments\n", + "event = MockToolEvent(\"shell_execute\", {\"command\": \"rm -rf /\"})\n", + "combined_tool_validate(event, allowed_tools=[\"calculator\", \"file_reader\", \"shell_execute\"])\n", + "print(f\" shell_execute('rm -rf /') -> {'ALLOWED' if event.cancel_tool is None else 'BLOCKED'}\")\n", + "print(f\" Reason: {event.cancel_tool}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## HookProvider: Wrapping Tool Guardrails for Agent Registration\n", + "\n", + "In strands-agents 1.40.0, hooks are registered via `HookProvider` classes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class ToolGuardrailHook(HookProvider):\n", + " \"\"\"HookProvider that validates tool calls using the combined guardrail.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " allowed_tools: Optional[list[str]] = None,\n", + " blocked_tools: Optional[list[str]] = None,\n", + " sensitive_paths: Optional[list[str]] = None,\n", + " dangerous_commands: Optional[list[str]] = None,\n", + " ):\n", + " self.allowed_tools = allowed_tools\n", + " self.blocked_tools = blocked_tools\n", + " self.sensitive_paths = sensitive_paths\n", + " self.dangerous_commands = dangerous_commands\n", + "\n", + " def register_hooks(self, registry: HookRegistry) -> None:\n", + " registry.add_callback(BeforeToolCallEvent, self._validate_tool)\n", + "\n", + " def _validate_tool(self, event: BeforeToolCallEvent) -> None:\n", + " combined_tool_validate(\n", + " event,\n", + " allowed_tools=self.allowed_tools,\n", + " blocked_tools=self.blocked_tools,\n", + " sensitive_paths=self.sensitive_paths,\n", + " dangerous_commands=self.dangerous_commands,\n", + " )\n", + "\n", + "\n", + "print(\"ToolGuardrailHook defined successfully.\")\n", + "print(\"Register with: Agent(hooks=[ToolGuardrailHook(allowed_tools=[...])])\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Attaching to a Live Agent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from strands import Agent, tool\n", + " from strands.models.bedrock import BedrockModel\n", + "\n", + " @tool\n", + " def calculator(expression: str) -> str:\n", + " \"\"\"Evaluate a math expression.\"\"\"\n", + " return str(eval(expression))\n", + "\n", + " @tool\n", + " def file_reader(path: str) -> str:\n", + " \"\"\"Read a file from disk.\"\"\"\n", + " return f\"Contents of {path}\"\n", + "\n", + " @tool\n", + " def shell_execute(command: str) -> str:\n", + " \"\"\"Execute a shell command.\"\"\"\n", + " return f\"Executed: {command}\"\n", + "\n", + " model = BedrockModel(model_id=\"us.anthropic.claude-sonnet-4-5-20250929-v1:0\")\n", + "\n", + " tool_hook = ToolGuardrailHook(\n", + " allowed_tools=[\"calculator\", \"file_reader\"],\n", + " sensitive_paths=SENSITIVE_PATHS,\n", + " dangerous_commands=DANGEROUS_COMMANDS,\n", + " )\n", + "\n", + " agent = Agent(\n", + " model=model,\n", + " system_prompt=\"You are a helpful assistant with access to tools.\",\n", + " tools=[calculator, file_reader, shell_execute],\n", + " hooks=[tool_hook],\n", + " )\n", + "\n", + " print(\"Agent created with tool call guardrail.\")\n", + " print(\"Testing: What is 25 * 4?\")\n", + " response = agent(\"What is 25 * 4?\")\n", + " print(f\" Response: {response}\")\n", + "\n", + "except Exception as e:\n", + " print(f\"Skipping live agent demo: {e}\")\n", + " print(\"(This is expected if no model provider is configured)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "In this notebook you learned four tool call validation patterns:\n", + "1. **Allowlist** — only listed tools can run\n", + "2. **Blocklist** — specific tools are blocked, all others allowed\n", + "3. **Argument validation** — inspect arguments for dangerous patterns\n", + "4. **Combined guardrail** — all patterns together with audit logging\n", + "\n", + "And how to wrap them in a `HookProvider` for agent registration.\n", + "\n", + "**Next Steps:** See `06_error_handling.ipynb` to learn about fail-open vs fail-closed error handling patterns." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/python/01-learn/18-input-output-guardrails/06_error_handling.ipynb b/python/01-learn/18-input-output-guardrails/06_error_handling.ipynb new file mode 100644 index 00000000..e12e7a08 --- /dev/null +++ b/python/01-learn/18-input-output-guardrails/06_error_handling.ipynb @@ -0,0 +1,550 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Error Handling Patterns for Guardrails\n", + "\n", + "This notebook demonstrates how to build **resilient guardrails** that handle errors gracefully.\n", + "\n", + "In production systems, content filters can fail due to timeouts, malformed input, or unexpected exceptions. The key design decision is:\n", + "\n", + "- **Fail-open**: Prioritize availability \u2014 if a filter crashes, allow the request through\n", + "- **Fail-closed**: Prioritize safety \u2014 if a filter crashes, block the request\n", + "\n", + "This notebook also covers:\n", + "- Structured audit logging for every guardrail decision\n", + "- Graceful handling of malformed messages\n", + "- Comparison of both modes with the same inputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install required packages\n", + "!pip install strands-agents strands-agents-tools --upgrade -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Content filter classes \u2014 inline definitions (no external file needed)\n", + "from dataclasses import dataclass\n", + "from enum import Enum\n", + "from typing import Optional\n", + "import re\n", + "\n", + "class Severity(Enum):\n", + " BLOCK = 'block'\n", + " WARN = 'warn'\n", + " REDACT = 'redact'\n", + "\n", + "@dataclass\n", + "class FilterResult:\n", + " passed: bool\n", + " filter_name: str\n", + " severity: Severity\n", + " message: Optional[str] = None\n", + " redacted_text: Optional[str] = None\n", + "\n", + "class ContentFilter:\n", + " def __init__(self, name, severity=Severity.BLOCK):\n", + " self.name = name\n", + " self.severity = severity\n", + " def evaluate(self, text):\n", + " raise NotImplementedError\n", + "\n", + "class RegexContentFilter(ContentFilter):\n", + " def __init__(self, name, patterns, severity=Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " self.patterns = [re.compile(p) for p in patterns]\n", + " def evaluate(self, text):\n", + " for pattern in self.patterns:\n", + " if pattern.search(text):\n", + " if self.severity == Severity.REDACT:\n", + " redacted = text\n", + " for p in self.patterns:\n", + " redacted = p.sub('[REDACTED]', redacted)\n", + " return FilterResult(False, self.name, self.severity,\n", + " f'Pattern matched: {pattern.pattern}', redacted)\n", + " return FilterResult(False, self.name, self.severity,\n", + " f'Pattern matched: {pattern.pattern}')\n", + " return FilterResult(True, self.name, self.severity)\n", + "\n", + "class KeywordContentFilter(ContentFilter):\n", + " def __init__(self, name, keywords, severity=Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " self.keywords = [kw.lower() for kw in keywords]\n", + " def evaluate(self, text):\n", + " text_lower = text.lower()\n", + " for keyword in self.keywords:\n", + " if keyword in text_lower:\n", + " return FilterResult(False, self.name, self.severity,\n", + " f\"Prohibited keyword: '{keyword}'\")\n", + " return FilterResult(True, self.name, self.severity)\n", + "\n", + "class FormatComplianceFilter(ContentFilter):\n", + " EXECUTION_PATTERNS = [\n", + " re.compile(r'\\b(run|execute|eval)\\s*\\(', re.IGNORECASE),\n", + " re.compile(r'```\\s*(bash|shell|sh)\\b', re.IGNORECASE),\n", + " re.compile(r'sudo\\s+\\w+', re.IGNORECASE),\n", + " ]\n", + " def __init__(self, name='format_compliance', severity=Severity.BLOCK):\n", + " super().__init__(name, severity)\n", + " def evaluate(self, text):\n", + " for pattern in self.EXECUTION_PATTERNS:\n", + " if pattern.search(text):\n", + " return FilterResult(False, self.name, self.severity,\n", + " 'Code execution instruction detected')\n", + " return FilterResult(True, self.name, self.severity)\n", + "\n", + "def run_filters(text, filters):\n", + " for f in filters:\n", + " result = f.evaluate(text)\n", + " if not result.passed:\n", + " return result\n", + " return None\n", + "\n", + "print('Content filter classes loaded.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "logger = logging.getLogger(__name__)\n", + "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s [%(levelname)s] %(message)s\", datefmt=\"%H:%M:%S\")\n", + "import time\n", + "from dataclasses import dataclass\n", + "from datetime import datetime, timezone\n", + "from typing import Optional\n", + "\n", + "# Import content filters from our shared module\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Audit Logging Setup\n", + "\n", + "Structured audit logging captures every guardrail decision for compliance and debugging." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@dataclass\n", + "class AuditEntry:\n", + " \"\"\"Structured audit log entry for a guardrail decision.\"\"\"\n", + " timestamp: str\n", + " direction: str # \"input\", \"output\", or \"tool_call\"\n", + " filter_name: str\n", + " action: str # \"blocked\", \"redacted\", \"warned\", \"passed\", \"error_allow\", \"error_block\"\n", + " content_snippet: str\n", + " message: Optional[str] = None\n", + "\n", + " def to_log_string(self) -> str:\n", + " snippet = self.content_snippet[:50].replace(\"\\n\", \" \")\n", + " parts = [\n", + " f\"direction={self.direction}\",\n", + " f\"filter={self.filter_name}\",\n", + " f\"action={self.action}\",\n", + " f'snippet=\"{snippet}\"',\n", + " ]\n", + " if self.message:\n", + " parts.append(f'message=\"{self.message}\"')\n", + " return \" \".join(parts)\n", + "\n", + "\n", + "# Configure audit logging\n", + "audit_logger = logging.getLogger(\"guardrails.audit\")\n", + "audit_logger.setLevel(logging.DEBUG)\n", + "if not audit_logger.handlers:\n", + " handler = logging.StreamHandler()\n", + " handler.setFormatter(logging.Formatter(\"[GUARDRAIL] %(message)s\"))\n", + " audit_logger.addHandler(handler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example Filters That Fail\n", + "\n", + "These filters simulate real-world failure scenarios: network errors, timeouts, and bugs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class BrokenFilter(ContentFilter):\n", + " \"\"\"A filter that always raises an exception.\n", + " Simulates network errors, resource exhaustion, or bugs.\"\"\"\n", + "\n", + " def __init__(self, name: str = \"broken_filter\", error_message: str = \"Internal filter error\"):\n", + " super().__init__(name, Severity.BLOCK)\n", + " self.error_message = error_message\n", + "\n", + " def evaluate(self, text: str) -> FilterResult:\n", + " raise RuntimeError(self.error_message)\n", + "\n", + "\n", + "class SlowFilter(ContentFilter):\n", + " \"\"\"A filter that simulates a timeout scenario.\"\"\"\n", + "\n", + " def __init__(self, name: str = \"slow_filter\", delay_seconds: float = 2.0):\n", + " super().__init__(name, Severity.BLOCK)\n", + " self.delay_seconds = delay_seconds\n", + "\n", + " def evaluate(self, text: str) -> FilterResult:\n", + " time.sleep(self.delay_seconds)\n", + " raise TimeoutError(f\"Filter '{self.name}' timed out after {self.delay_seconds}s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ResilientGuardrail: Fail-Open vs Fail-Closed\n", + "\n", + "This class wraps content filter execution with error handling and audit logging.\n", + "\n", + "- `fail_open=True` (default): Log the error and allow the request through\n", + "- `fail_open=False`: Log the error and block the request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class ResilientGuardrail:\n", + " \"\"\"A guardrail wrapper that handles filter errors gracefully.\"\"\"\n", + "\n", + " def __init__(self, filters: list[ContentFilter], fail_open: bool = True, direction: str = \"input\"):\n", + " self.filters = filters\n", + " self.fail_open = fail_open\n", + " self.direction = direction\n", + " self.audit_log: list[AuditEntry] = []\n", + " self._logger = logging.getLogger(\"guardrails.audit\")\n", + "\n", + " def _create_audit_entry(self, filter_name, action, content, message=None):\n", + " entry = AuditEntry(\n", + " timestamp=datetime.now(timezone.utc).isoformat(),\n", + " direction=self.direction,\n", + " filter_name=filter_name,\n", + " action=action,\n", + " content_snippet=content[:50] if content else \"\",\n", + " message=message,\n", + " )\n", + " self.audit_log.append(entry)\n", + " return entry\n", + "\n", + " def evaluate(self, text: str) -> tuple[bool, Optional[str], Optional[str]]:\n", + " \"\"\"Evaluate text through all filters with error handling.\n", + "\n", + " Returns:\n", + " (allowed, response_text, redacted_text)\n", + " \"\"\"\n", + " # Handle malformed/empty input\n", + " if text is None:\n", + " entry = self._create_audit_entry(\"input_validation\", \"blocked\", \"\", \"Received None\")\n", + " self._logger.warning(entry.to_log_string())\n", + " return False, \"Invalid input: message content is missing.\", None\n", + "\n", + " if not isinstance(text, str):\n", + " entry = self._create_audit_entry(\"input_validation\", \"blocked\", str(text)[:50],\n", + " f\"Expected string, got {type(text).__name__}\")\n", + " self._logger.warning(entry.to_log_string())\n", + " return False, f\"Invalid input: expected text, got {type(text).__name__}.\", None\n", + "\n", + " if not text.strip():\n", + " entry = self._create_audit_entry(\"input_validation\", \"passed\", \"\", \"Empty message\")\n", + " self._logger.debug(entry.to_log_string())\n", + " return True, None, None\n", + "\n", + " # Run each filter with error handling\n", + " for content_filter in self.filters:\n", + " try:\n", + " result = content_filter.evaluate(text)\n", + "\n", + " if not result.passed:\n", + " if result.severity == Severity.BLOCK:\n", + " entry = self._create_audit_entry(content_filter.name, \"blocked\", text, result.message)\n", + " self._logger.warning(entry.to_log_string())\n", + " return False, f\"Request blocked by '{content_filter.name}': {result.message}\", None\n", + " elif result.severity == Severity.REDACT:\n", + " entry = self._create_audit_entry(content_filter.name, \"redacted\", text, result.message)\n", + " self._logger.info(entry.to_log_string())\n", + " text = result.redacted_text or text\n", + " elif result.severity == Severity.WARN:\n", + " entry = self._create_audit_entry(content_filter.name, \"warned\", text, result.message)\n", + " self._logger.info(entry.to_log_string())\n", + " else:\n", + " entry = self._create_audit_entry(content_filter.name, \"passed\", text)\n", + " self._logger.debug(entry.to_log_string())\n", + "\n", + " except Exception as e:\n", + " if self.fail_open:\n", + " entry = self._create_audit_entry(content_filter.name, \"error_allow\", text, f\"Filter error: {e}\")\n", + " self._logger.error(f\"{entry.to_log_string()} | Allowing request (fail-open).\")\n", + " else:\n", + " entry = self._create_audit_entry(content_filter.name, \"error_block\", text, f\"Filter error: {e}\")\n", + " self._logger.error(f\"{entry.to_log_string()} | Blocking request (fail-closed).\")\n", + " return False, \"Request blocked due to an internal guardrail error.\", None\n", + "\n", + " return True, None, None" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demo: Fail-Open Mode\n", + "\n", + "In fail-open mode, a broken filter's error is logged but the request proceeds." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up filters\n", + "keyword_filter = KeywordContentFilter(name=\"topic_blocker\", keywords=[\"hack\", \"exploit\"], severity=Severity.BLOCK)\n", + "pii_filter = RegexContentFilter(\n", + " name=\"pii_redactor\",\n", + " patterns=[r\"\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b\"],\n", + " severity=Severity.REDACT,\n", + ")\n", + "broken_filter = BrokenFilter(name=\"unstable_classifier\", error_message=\"Connection to ML service refused\")\n", + "\n", + "test_inputs = [\n", + " \"What are best practices for application security?\",\n", + " \"How do I hack into a system?\",\n", + " \"Contact me at user@example.com for details.\",\n", + " \"\",\n", + " \"Normal request that will hit the broken filter\",\n", + "]\n", + "\n", + "print(\"=\" * 60)\n", + "print(\"MODE: FAIL-OPEN (availability > safety)\")\n", + "print(\"Filters: keyword_filter -> broken_filter -> pii_filter\")\n", + "print(\"=\" * 60)\n", + "\n", + "fail_open_guardrail = ResilientGuardrail(\n", + " filters=[keyword_filter, broken_filter, pii_filter],\n", + " fail_open=True,\n", + " direction=\"input\",\n", + ")\n", + "\n", + "for text in test_inputs:\n", + " display_text = text if text else \"\"\n", + " print(f\"\\n Input: \\\"{display_text}\\\"\")\n", + " allowed, message, redacted = fail_open_guardrail.evaluate(text)\n", + " print(f\" Allowed: {allowed}\")\n", + " if message:\n", + " print(f\" Message: {message}\")\n", + " if redacted:\n", + " print(f\" Redacted: {redacted}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demo: Fail-Closed Mode\n", + "\n", + "In fail-closed mode, a broken filter's error causes the request to be blocked." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=\" * 60)\n", + "print(\"MODE: FAIL-CLOSED (safety > availability)\")\n", + "print(\"Filters: keyword_filter -> broken_filter -> pii_filter\")\n", + "print(\"=\" * 60)\n", + "\n", + "fail_closed_guardrail = ResilientGuardrail(\n", + " filters=[keyword_filter, broken_filter, pii_filter],\n", + " fail_open=False,\n", + " direction=\"input\",\n", + ")\n", + "\n", + "for text in test_inputs:\n", + " display_text = text if text else \"\"\n", + " print(f\"\\n Input: \\\"{display_text}\\\"\")\n", + " allowed, message, redacted = fail_closed_guardrail.evaluate(text)\n", + " print(f\" Allowed: {allowed}\")\n", + " if message:\n", + " print(f\" Message: {message}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demo: Malformed Message Handling\n", + "\n", + "The guardrail gracefully handles unexpected input types." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=\" * 60)\n", + "print(\"MALFORMED MESSAGE HANDLING\")\n", + "print(\"=\" * 60)\n", + "\n", + "guardrail = ResilientGuardrail(filters=[keyword_filter], fail_open=True, direction=\"input\")\n", + "\n", + "malformed_inputs = [\n", + " (None, \"None value\"),\n", + " (123, \"Integer instead of string\"),\n", + " ([\"a\", \"list\"], \"List instead of string\"),\n", + " (\"\", \"Empty string\"),\n", + " (\" \", \"Whitespace only\"),\n", + "]\n", + "\n", + "for value, description in malformed_inputs:\n", + " print(f\"\\n Input ({description}): {repr(value)}\")\n", + " allowed, message, _ = guardrail.evaluate(value)\n", + " print(f\" Allowed: {allowed}\")\n", + " if message:\n", + " print(f\" Message: {message}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparison Table: Same Inputs, Different Modes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Suppress audit logging for cleaner output\n", + "audit_logger.setLevel(logging.CRITICAL)\n", + "\n", + "comparison_open = ResilientGuardrail(\n", + " filters=[keyword_filter, broken_filter, pii_filter], fail_open=True, direction=\"input\"\n", + ")\n", + "comparison_closed = ResilientGuardrail(\n", + " filters=[keyword_filter, broken_filter, pii_filter], fail_open=False, direction=\"input\"\n", + ")\n", + "\n", + "print(f\"{'Input':<45} {'Fail-Open':<12} {'Fail-Closed':<12}\")\n", + "print(f\"{'-'*45} {'-'*12} {'-'*12}\")\n", + "\n", + "for text in test_inputs:\n", + " display = (text[:42] + \"...\") if len(text) > 42 else text\n", + " if not display:\n", + " display = \"\"\n", + "\n", + " open_allowed, _, _ = comparison_open.evaluate(text)\n", + " closed_allowed, _, _ = comparison_closed.evaluate(text)\n", + "\n", + " open_status = \"ALLOWED\" if open_allowed else \"BLOCKED\"\n", + " closed_status = \"ALLOWED\" if closed_allowed else \"BLOCKED\"\n", + "\n", + " print(f\"{display:<45} {open_status:<12} {closed_status:<12}\")\n", + "\n", + "# Restore logging\n", + "audit_logger.setLevel(logging.DEBUG)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Audit Log Review" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"AUDIT LOG SUMMARY (from fail-open guardrail)\")\n", + "print(\"=\" * 60)\n", + "\n", + "for entry in fail_open_guardrail.audit_log:\n", + " print(f\" [{entry.action.upper():12s}] filter={entry.filter_name:20s} \"\n", + " f'snippet=\"{entry.content_snippet[:30]}...\"')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "**Key Takeaways:**\n", + "\n", + "| Mode | Behavior on Error | Best For |\n", + "|------|-------------------|----------|\n", + "| **Fail-open** | Log error, allow request | User-facing apps, availability-critical systems |\n", + "| **Fail-closed** | Log error, block request | High-security environments, compliance-critical systems |\n", + "\n", + "In this notebook you learned:\n", + "1. The fail-open vs fail-closed design decision\n", + "2. How to build a `ResilientGuardrail` that handles filter errors gracefully\n", + "3. Structured audit logging for compliance\n", + "4. Graceful handling of malformed inputs\n", + "\n", + "**Next Steps:** See `07_testing_guardrails.ipynb` to learn how to test guardrails with example-based and property-based tests." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/python/01-learn/18-input-output-guardrails/README.md b/python/01-learn/18-input-output-guardrails/README.md new file mode 100644 index 00000000..e25eac57 --- /dev/null +++ b/python/01-learn/18-input-output-guardrails/README.md @@ -0,0 +1,442 @@ +# Input/Output Guardrails + +Build custom input validation, output validation, and content filtering for Strands Agents using hooks — entirely in Python. + +## Overview + +This tutorial teaches you how to implement guardrails that inspect and control what goes into and comes out of your agent. You'll build: + +- **Input guardrails** that block harmful or non-compliant user messages before they reach the model +- **Output guardrails** that redact PII or replace unsafe responses before they reach the user +- **Tool call guardrails** that restrict which tools the agent can invoke +- **A reusable HookProvider** that bundles all guardrail logic into a single component + +### Architecture + +
+ +
+ +### How is this different from `05-guardrails`? + +The `05-guardrails` tutorial uses **Amazon Bedrock Guardrails** — a managed service you configure through AWS and attach via model parameters. It's great when you want AWS to handle content filtering for you. + +This tutorial takes a different approach: you build guardrails **in pure Python** using the Strands SDK's `HookProvider` infrastructure. This gives you: + +- Full control over validation logic (regex, keywords, custom classifiers) +- Model-agnostic implementation (works with any provider) +- Testable in isolation without a live model connection +- Composable filters you can mix and match per use case + +## Prerequisites + +- Python 3.10+ +- `strands-agents` SDK 1.40.0+ installed +- A configured model provider (AWS Bedrock, Anthropic, etc.) — optional for testing + +Install dependencies: + +```bash +pip install strands-agents strands-agents-tools --upgrade +``` + +## Tutorial Structure + +| File | Description | +|------|-------------| +| `01_input_guardrail.ipynb` | Input validation using `BeforeInvocationEvent` via `HookProvider` | +| `02_output_guardrail.ipynb` | Output validation with BLOCK and REDACT behaviors | +| `03_content_filters.ipynb` | Reusable content filter classes (regex, keyword, format) | +| `04_guardrail_plugin.ipynb` | Full `GuardrailPlugin` class implementing `HookProvider` | +| `05_tool_call_validation.ipynb` | Tool call restriction using `BeforeToolCallEvent` | +| `06_error_handling.ipynb` | Fail-open vs fail-closed error handling patterns | +| `content_filters.py` | Shared content filter module imported by all notebooks | + +--- + +## Step 1: Input Guardrails + +**File:** `01_input_guardrail.ipynb` + +Input guardrails intercept user messages *before* they reach the model. The Strands SDK fires a `BeforeInvocationEvent` at the start of each agent invocation, giving you access to `event.agent.messages` — the full conversation history you can inspect and modify. + +### How it works + +1. Extract text from the last user message +2. Run it through your content filters +3. If a violation is found, replace the messages with a rejection prompt + +### Key code + +```python +from strands.hooks import HookProvider, HookRegistry, BeforeInvocationEvent + +class InputGuardrailHook(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(BeforeInvocationEvent, self._validate_input) + + def _validate_input(self, event: BeforeInvocationEvent) -> None: + messages = event.agent.messages + if not messages: + return + + last_message = messages[-1] + if last_message.get("role") != "user": + return + + text = _extract_text_from_message(last_message) + result = keyword_filter.evaluate(text) + + if not result.passed: + # Replace messages with a rejection prompt + messages.clear() + messages.append({ + "role": "user", + "content": [{"text": ( + "Respond only with: I cannot process that request. " + "The input was blocked by a content safety filter." + )}], + }) +``` + +### Registering the hook + +```python +from strands import Agent + +agent = Agent( + system_prompt="You are a helpful assistant.", + hooks=[InputGuardrailHook()], +) +``` + +### Composing multiple filters + +The tutorial also demonstrates a `combined_input_guardrail_logic` that chains keyword detection and PII detection in sequence — the first violation stops evaluation: + +```python +filters = [keyword_filter, pii_filter] +violation = run_filters(text, filters) + +if violation is not None: + # Block the request with the violation's message + messages.clear() + messages.append(...) +``` + +--- + +## Step 2: Output Guardrails + +**File:** `02_output_guardrail.ipynb` + +Output guardrails inspect model responses *after* inference completes. The SDK fires an `AfterInvocationEvent` with `event.agent` — access the conversation history via `event.agent.messages` which includes the assistant's reply. + +### Two behaviors: BLOCK vs REDACT + +| Behavior | What happens | Use case | +|----------|-------------|----------| +| **BLOCK** | Replace the entire response with a safe fallback | Prohibited topics, classified info | +| **REDACT** | Replace only matched patterns with `[REDACTED]` | PII in responses (emails, phone numbers) | + +### BLOCK example + +```python +from strands.hooks import HookProvider, HookRegistry, AfterInvocationEvent + +class OutputGuardrailHook(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(AfterInvocationEvent, self._validate_output) + + def _validate_output(self, event: AfterInvocationEvent) -> None: + messages = event.agent.messages + # Find last assistant message and evaluate it + response_text = _get_last_assistant_text(messages) + result = output_keyword_filter.evaluate(response_text) + + if not result.passed: + _replace_assistant_response(messages, "Blocked by content filter.") +``` + +### REDACT example + +```python +def pii_output_guardrail_logic(messages: list) -> None: + """Redact PII from model responses.""" + response_text = _get_last_assistant_text(messages) + result = pii_redaction_filter.evaluate(response_text) + + if not result.passed: + _redact_assistant_response(messages, result.redacted_text) +``` + +Given input `"The user's email is john@example.com and phone is 555-123-4567."`, the output becomes: +`"The user's email is [REDACTED] and phone is [REDACTED]."` + +### Registering the hook + +```python +agent = Agent( + system_prompt="You are a helpful assistant.", + hooks=[OutputGuardrailHook()], +) +``` + +--- + +## Step 3: Custom Content Filters + +**File:** `03_content_filters.ipynb` and `content_filters.py` + +Content filters are the building blocks of guardrails. This module defines a composable filter architecture with a base class, concrete implementations, and a pipeline runner. + +### Architecture + +``` +ContentFilter (base class) +├── RegexContentFilter — pattern matching (PII, structured data) +├── KeywordContentFilter — keyword/phrase detection (topic blocking) +└── FormatComplianceFilter — structural validation (no code execution instructions) +``` + +
+ +
+ +### Severity levels + +```python +class Severity(Enum): + BLOCK = "block" # Reject the entire request/response + WARN = "warn" # Log a warning but allow through + REDACT = "redact" # Remove/replace the matched content +``` + +### Pipeline runner + +`run_filters()` evaluates text against a list of filters in order, returning the first violation: + +```python +def run_filters(text: str, filters: list[ContentFilter]) -> Optional[FilterResult]: + """Return the first failing filter's result, or None if all pass.""" + for content_filter in filters: + result = content_filter.evaluate(text) + if not result.passed: + return result + return None +``` + +--- + +## Step 4: Guardrail Plugin (HookProvider) + +**File:** `04_guardrail_plugin.ipynb` + +For production use, package your guardrails as a **HookProvider**. This bundles input, output, and tool call validation into a single reusable component. + +### The HookProvider pattern + +
+ +
+ +```python +from strands.hooks import HookProvider, HookRegistry, BeforeInvocationEvent, AfterInvocationEvent, BeforeToolCallEvent + +class GuardrailPlugin(HookProvider): + def __init__( + self, + input_filters: list[ContentFilter] | None = None, + output_filters: list[ContentFilter] | None = None, + tool_allowlist: list[str] | None = None, + fail_open: bool = True, + ): + self.input_filters = input_filters or [] + self.output_filters = output_filters or [] + self.tool_allowlist = tool_allowlist + self.fail_open = fail_open + + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(BeforeInvocationEvent, self._validate_input) + registry.add_callback(AfterInvocationEvent, self._validate_output) + registry.add_callback(BeforeToolCallEvent, self._validate_tool_call) + + def _validate_input(self, event: BeforeInvocationEvent) -> None: + """Validate user input before model inference.""" + ... + + def _validate_output(self, event: AfterInvocationEvent) -> None: + """Validate model output before returning to user.""" + ... + + def _validate_tool_call(self, event: BeforeToolCallEvent) -> None: + """Validate tool calls against the configured allowlist.""" + ... +``` + +### Key features + +- **`register_hooks`** — registers callbacks for all three event types in one place +- **Configurable filters** — pass different filter lists for input vs output +- **Tool allowlist** — restrict which tools the agent can call (set to `None` to allow all) +- **Fail-open/fail-closed** — control error handling behavior via constructor parameter +- **Audit logging** — all decisions are logged with structured formatting + +### Attaching to an agent + +```python +plugin = GuardrailPlugin( + input_filters=[ + KeywordContentFilter("topics", ["hack", "exploit"], Severity.BLOCK), + RegexContentFilter("pii", [r"\b\d{3}-\d{2}-\d{4}\b"], Severity.BLOCK), + ], + output_filters=[ + RegexContentFilter("pii_redactor", [r"\b\d{3}-\d{2}-\d{4}\b"], Severity.REDACT), + ], + tool_allowlist=["calculator", "web_search"], + fail_open=True, +) + +agent = Agent(hooks=[plugin]) +``` + +--- + +## Step 5: Tool Call Validation + +**File:** `05_tool_call_validation.ipynb` + +Tool call guardrails intercept tool invocations *before* execution using `BeforeToolCallEvent`. This lets you enforce which tools the agent can use and validate their arguments. + +### Event structure + +```python +event.tool_use = { + "name": "shell_execute", + "toolUseId": "abc-123", + "input": {"command": "rm -rf /"} +} +event.cancel_tool = None # Set this to a string to block the tool call +``` + +### Pattern 1: Allowlist + +Only listed tools can execute. Everything else is blocked: + +```python +def allowlist_validate(event, allowed_tools: list[str]) -> None: + tool_name = event.tool_use.get("name", "") + if tool_name not in allowed_tools: + event.cancel_tool = f"Tool '{tool_name}' is not permitted." +``` + +### Pattern 2: Blocklist + +Specific tools are blocked. Everything else is allowed. + +### Pattern 3: Argument validation + +Inspect tool arguments for dangerous patterns (sensitive paths, destructive commands). + +### Registering tool call hooks + +```python +class ToolGuardrailHook(HookProvider): + def __init__(self, allowed_tools=None): + self.allowed_tools = allowed_tools + + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(BeforeToolCallEvent, self._validate_tool) + + def _validate_tool(self, event: BeforeToolCallEvent) -> None: + ... + +agent = Agent( + tools=[calculator, file_reader, shell_execute], + hooks=[ToolGuardrailHook(allowed_tools=["calculator", "file_reader"])], +) +``` + +--- + +## Step 6: Error Handling + +**File:** `06_error_handling.ipynb` + +In production, content filters can fail — network timeouts, malformed input, bugs in custom classifiers. The key design decision: should a filter failure **allow** or **block** the request? + +### Fail-open vs fail-closed + +| Mode | On filter exception | Best for | +|------|-------------------|----------| +| `fail_open=True` | Log error, allow request through | Production systems prioritizing availability | +| `fail_open=False` | Log error, block request | High-security systems prioritizing safety | + +### When to use each + +- **Fail-open**: Customer-facing chatbots, general assistants — a crashed filter shouldn't break the user experience +- **Fail-closed**: Financial compliance, healthcare, legal — safety is non-negotiable + +### Audit logging + +Every guardrail decision is logged with structured formatting for compliance: + +```python +audit_logger.warning( + f"direction=input filter={filter_name} action=blocked " + f'snippet="{content[:50]}" message="{reason}"' +) +``` + +--- + +## Summary + +You've learned how to build a complete guardrail system for Strands Agents: + +1. **Input guardrails** intercept messages via `BeforeInvocationEvent` and block/modify them before inference +2. **Output guardrails** intercept responses via `AfterInvocationEvent` and can BLOCK or REDACT content +3. **Content filters** are composable building blocks with severity levels (BLOCK, WARN, REDACT) +4. **The HookProvider pattern** bundles everything into a reusable component +5. **Tool call validation** restricts which tools the agent can invoke via `BeforeToolCallEvent` +6. **Error handling** lets you choose fail-open (availability) or fail-closed (safety) + +### Key takeaways + +- Guardrails are `HookProvider` classes that register callbacks for lifecycle events +- Access messages via `event.agent.messages` in both `BeforeInvocationEvent` and `AfterInvocationEvent` +- The `hooks=` parameter on `Agent` expects a list of `HookProvider` instances +- Always test guardrails in isolation before deploying with a live model +- Choose fail-open vs fail-closed based on your application's risk profile +- Audit logging is essential for compliance and debugging + +### API Quick Reference (strands-agents 1.40.0) + +```python +from strands.hooks import HookProvider, HookRegistry, BeforeInvocationEvent, AfterInvocationEvent, BeforeToolCallEvent + +class MyHook(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(BeforeInvocationEvent, self._before) + registry.add_callback(AfterInvocationEvent, self._after) + registry.add_callback(BeforeToolCallEvent, self._tool) + + def _before(self, event: BeforeInvocationEvent) -> None: + messages = event.agent.messages # Access/modify messages + + def _after(self, event: AfterInvocationEvent) -> None: + messages = event.agent.messages # Access/modify messages (includes assistant reply) + + def _tool(self, event: BeforeToolCallEvent) -> None: + event.tool_use # {"name": ..., "toolUseId": ..., "input": {...}} + event.cancel_tool = "reason" # Set to block the tool call + +agent = Agent(hooks=[MyHook()]) +``` + +### Next steps + +- Explore the [hooks lifecycle tutorial](../16-hooks-lifecycle/) for more lifecycle event patterns +- Check out [05-guardrails](../05-guardrails/) for the managed Bedrock Guardrails approach +- Add custom ML-based classifiers (toxicity, sentiment) as `ContentFilter` subclasses +- Integrate with external moderation APIs by wrapping them in the `ContentFilter` interface diff --git a/python/01-learn/18-input-output-guardrails/images/filter_pipeline.png b/python/01-learn/18-input-output-guardrails/images/filter_pipeline.png new file mode 100644 index 00000000..c58c3076 Binary files /dev/null and b/python/01-learn/18-input-output-guardrails/images/filter_pipeline.png differ diff --git a/python/01-learn/18-input-output-guardrails/images/guardrail_architecture.png b/python/01-learn/18-input-output-guardrails/images/guardrail_architecture.png new file mode 100644 index 00000000..29bc2b68 Binary files /dev/null and b/python/01-learn/18-input-output-guardrails/images/guardrail_architecture.png differ diff --git a/python/01-learn/18-input-output-guardrails/images/plugin_architecture.png b/python/01-learn/18-input-output-guardrails/images/plugin_architecture.png new file mode 100644 index 00000000..b114d1a2 Binary files /dev/null and b/python/01-learn/18-input-output-guardrails/images/plugin_architecture.png differ diff --git a/python/01-learn/18-input-output-guardrails/requirements.txt b/python/01-learn/18-input-output-guardrails/requirements.txt new file mode 100644 index 00000000..80a728e4 --- /dev/null +++ b/python/01-learn/18-input-output-guardrails/requirements.txt @@ -0,0 +1,4 @@ +strands-agents +strands-agents-tools +hypothesis +pytest